Skip to content

Commit

Permalink
Refactor zipfile: Adds a register_compressor() API.
Browse files Browse the repository at this point in the history
Uses this API for all of our built-in stdlib based compressors.

This allows additional compression methods to be officially supplied by
third party libraries without monkeypatching the zipfile module.

It is designed to obsolete the gross hacks that things like
https://pypi.org/project/zipfile-zstd/ have to do.
  • Loading branch information
gpshead committed Jan 6, 2024
1 parent 4c4b08d commit 41d28bd
Show file tree
Hide file tree
Showing 4 changed files with 667 additions and 136 deletions.
27 changes: 27 additions & 0 deletions Lib/bz2.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,30 @@ def decompress(data):
"end-of-stream marker was reached")
data = decomp.unused_data
return b"".join(results)


# For use by zipfile.register_compressor().
class _ZipBZ2CompressorProxy:
def __new__(cls, compresslevel=None):
if compresslevel is not None:
return BZ2Compressor(compresslevel)
return BZ2Compressor()


class _ZipBZ2Decompressor:
def __init__(self):
self._decomp = BZ2Decompressor()
self.decompress = self._decomp.decompress

@property
def eof(self):
return self._decomp.eof

@property
def needs_input(self):
return self._decomp.needs_input

def flush(self):
if not self._decomp.eof and not self._decomp.needs_input:
return self.decompress(b'')
return b''
65 changes: 65 additions & 0 deletions Lib/lzma.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import builtins
import io
import os
import struct
from _lzma import *
from _lzma import _encode_filter_properties, _decode_filter_properties
import _compression
Expand Down Expand Up @@ -354,3 +355,67 @@ def decompress(data, format=FORMAT_AUTO, memlimit=None, filters=None):
if not data:
break
return b"".join(results)



# For use by zipfile.register_compressor(), originally part of zipfile.
class _ZipLZMACompressor:
def __init__(self, compresslevel=None):
self._comp = None
del compresslevel # unused

def _init(self):
props = _encode_filter_properties({'id': FILTER_LZMA1})
self._comp = LZMACompressor(FORMAT_RAW, filters=[
_decode_filter_properties(FILTER_LZMA1, props)
])
return struct.pack('<BBH', 9, 4, len(props)) + props

def compress(self, data):
if self._comp is None:
return self._init() + self._comp.compress(data)
return self._comp.compress(data)

def flush(self):
if self._comp is None:
return self._init() + self._comp.flush()
return self._comp.flush()


class _ZipLZMADecompressor:
def __init__(self):
self._decomp = None
self._unconsumed = b''

@property
def eof(self):
return self._decomp.eof if self._decomp else False

@property
def needs_input(self):
return self._decomp.needs_input if self._decomp else True

def decompress(self, /, data, max_length=-1):
if self._decomp is None:
self._unconsumed += data
if len(self._unconsumed) <= 4:
return b''
psize, = struct.unpack('<H', self._unconsumed[2:4])
if len(self._unconsumed) <= 4 + psize:
return b''

self._decomp = LZMADecompressor(FORMAT_RAW, filters=[
_decode_filter_properties(FILTER_LZMA1,
self._unconsumed[4:4 + psize])
])
data = self._unconsumed[4 + psize:]
del self._unconsumed

return self._decomp.decompress(data, max_length)

def flush(self):
if not self._decomp:
return b''
if not self._decomp.eof and not self._decomp.needs_input:
return self.decompress(b'')
return b''
202 changes: 193 additions & 9 deletions Lib/test/test_zipfile/test_core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import array
import codecs
import contextlib
import importlib
import importlib.util
import io
import itertools
Expand Down Expand Up @@ -2262,19 +2264,201 @@ def test_read_after_seek(self):
fp.seek(1, os.SEEK_CUR)
self.assertEqual(fp.read(-1), b'men!')

@requires_bz2()
def test_decompress_without_3rd_party_library(self):
data = b'PK\x05\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
def tearDown(self):
unlink(TESTFN)
unlink(TESTFN2)


class TestDemandLoadedCompressors(unittest.TestCase):
def setUp(self):
importlib.reload(zipfile)
# Confirm that the reload() reset these to their initial pre-use state.
# They replace themselves with an actual implementation on first use
# when the relevant compression module is present.
self.assertIn("OnDemand", zipfile._compressors_by_appnote_4_4_5_method_id[zipfile.ZIP_DEFLATED].decompressor.__name__)
self.assertIn("OnDemand", zipfile._compressors_by_appnote_4_4_5_method_id[zipfile.ZIP_BZIP2].decompressor.__name__)
self.assertIn("OnDemand", zipfile._compressors_by_appnote_4_4_5_method_id[zipfile.ZIP_LZMA].decompressor.__name__)

@requires_bz2() # To make coding the save and restore easier.
def test_decompress_without_3rd_party_bz2_library(self):
# A zip containing a single bzip2 compressed a.txt file.
data = b'PK\x03\x04.\x00\x00\x00\x0c\x00\x8a\x8b$XC\xbe\xb7\xe8%\x00\x00\x00\x01\x00\x00\x00\x05\x00\x00\x00a.txtBZh91AY&SY\x19\x93\x9bk\x00\x00\x00\x01\x00 \x00 \x00!\x18F\x82\xeeH\xa7\n\x12\x032sm`PK\x01\x02.\x03.\x00\x00\x00\x0c\x00\x8a\x8b$XC\xbe\xb7\xe8%\x00\x00\x00\x01\x00\x00\x00\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80\x01\x00\x00\x00\x00a.txtPK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00\x00H\x00\x00\x00\x00\x00'
zip_file = io.BytesIO(data)
saved_bz2_module = sys.modules['bz2']
try:
sys.modules['bz2'] = None
with zipfile.ZipFile(zip_file) as zf:
with self.assertRaises(zipfile.UnknownCompressionError):
zf.extract('a.txt')
finally:
sys.modules['bz2'] = saved_bz2_module

@requires_lzma() # To make coding the save and restore easier.
def test_decompress_without_3rd_party_lzma_library(self):
# A zip containing a single lzma compressed a.txt file.
data = b'PK\x03\x04?\x00\x02\x00\x0e\x00a\x8e$XC\xbe\xb7\xe8\x14\x00\x00\x00\x01\x00\x00\x00\x05\x00\x00\x00a.txt\t\x04\x05\x00]\x00\x00\x80\x00\x000\xc1\xfb\xff\xff\xff\xe0\x00\x00\x00PK\x01\x02?\x03?\x00\x02\x00\x0e\x00a\x8e$XC\xbe\xb7\xe8\x14\x00\x00\x00\x01\x00\x00\x00\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80\x01\x00\x00\x00\x00a.txtPK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00\x007\x00\x00\x00\x00\x00'
zip_file = io.BytesIO(data)
saved_lzma_module = sys.modules['lzma']
try:
sys.modules['lzma'] = None
with zipfile.ZipFile(zip_file) as zf:
with self.assertRaises(zipfile.UnknownCompressionError):
zf.extract('a.txt')
finally:
sys.modules['lzma'] = saved_lzma_module

def test_decompress_without_3rd_party_zlib_library(self):
# A zip containing a single zlib (deflate) compressed a.txt file.
data = b'PK\x03\x04\x14\x00\x00\x00\x08\x00g\x8f$XC\xbe\xb7\xe8\x03\x00\x00\x00\x01\x00\x00\x00\x05\x00\x00\x00a.txtK\x04\x00PK\x01\x02\x14\x03\x14\x00\x00\x00\x08\x00g\x8f$XC\xbe\xb7\xe8\x03\x00\x00\x00\x01\x00\x00\x00\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80\x01\x00\x00\x00\x00a.txtPK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00\x00&\x00\x00\x00\x00\x00'
zip_file = io.BytesIO(data)
with zipfile.ZipFile(zip_file, 'w', compression=zipfile.ZIP_BZIP2) as zf:
zf.writestr('a.txt', b'a')
with mock.patch('zipfile.bz2', None):
# zlib is currently always imported when available so this test is
# simpler than others: It can just disable zlib within zipfile.
with mock.patch("zipfile.zlib", None):
with zipfile.ZipFile(zip_file) as zf:
self.assertRaises(RuntimeError, zf.extract, 'a.txt')
with self.assertRaises(zipfile.UnknownCompressionError):
zf.extract('a.txt')
# UnknownCompressionError must also be catchable via these
# old types that zipfile used to raise for unsupported formats.
with self.assertRaises(RuntimeError):
zf.extract('a.txt')
with self.assertRaises(NotImplementedError):
zf.extract('a.txt')


class RegisterCompressorTests(unittest.TestCase):
def tearDown(self):
unlink(TESTFN)
unlink(TESTFN2)
importlib.reload(zipfile)

def test_override_store_fails(self):
with self.assertRaises(RuntimeError):
zipfile.register_compressor(
'fakename',
compression_type=zipfile.ZIP_STORED,
compressor=zipfile.AbstractCompressor,
decompressor=zipfile.AbstractDecompressor,
override="always",
)

def test_override_always(self):
zipfile.register_compressor(
'fake-deflate',
compression_type=zipfile.ZIP_DEFLATED,
compressor=zipfile.AbstractCompressor,
decompressor=zipfile.AbstractDecompressor,
override="always",
)

def test_override_stdlib(self):
zipfile.register_compressor(
'fake-deflate',
compression_type=zipfile.ZIP_DEFLATED,
compressor=zipfile.AbstractCompressor,
decompressor=zipfile.AbstractDecompressor,
override="stdlib",
)

def test_override_non_stdlib(self):
zipfile.register_compressor(
'fake-127',
compression_type=127, # Unused as of this writing.
compressor=zipfile.AbstractCompressor,
decompressor=zipfile.AbstractDecompressor,
override="stdlib",
)
# attempt overriding the above not in "always" mode.
with self.assertRaises(RuntimeError):
zipfile.register_compressor(
'different-127',
compression_type=127,
compressor=zipfile.AbstractCompressor,
decompressor=zipfile.AbstractDecompressor,
override="stdlib",
)
with self.assertRaises(RuntimeError):
zipfile.register_compressor(
'different-127',
compression_type=127,
compressor=zipfile.AbstractCompressor,
decompressor=zipfile.AbstractDecompressor,
override="never",
)

def test_override_never(self):
with self.assertRaises(RuntimeError):
zipfile.register_compressor(
'fake-deflate',
compression_type=zipfile.ZIP_DEFLATED,
compressor=zipfile.AbstractCompressor,
decompressor=zipfile.AbstractDecompressor,
override="never",
)

def test_custom_rot13_compressor(self):
rot13_id = 113 # Made up unclaimed value for this test.
self.assertNotIn(rot13_id, zipfile.compressor_names,
msg="update the test to use a different value?")

# A silly "compression" scheme useful for testing purposes.
class rot13_compressor:
def __init__(self, compresslevel):
del compresslevel # unused
self.data = b""

def compress(self, data):
self.data += codecs.encode(
data.decode("latin-1"), "rot13").encode("latin-1")
# This exercises partial return + flush operation.
if self.data:
two_bytes = self.data[0:2]
self.data = self.data[2:]
return two_bytes
return b""

def flush(self):
try:
return self.data
finally:
self.data = b""

class rot13_decompressor:
def __init__(self):
self._buffer = b""
self.eof = False

def decompress(self, /, data, max_length=0):
data = self._buffer + data
self._buffer = b""
if max_length > 0:
self._buffer = data[max_length:]
data = data[:max_length]
return codecs.decode(
data.decode("latin-1"), "rot13").encode("latin-1")

def flush(self):
if self._buffer:
return self.decompress(b"")
return b""

zipfile.register_compressor(
'rot13',
compression_type=rot13_id,
compressor=rot13_compressor,
decompressor=rot13_decompressor,
override="never",
)

mem_zip_f = io.BytesIO()
with zipfile.ZipFile(mem_zip_f, "w", compression=rot13_id) as zipf:
zipf.writestr("EBG13.txt", b"fcnz naq ornaf")

self.assertIn(b"spam and beans", mem_zip_f.getvalue())
mem_zip_f.seek(0)
with zipfile.ZipFile(mem_zip_f, "r") as zipf:
self.assertEqual(zipf.infolist()[0].compress_type, rot13_id)
self.assertEqual(zipf.read("EBG13.txt"), b"fcnz naq ornaf")
with zipf.open("EBG13.txt", "r") as f:
self.assertEqual(f.read(5), b"fcnz ")
self.assertEqual(f.read(), b"naq ornaf")


class AbstractBadCrcTests:
Expand Down
Loading

0 comments on commit 41d28bd

Please sign in to comment.