From 75b1f22a1c1edffed2e075bcd98cc2a57a2741cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Mu=C5=82a?= Date: Thu, 28 Apr 2022 11:57:13 +0200 Subject: [PATCH] zstd: Allow to ignore checksum checking (#572) Fixes #571 --- zstd/decoder.go | 6 ++--- zstd/decoder_options.go | 9 ++++++++ zstd/decoder_test.go | 51 +++++++++++++++++++++++++++++++++++++++++ zstd/framedec.go | 50 +++++++++++++++++++++++++++++----------- 4 files changed, 99 insertions(+), 17 deletions(-) diff --git a/zstd/decoder.go b/zstd/decoder.go index c65ea9795f..b04e36f2ff 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -439,7 +439,7 @@ func (d *Decoder) nextBlock(blocking bool) (ok bool) { println("got", len(d.current.b), "bytes, error:", d.current.err, "data crc:", tmp) } - if len(next.b) > 0 { + if !d.o.ignoreChecksum && len(next.b) > 0 { n, err := d.current.crc.Write(next.b) if err == nil { if n != len(next.b) { @@ -451,7 +451,7 @@ func (d *Decoder) nextBlock(blocking bool) (ok bool) { got := d.current.crc.Sum64() var tmp [4]byte binary.LittleEndian.PutUint32(tmp[:], uint32(got)) - if !bytes.Equal(tmp[:], next.d.checkCRC) && !ignoreCRC { + if !d.o.ignoreChecksum && !bytes.Equal(tmp[:], next.d.checkCRC) && !ignoreCRC { if debugDecoder { println("CRC Check Failed:", tmp[:], " (got) !=", next.d.checkCRC, "(on stream)") } @@ -534,7 +534,7 @@ func (d *Decoder) nextBlockSync() (ok bool) { } // Update/Check CRC - if d.frame.HasCheckSum { + if !d.o.ignoreChecksum && d.frame.HasCheckSum { d.frame.crc.Write(d.current.b) if d.current.d.Last { d.current.err = d.frame.checkCRC() diff --git a/zstd/decoder_options.go b/zstd/decoder_options.go index fc52ebc403..c70e6fa0f7 100644 --- a/zstd/decoder_options.go +++ b/zstd/decoder_options.go @@ -19,6 +19,7 @@ type decoderOptions struct { maxDecodedSize uint64 maxWindowSize uint64 dicts []dict + ignoreChecksum bool } func (o *decoderOptions) setDefault() { @@ -112,3 +113,11 @@ func WithDecoderMaxWindow(size uint64) DOption { return nil } } + +// IgnoreChecksum allows to forcibly ignore checksum checking. +func IgnoreChecksum(b bool) DOption { + return func(o *decoderOptions) error { + o.ignoreChecksum = b + return nil + } +} diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index 46c169b874..cba5dd80f7 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -1737,6 +1737,57 @@ func TestResetNil(t *testing.T) { } } +func TestIgnoreChecksum(t *testing.T) { + // zstd file containing text "compress\n" and has a xxhash checksum + zstdBlob := []byte{0x28, 0xb5, 0x2f, 0xfd, 0x24, 0x09, 0x49, 0x00, 0x00, 'C', 'o', 'm', 'p', 'r', 'e', 's', 's', '\n', 0x79, 0x6e, 0xe0, 0xd2} + + // replace letter 'c' with 'C', so decoding should fail. + zstdBlob[9] = 'C' + + { + // Check if the file is indeed incorrect + dec, err := NewReader(nil) + if err != nil { + t.Fatal(err) + } + defer dec.Close() + + dec.Reset(bytes.NewBuffer(zstdBlob)) + + _, err = ioutil.ReadAll(dec) + if err == nil { + t.Fatal("Expected decoding error") + } + + if !errors.Is(err, ErrCRCMismatch) { + t.Fatalf("Expected checksum error, got '%s'", err) + } + } + + { + // Ignore CRC error and decompress the content + dec, err := NewReader(nil, IgnoreChecksum(true)) + if err != nil { + t.Fatal(err) + } + defer dec.Close() + + dec.Reset(bytes.NewBuffer(zstdBlob)) + + res, err := ioutil.ReadAll(dec) + if err != nil { + t.Fatalf("Unexpected error: '%s'", err) + } + + want := []byte{'C', 'o', 'm', 'p', 'r', 'e', 's', 's', '\n'} + if !bytes.Equal(res, want) { + t.Logf("want: %s", want) + t.Logf("got: %s", res) + t.Fatalf("Wrong output") + } + } +} + func timeout(after time.Duration) (cancel func()) { if isRaceTest { return func() {} diff --git a/zstd/framedec.go b/zstd/framedec.go index 509d5cecea..4b15b2acc9 100644 --- a/zstd/framedec.go +++ b/zstd/framedec.go @@ -290,13 +290,6 @@ func (d *frameDec) checkCRC() error { if !d.HasCheckSum { return nil } - var tmp [4]byte - got := d.crc.Sum64() - // Flip to match file order. - tmp[0] = byte(got >> 0) - tmp[1] = byte(got >> 8) - tmp[2] = byte(got >> 16) - tmp[3] = byte(got >> 24) // We can overwrite upper tmp now want, err := d.rawInput.readSmall(4) @@ -305,6 +298,18 @@ func (d *frameDec) checkCRC() error { return err } + if d.o.ignoreChecksum { + return nil + } + + var tmp [4]byte + got := d.crc.Sum64() + // Flip to match file order. + tmp[0] = byte(got >> 0) + tmp[1] = byte(got >> 8) + tmp[2] = byte(got >> 16) + tmp[3] = byte(got >> 24) + if !bytes.Equal(tmp[:], want) && !ignoreCRC { if debugDecoder { println("CRC Check Failed:", tmp[:], "!=", want) @@ -317,6 +322,19 @@ func (d *frameDec) checkCRC() error { return nil } +// consumeCRC reads the checksum data if the frame has one. +func (d *frameDec) consumeCRC() error { + if d.HasCheckSum { + _, err := d.rawInput.readSmall(4) + if err != nil { + println("CRC missing?", err) + return err + } + } + + return nil +} + // runDecoder will create a sync decoder that will decode a block of data. func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) { saved := d.history.b @@ -373,13 +391,17 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) { if d.FrameContentSize != fcsUnknown && uint64(len(d.history.b)-crcStart) != d.FrameContentSize { err = ErrFrameSizeMismatch } else if d.HasCheckSum { - var n int - n, err = d.crc.Write(dst[crcStart:]) - if err == nil { - if n != len(dst)-crcStart { - err = io.ErrShortWrite - } else { - err = d.checkCRC() + if d.o.ignoreChecksum { + err = d.consumeCRC() + } else { + var n int + n, err = d.crc.Write(dst[crcStart:]) + if err == nil { + if n != len(dst)-crcStart { + err = io.ErrShortWrite + } else { + err = d.checkCRC() + } } } }