Files
opus-apps/compress/apis/deflatelua.lua
2020-05-31 23:51:04 -06:00

530 lines
11 KiB
Lua

--[[
see: https://github.com/davidm/lua-compress-deflatelua/
for licensing / details
--]]
local M = {_TYPE='module', _NAME='compress.deflatelua', _VERSION='0.3.20111128'}
local assert = assert
local error = error
local ipairs = ipairs
local pairs = pairs
local tostring = tostring
local type = type
local setmetatable = setmetatable
local io = io
local math = math
local table_sort = table.sort
local math_max = math.max
local string_char = string.char
local band = bit32.band
local lshift = bit32.lshift
local rshift = bit32.rshift
local function runtime_error(s, level)
level = level or 1
error(s, level+1)
end
local function make_outstate(outbs)
local outstate = {}
outstate.outbs = outbs
outstate.window = {}
outstate.window_pos = 1
return outstate
end
local function output(outstate, byte)
local window_pos = outstate.window_pos
outstate.outbs(byte)
outstate.window[window_pos] = byte
outstate.window_pos = window_pos % 32768 + 1 -- 32K
end
local function noeof(val)
return assert(val, 'unexpected end of file')
end
local function hasbit(bits, bit)
return bits % (bit + bit) >= bit
end
local function memoize(f)
local mt = {}
local t = setmetatable({}, mt)
function mt:__index(k)
local v = f(k)
t[k] = v
return v
end
return t
end
-- small optimization (lookup table for powers of 2)
local pow2 = memoize(function(n) return 2^n end)
--local tbits = memoize(
-- function(bits)
-- return memoize( function(bit) return getbit(bits, bit) end )
-- end )
-- weak metatable marking objects as bitstream type
local is_bitstream = setmetatable({}, {__mode='k'})
local function bytestream_from_file(fh)
local o = {}
function o.read()
local sb = fh:read(1)
if sb then return sb:byte() end
end
return o
end
local function bytestream_from_string(s)
local i = 1
local o = {}
function o.read()
local by
if i <= #s then
by = s:byte(i)
i = i + 1
end
return by
end
return o
end
local function bytestream_from_function(f)
local o = {}
function o.read()
return f()
end
return o
end
local function bitstream_from_bytestream(bys)
local buf_byte = 0
local buf_nbit = 0
local o = {}
function o.nbits_left_in_byte()
return buf_nbit
end
function o:read(nbits)
nbits = nbits or 1
while buf_nbit < nbits do
local byte = bys:read()
if not byte then return end -- note: more calls also return nil
buf_byte = buf_byte + lshift(byte, buf_nbit)
buf_nbit = buf_nbit + 8
end
local bits
if nbits == 0 then
bits = 0
elseif nbits == 32 then
bits = buf_byte
buf_byte = 0
else
bits = band(buf_byte, rshift(0xffffffff, 32 - nbits))
buf_byte = rshift(buf_byte, nbits)
end
buf_nbit = buf_nbit - nbits
return bits
end
is_bitstream[o] = true
return o
end
local function get_bitstream(o)
local bs
if is_bitstream[o] then
return o
elseif io.type(o) == 'file' then
bs = bitstream_from_bytestream(bytestream_from_file(o))
elseif type(o) == 'string' then
bs = bitstream_from_bytestream(bytestream_from_string(o))
elseif type(o) == 'function' then
bs = bitstream_from_bytestream(bytestream_from_function(o))
else
runtime_error 'unrecognized type'
end
return bs
end
local function get_obytestream(o)
local bs
if io.type(o) == 'file' then
bs = function(sbyte) o:write(string_char(sbyte)) end
elseif type(o) == 'function' then
bs = o
else
runtime_error('unrecognized type: ' .. tostring(o))
end
return bs
end
local function HuffmanTable(init, is_full)
local t = {}
if is_full then
for val,nbits in pairs(init) do
if nbits ~= 0 then
t[#t+1] = {val=val, nbits=nbits}
end
end
else
for i=1,#init-2,2 do
local firstval, nbits, nextval = init[i], init[i+1], init[i+2]
if nbits ~= 0 then
for val=firstval,nextval-1 do
t[#t+1] = {val=val, nbits=nbits}
end
end
end
end
table_sort(t, function(a,b)
return a.nbits == b.nbits and a.val < b.val or a.nbits < b.nbits
end)
-- assign codes
local code = 1 -- leading 1 marker
local nbits = 0
for _,s in ipairs(t) do
if s.nbits ~= nbits then
code = code * pow2[s.nbits - nbits]
nbits = s.nbits
end
s.code = code
code = code + 1
end
local minbits = math.huge
local look = {}
for _,s in ipairs(t) do
minbits = math.min(minbits, s.nbits)
look[s.code] = s.val
end
local msb = function(bits, nbits)
local res = 0
for _=1,nbits do
res = lshift(res, 1) + band(bits, 1)
bits = rshift(bits, 1)
end
return res
end
local tfirstcode = memoize(
function(bits) return pow2[minbits] + msb(bits, minbits) end)
function t:read(bs)
local code = 1 -- leading 1 marker
local nbits = 0
while 1 do
if nbits == 0 then -- small optimization (optional)
code = tfirstcode[noeof(bs:read(minbits))]
nbits = nbits + minbits
else
local b = noeof(bs:read())
nbits = nbits + 1
code = code * 2 + b -- MSB first
end
local val = look[code]
if val then
return val
end
end
end
return t
end
local function parse_gzip_header(bs)
-- local FLG_FTEXT = 2^0
local FLG_FHCRC = 2^1
local FLG_FEXTRA = 2^2
local FLG_FNAME = 2^3
local FLG_FCOMMENT = 2^4
local id1 = bs:read(8)
local id2 = bs:read(8)
if id1 ~= 31 or id2 ~= 139 then
runtime_error 'not in gzip format'
end
bs:read(8) -- compression method
local flg = bs:read(8) -- FLaGs
local mtime = bs:read(32) -- Modification TIME
local xfl = bs:read(8) -- eXtra FLags
local os = bs:read(8) -- Operating System
if not os then runtime_error 'invalid header' end
if hasbit(flg, FLG_FEXTRA) then
local xlen = bs:read(16)
local extra = 0
for i=1,xlen do
extra = bs:read(8)
end
if not extra then runtime_error 'invalid header' end
end
local function parse_zstring(bs)
repeat
local by = bs:read(8)
if not by then runtime_error 'invalid header' end
until by == 0
end
if hasbit(flg, FLG_FNAME) then
parse_zstring(bs)
end
if hasbit(flg, FLG_FCOMMENT) then
parse_zstring(bs)
end
if hasbit(flg, FLG_FHCRC) then
local crc16 = bs:read(16)
if not crc16 then runtime_error 'invalid header' end
-- IMPROVE: check CRC. where is an example .gz file that
-- has this set?
end
end
local function parse_zlib_header(bs)
local cm = bs:read(4) -- Compression Method
local cinfo = bs:read(4) -- Compression info
local fcheck = bs:read(5) -- FLaGs: FCHECK (check bits for CMF and FLG)
local fdict = bs:read(1) -- FLaGs: FDICT (present dictionary)
local flevel = bs:read(2) -- FLaGs: FLEVEL (compression level)
local cmf = cinfo * 16 + cm -- CMF (Compresion Method and flags)
local flg = fcheck + fdict * 32 + flevel * 64 -- FLaGs
if cm ~= 8 then -- not "deflate"
runtime_error("unrecognized zlib compression method: " + cm)
end
if cinfo > 7 then
runtime_error("invalid zlib window size: cinfo=" + cinfo)
end
local window_size = 2^(cinfo + 8)
if (cmf*256 + flg) % 31 ~= 0 then
runtime_error("invalid zlib header (bad fcheck sum)")
end
if fdict == 1 then
runtime_error("FIX:TODO - FDICT not currently implemented")
local dictid_ = bs:read(32)
end
return window_size
end
local function parse_huffmantables(bs)
local hlit = bs:read(5) -- # of literal/length codes - 257
local hdist = bs:read(5) -- # of distance codes - 1
local hclen = noeof(bs:read(4)) -- # of code length codes - 4
local ncodelen_codes = hclen + 4
local codelen_init = {}
local codelen_vals = {
16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}
for i=1,ncodelen_codes do
local nbits = bs:read(3)
local val = codelen_vals[i]
codelen_init[val] = nbits
end
local codelentable = HuffmanTable(codelen_init, true)
local function decode(ncodes)
local init = {}
local nbits
local val = 0
while val < ncodes do
local codelen = codelentable:read(bs)
--FIX:check nil?
local nrepeat
if codelen <= 15 then
nrepeat = 1
nbits = codelen
elseif codelen == 16 then
nrepeat = 3 + noeof(bs:read(2))
-- nbits unchanged
elseif codelen == 17 then
nrepeat = 3 + noeof(bs:read(3))
nbits = 0
elseif codelen == 18 then
nrepeat = 11 + noeof(bs:read(7))
nbits = 0
else
error 'ASSERT'
end
for i=1,nrepeat do
init[val] = nbits
val = val + 1
end
end
local huffmantable = HuffmanTable(init, true)
return huffmantable
end
local nlit_codes = hlit + 257
local ndist_codes = hdist + 1
local littable = decode(nlit_codes)
local disttable = decode(ndist_codes)
return littable, disttable
end
local tdecode_len_base
local tdecode_len_nextrabits
local tdecode_dist_base
local tdecode_dist_nextrabits
local function parse_compressed_item(bs, outstate, littable, disttable)
local val = littable:read(bs)
if val < 256 then -- literal
output(outstate, val)
elseif val == 256 then -- end of block
return true
else
if not tdecode_len_base then
local t = {[257]=3}
local skip = 1
for i=258,285,4 do
for j=i,i+3 do t[j] = t[j-1] + skip end
if i ~= 258 then skip = skip * 2 end
end
t[285] = 258
tdecode_len_base = t
end
if not tdecode_len_nextrabits then
local t = {}
for i=257,285 do
local j = math_max(i - 261, 0)
t[i] = rshift(j, 2)
end
t[285] = 0
tdecode_len_nextrabits = t
end
local len_base = tdecode_len_base[val]
local nextrabits = tdecode_len_nextrabits[val]
local extrabits = bs:read(nextrabits)
local len = len_base + extrabits
if not tdecode_dist_base then
local t = {[0]=1}
local skip = 1
for i=1,29,2 do
for j=i,i+1 do t[j] = t[j-1] + skip end
if i ~= 1 then skip = skip * 2 end
end
tdecode_dist_base = t
end
if not tdecode_dist_nextrabits then
local t = {}
for i=0,29 do
local j = math_max(i - 2, 0)
t[i] = rshift(j, 1)
end
tdecode_dist_nextrabits = t
end
local dist_val = disttable:read(bs)
local dist_base = tdecode_dist_base[dist_val]
local dist_nextrabits = tdecode_dist_nextrabits[dist_val]
local dist_extrabits = bs:read(dist_nextrabits)
local dist = dist_base + dist_extrabits
for i=1,len do
local pos = (outstate.window_pos - 1 - dist) % 32768 + 1 -- 32K
output(outstate, assert(outstate.window[pos], 'invalid distance'))
end
end
return false
end
local function parse_block(bs, outstate, throttle)
local bfinal = bs:read(1)
local btype = bs:read(2)
local BTYPE_NO_COMPRESSION = 0
local BTYPE_FIXED_HUFFMAN = 1
local BTYPE_DYNAMIC_HUFFMAN = 2
local BTYPE_RESERVED_ = 3
if btype == BTYPE_NO_COMPRESSION then
bs:read(bs:nbits_left_in_byte())
local len = bs:read(16)
local nlen_ = noeof(bs:read(16))
for _=1,len do
local by = noeof(bs:read(8))
output(outstate, by)
end
elseif btype == BTYPE_FIXED_HUFFMAN or btype == BTYPE_DYNAMIC_HUFFMAN then
local littable, disttable
if btype == BTYPE_DYNAMIC_HUFFMAN then
littable, disttable = parse_huffmantables(bs)
else
littable = HuffmanTable {0,8, 144,9, 256,7, 280,8, 288,nil}
disttable = HuffmanTable {0,5, 32,nil}
end
repeat
local is_done = parse_compressed_item(
bs, outstate, littable, disttable)
throttle()
until is_done
else
runtime_error 'unrecognized compression type'
end
return bfinal ~= 0
end
function M.inflate(t)
local bs = get_bitstream(t.input)
local outbs = get_obytestream(t.output)
local outstate = make_outstate(outbs)
repeat
local is_final = parse_block(bs, outstate, t.throttle)
until is_final
end
local inflate = M.inflate
function M.gunzip(t)
local bs = get_bitstream(t.input)
local outbs = get_obytestream(t.output)
parse_gzip_header(bs)
inflate{input=bs, output=outbs, throttle=t.throttle or function() end}
bs:read(bs:nbits_left_in_byte())
end
return M