From 2f750260fed707879e28e4c5995dafacdcf356c1 Mon Sep 17 00:00:00 2001 From: Luc Georges Date: Thu, 28 Mar 2024 18:37:03 +0100 Subject: [PATCH] refactor: avoid convulted code --- src/picklescan/scanner.py | 54 +++++++++++++++----------- tests/data/malicious-invalid-bytes.pkl | 14 +++++++ tests/test_scanner.py | 49 +++++++++++++++++++++-- 3 files changed, 90 insertions(+), 27 deletions(-) create mode 100644 tests/data/malicious-invalid-bytes.pkl diff --git a/src/picklescan/scanner.py b/src/picklescan/scanner.py index 349f60f..61c8af7 100644 --- a/src/picklescan/scanner.py +++ b/src/picklescan/scanner.py @@ -50,8 +50,9 @@ def merge(self, sr: "ScanResult"): class GenOpsError(Exception): - def __init__(self, msg: str): + def __init__(self, msg: str, globals: Optional[Set[Tuple[str, str]]]): self.msg = msg + self.globals = globals super().__init__() def __str__(self) -> str: @@ -177,16 +178,10 @@ def _list_globals(data: IO[bytes], multiple_pickles=True) -> Set[Tuple[str, str] try: ops = list(pickletools.genops(data)) except Exception as e: - # XXX: pickle will happily load files that contain arbitrarily placed new lines whereas pickletools errors in such cases. - # below is code to circumvent or skip these newlines while succeeding at parsing the opcodes. - err = str(e) - if "opcode b'\\n' unknown" not in err: - raise GenOpsError(err) - else: - pos = int(err.split(",")[0].replace("at position ", "")) - data.seek(-(pos + 1), 1) - ops = list(pickletools.genops(data.read(pos))) - data.seek(1, 1) + # XXX: given we can have multiple pickles in a file, we may have already successfully extracted globals from a valid pickle. + # Thus we return the already found globals in the error & to let the caller decide what to do. + globals_opt = globals if len(globals) > 0 else None + raise GenOpsError(str(e), globals_opt) last_byte = data.read(1) data.seek(-1, 1) @@ -241,18 +236,12 @@ def _list_globals(data: IO[bytes], multiple_pickles=True) -> Set[Tuple[str, str] return globals -def scan_pickle_bytes(data: IO[bytes], file_id, multiple_pickles=True) -> ScanResult: - """Disassemble a Pickle stream and report issues""" - +def _build_scan_result_from_raw_globals( + raw_globals: Set[Tuple[str, str]], + file_id, + scan_err=False, +) -> ScanResult: globals = [] - try: - raw_globals = _list_globals(data, multiple_pickles) - except GenOpsError as e: - _log.error(f"ERROR: parsing pickle in {file_id}: {e}") - return ScanResult(globals, scan_err=True) - - _log.debug("Global imports in %s: %s", file_id, raw_globals) - issues_count = 0 for rg in raw_globals: g = Global(rg[0], rg[1], SafetyLevel.Dangerous) @@ -278,7 +267,26 @@ def scan_pickle_bytes(data: IO[bytes], file_id, multiple_pickles=True) -> ScanRe g.safety = SafetyLevel.Suspicious globals.append(g) - return ScanResult(globals, 1, issues_count, 1 if issues_count > 0 else 0, False) + return ScanResult(globals, 1, issues_count, 1 if issues_count > 0 else 0, scan_err) + + +def scan_pickle_bytes(data: IO[bytes], file_id, multiple_pickles=True) -> ScanResult: + """Disassemble a Pickle stream and report issues""" + + try: + raw_globals = _list_globals(data, multiple_pickles) + except GenOpsError as e: + _log.error(f"ERROR: parsing pickle in {file_id}: {e}") + if e.globals is not None: + return _build_scan_result_from_raw_globals( + e.globals, file_id, scan_err=True + ) + else: + return ScanResult([], scan_err=True) + + _log.debug("Global imports in %s: %s", file_id, raw_globals) + + return _build_scan_result_from_raw_globals(raw_globals, file_id) def scan_zip_bytes(data: IO[bytes], file_id) -> ScanResult: diff --git a/tests/data/malicious-invalid-bytes.pkl b/tests/data/malicious-invalid-bytes.pkl new file mode 100644 index 0000000..841cca8 --- /dev/null +++ b/tests/data/malicious-invalid-bytes.pkl @@ -0,0 +1,14 @@ +Vos +p2 +0Vsystem +p3 +0Vtorch +p0 +0VLongStorage +p1 +0g2 +g3 +“(Vcat flag.txt +tR. + + \ No newline at end of file diff --git a/tests/test_scanner.py b/tests/test_scanner.py index 5193143..19847d4 100644 --- a/tests/test_scanner.py +++ b/tests/test_scanner.py @@ -243,6 +243,35 @@ def initialize_pickle_files(): ), ) + initialize_data_file( + f"{_root_path}/data/malicious-invalid-bytes.pkl", + b"".join( + [ + pickle.UNICODE + b"os\n", + pickle.PUT + b"2\n", + pickle.POP, + pickle.UNICODE + b"system\n", + pickle.PUT + b"3\n", + pickle.POP, + pickle.UNICODE + b"torch\n", + pickle.PUT + b"0\n", + pickle.POP, + pickle.UNICODE + b"LongStorage\n", + pickle.PUT + b"1\n", + pickle.POP, + pickle.GET + b"2\n", + pickle.GET + b"3\n", + pickle.STACK_GLOBAL, + pickle.MARK, + pickle.UNICODE + b"cat flag.txt\n", + pickle.TUPLE, + pickle.REDUCE, + pickle.STOP, + b"\n\n\t\t", + ] + ), + ) + # Code which created malicious12.pkl using pickleassem (see https://github.com/gousaiyang/pickleassem) # # p = PickleAssembler(proto=4) @@ -351,7 +380,6 @@ def test_scan_pickle_bytes(): def test_scan_zip_bytes(): - buffer = io.BytesIO() with zipfile.ZipFile(buffer, "w") as zip: zip.writestr("data.pkl", pickle.dumps(Malicious1())) @@ -559,15 +587,17 @@ def test_scan_directory_path(): Global("torch", "_utils", SafetyLevel.Suspicious), Global("__builtin__", "exec", SafetyLevel.Dangerous), Global("os", "system", SafetyLevel.Dangerous), + Global("os", "system", SafetyLevel.Dangerous), Global("operator", "attrgetter", SafetyLevel.Dangerous), Global("builtins", "__import__", SafetyLevel.Suspicious), Global("pickle", "loads", SafetyLevel.Dangerous), Global("_pickle", "loads", SafetyLevel.Dangerous), Global("_codecs", "encode", SafetyLevel.Suspicious), ], - scanned_files=26, - issues_count=24, - infected_files=21, + scanned_files=27, + issues_count=25, + infected_files=22, + scan_err=True, ) compare_scan_results(scan_directory_path(f"{_root_path}/data/"), sr) @@ -610,3 +640,14 @@ def test_pickle_files(): assert pickle.load(file) == 12345 with open(f"{_root_path}/data/malicious13b.pkl", "rb") as file: assert pickle.load(file) == 12345 + + +def test_invalid_bytes_err(): + malicious_invalid_bytes = ScanResult( + [Global("os", "system", SafetyLevel.Dangerous)], 1, 1, 1, True + ) + with open(f"{_root_path}/data/malicious-invalid-bytes.pkl", "rb") as file: + compare_scan_results( + scan_pickle_bytes(file, f"{_root_path}/data/malicious-invalid-bytes.pkl"), + malicious_invalid_bytes, + )