From 2f84f3e670f68392f3fec7440200f9fad4d9f5cb Mon Sep 17 00:00:00 2001 From: whyrusleeping Date: Thu, 4 Jan 2024 11:16:39 -0800 Subject: [PATCH] some cleanup for easier reading --- deferred.go | 95 ++++++++++++++ helper_types.go | 133 ++++++++++++++++++++ links.go | 125 +++++++++++++++++++ utils.go | 321 ------------------------------------------------ 4 files changed, 353 insertions(+), 321 deletions(-) create mode 100644 deferred.go create mode 100644 helper_types.go create mode 100644 links.go diff --git a/deferred.go b/deferred.go new file mode 100644 index 0000000..c2bb6ec --- /dev/null +++ b/deferred.go @@ -0,0 +1,95 @@ +package typegen + +import ( + "bytes" + "errors" + "fmt" + "io" +) + +type Deferred struct { + Raw []byte +} + +func (d *Deferred) MarshalCBOR(w io.Writer) error { + if d == nil { + _, err := w.Write(CborNull) + return err + } + if d.Raw == nil { + return errors.New("cannot marshal Deferred with nil value for Raw (will not unmarshal)") + } + _, err := w.Write(d.Raw) + return err +} + +func (d *Deferred) UnmarshalCBOR(br io.Reader) (err error) { + // Reuse any existing buffers. + reusedBuf := d.Raw[:0] + d.Raw = nil + buf := bytes.NewBuffer(reusedBuf) + + // Allocate some scratch space. + scratch := make([]byte, maxHeaderSize) + + hasReadOnce := false + defer func() { + if err == io.EOF && hasReadOnce { + err = io.ErrUnexpectedEOF + } + }() + + // Algorithm: + // + // 1. We start off expecting to read one element. + // 2. If we see a tag, we expect to read one more element so we increment "remaining". + // 3. If see an array, we expect to read "extra" elements so we add "extra" to "remaining". + // 4. If see a map, we expect to read "2*extra" elements so we add "2*extra" to "remaining". + // 5. While "remaining" is non-zero, read more elements. + + // define this once so we don't keep allocating it. + limitedReader := io.LimitedReader{R: br} + for remaining := uint64(1); remaining > 0; remaining-- { + maj, extra, err := CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + hasReadOnce = true + if err := WriteMajorTypeHeaderBuf(scratch, buf, maj, extra); err != nil { + return err + } + + switch maj { + case MajUnsignedInt, MajNegativeInt, MajOther: + // nothing fancy to do + case MajByteString, MajTextString: + if extra > ByteArrayMaxLen { + return maxLengthError + } + // Copy the bytes + limitedReader.N = int64(extra) + buf.Grow(int(extra)) + if n, err := buf.ReadFrom(&limitedReader); err != nil { + return err + } else if n < int64(extra) { + return io.ErrUnexpectedEOF + } + case MajTag: + remaining++ + case MajArray: + if extra > MaxLength { + return maxLengthError + } + remaining += extra + case MajMap: + if extra > MaxLength { + return maxLengthError + } + remaining += extra * 2 + default: + return fmt.Errorf("unhandled deferred cbor type: %d", maj) + } + } + d.Raw = buf.Bytes() + return nil +} diff --git a/helper_types.go b/helper_types.go new file mode 100644 index 0000000..dc6bba8 --- /dev/null +++ b/helper_types.go @@ -0,0 +1,133 @@ +package typegen + +import ( + "fmt" + "io" + "time" +) + +var ( + CborBoolFalse = []byte{0xf4} + CborBoolTrue = []byte{0xf5} + CborNull = []byte{0xf6} +) + +func EncodeBool(b bool) []byte { + if b { + return CborBoolTrue + } + return CborBoolFalse +} + +func WriteBool(w io.Writer, b bool) error { + _, err := w.Write(EncodeBool(b)) + return err +} + +type CborBool bool + +func (cb CborBool) MarshalCBOR(w io.Writer) error { + return WriteBool(w, bool(cb)) +} + +func (cb *CborBool) UnmarshalCBOR(r io.Reader) error { + t, val, err := CborReadHeader(r) + if err != nil { + return err + } + + if t != MajOther { + return fmt.Errorf("booleans should be major type 7") + } + + switch val { + case 20: + *cb = false + case 21: + *cb = true + default: + return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", val) + } + return nil +} + +type CborInt int64 + +func (ci CborInt) MarshalCBOR(w io.Writer) error { + v := int64(ci) + if v >= 0 { + if err := WriteMajorTypeHeader(w, MajUnsignedInt, uint64(v)); err != nil { + return err + } + } else { + if err := WriteMajorTypeHeader(w, MajNegativeInt, uint64(-v)-1); err != nil { + return err + } + } + return nil +} + +func (ci *CborInt) UnmarshalCBOR(r io.Reader) error { + maj, extra, err := CborReadHeader(r) + if err != nil { + return err + } + var extraI int64 + switch maj { + case MajUnsignedInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 positive overflow") + } + case MajNegativeInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 negative overflow") + } + extraI = -1 - extraI + default: + return fmt.Errorf("wrong type for int64 field: %d", maj) + } + + *ci = CborInt(extraI) + return nil +} + +type CborTime time.Time + +func (ct CborTime) MarshalCBOR(w io.Writer) error { + nsecs := ct.Time().UnixNano() + + cbi := CborInt(nsecs) + + return cbi.MarshalCBOR(w) +} + +func (ct *CborTime) UnmarshalCBOR(r io.Reader) error { + var cbi CborInt + if err := cbi.UnmarshalCBOR(r); err != nil { + return err + } + + t := time.Unix(0, int64(cbi)) + + *ct = (CborTime)(t) + return nil +} + +func (ct CborTime) Time() time.Time { + return (time.Time)(ct) +} + +func (ct CborTime) MarshalJSON() ([]byte, error) { + return ct.Time().MarshalJSON() +} + +func (ct *CborTime) UnmarshalJSON(b []byte) error { + var t time.Time + if err := t.UnmarshalJSON(b); err != nil { + return err + } + *(*time.Time)(ct) = t + return nil +} diff --git a/links.go b/links.go new file mode 100644 index 0000000..c91679b --- /dev/null +++ b/links.go @@ -0,0 +1,125 @@ +package typegen + +import ( + "bufio" + "bytes" + "fmt" + "io" + "io/ioutil" + "math" + + cid "github.com/ipfs/go-cid" +) + +func ScanForLinks(br io.Reader, cb func(cid.Cid)) (err error) { + hasReadOnce := false + defer func() { + if err == io.EOF && hasReadOnce { + err = io.ErrUnexpectedEOF + } + }() + + scratch := make([]byte, maxCidLength) + for remaining := uint64(1); remaining > 0; remaining-- { + maj, extra, err := CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + hasReadOnce = true + + switch maj { + case MajUnsignedInt, MajNegativeInt, MajOther: + case MajByteString, MajTextString: + if extra > math.MaxInt32 { + return fmt.Errorf("string in cbor input too long") + } + + err := discard(br, int(extra)) + if err != nil { + return err + } + case MajTag: + if extra == 42 { + maj, extra, err = CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + + if maj != MajByteString { + return fmt.Errorf("expected cbor type 'byte string' in input") + } + + if extra > maxCidLength { + return fmt.Errorf("string in cbor input too long") + } + + if extra == 0 { + return fmt.Errorf("string in cbor input is empty") + } + + if _, err := io.ReadAtLeast(br, scratch[:extra], int(extra)); err != nil { + return err + } + + c, err := cid.Cast(scratch[1:extra]) + if err != nil { + return err + } + cb(c) + + } else { + remaining++ + } + case MajArray: + remaining += extra + case MajMap: + remaining += (extra * 2) + default: + return fmt.Errorf("unhandled cbor type: %d", maj) + } + } + return nil +} + +// discard is a helper function to discard data from a reader, special-casing +// the most common readers we encounter in this library for a significant +// performance boost. +func discard(br io.Reader, n int) error { + // If we're expecting no bytes, don't even try to read. Otherwise, we may read an EOF. + if n == 0 { + return nil + } + + switch r := br.(type) { + case *bytes.Buffer: + buf := r.Next(n) + if len(buf) == 0 { + return io.EOF + } else if len(buf) < n { + return io.ErrUnexpectedEOF + } + return nil + case *bytes.Reader: + if r.Len() == 0 { + return io.EOF + } else if r.Len() < n { + _, _ = r.Seek(0, io.SeekEnd) + return io.ErrUnexpectedEOF + } + _, err := r.Seek(int64(n), io.SeekCurrent) + return err + case *bufio.Reader: + discarded, err := r.Discard(n) + if discarded != 0 && discarded < n && err == io.EOF { + return io.ErrUnexpectedEOF + } + return err + default: + discarded, err := io.CopyN(ioutil.Discard, br, int64(n)) + if discarded != 0 && discarded < int64(n) && err == io.EOF { + return io.ErrUnexpectedEOF + } + + return err + } +} diff --git a/utils.go b/utils.go index dfaadb7..7c11993 100644 --- a/utils.go +++ b/utils.go @@ -4,13 +4,10 @@ import ( "bufio" "bytes" "encoding/binary" - "errors" "fmt" "io" - "io/ioutil" "math" "sync" - "time" cid "github.com/ipfs/go-cid" ) @@ -20,111 +17,6 @@ const ( maxHeaderSize = 9 ) -// discard is a helper function to discard data from a reader, special-casing -// the most common readers we encounter in this library for a significant -// performance boost. -func discard(br io.Reader, n int) error { - // If we're expecting no bytes, don't even try to read. Otherwise, we may read an EOF. - if n == 0 { - return nil - } - - switch r := br.(type) { - case *bytes.Buffer: - buf := r.Next(n) - if len(buf) == 0 { - return io.EOF - } else if len(buf) < n { - return io.ErrUnexpectedEOF - } - return nil - case *bytes.Reader: - if r.Len() == 0 { - return io.EOF - } else if r.Len() < n { - _, _ = r.Seek(0, io.SeekEnd) - return io.ErrUnexpectedEOF - } - _, err := r.Seek(int64(n), io.SeekCurrent) - return err - case *bufio.Reader: - discarded, err := r.Discard(n) - if discarded != 0 && discarded < n && err == io.EOF { - return io.ErrUnexpectedEOF - } - return err - default: - discarded, err := io.CopyN(ioutil.Discard, br, int64(n)) - if discarded != 0 && discarded < int64(n) && err == io.EOF { - return io.ErrUnexpectedEOF - } - - return err - } -} - -func ScanForLinks(br io.Reader, cb func(cid.Cid)) (err error) { - hasReadOnce := false - defer func() { - if err == io.EOF && hasReadOnce { - err = io.ErrUnexpectedEOF - } - }() - - scratch := make([]byte, maxCidLength) - for remaining := uint64(1); remaining > 0; remaining-- { - maj, extra, err := CborReadHeaderBuf(br, scratch) - if err != nil { - return err - } - hasReadOnce = true - - switch maj { - case MajUnsignedInt, MajNegativeInt, MajOther: - case MajByteString, MajTextString: - err := discard(br, int(extra)) - if err != nil { - return err - } - case MajTag: - if extra == 42 { - maj, extra, err = CborReadHeaderBuf(br, scratch) - if err != nil { - return err - } - - if maj != MajByteString { - return fmt.Errorf("expected cbor type 'byte string' in input") - } - - if extra > maxCidLength { - return fmt.Errorf("string in cbor input too long") - } - - if _, err := io.ReadAtLeast(br, scratch[:extra], int(extra)); err != nil { - return err - } - - c, err := cid.Cast(scratch[1:extra]) - if err != nil { - return err - } - cb(c) - - } else { - remaining++ - } - case MajArray: - remaining += extra - case MajMap: - remaining += (extra * 2) - default: - return fmt.Errorf("unhandled cbor type: %d", maj) - } - } - return nil -} - const ( MajUnsignedInt = 0 MajNegativeInt = 1 @@ -146,93 +38,6 @@ type CBORMarshaler interface { MarshalCBOR(io.Writer) error } -type Deferred struct { - Raw []byte -} - -func (d *Deferred) MarshalCBOR(w io.Writer) error { - if d == nil { - _, err := w.Write(CborNull) - return err - } - if d.Raw == nil { - return errors.New("cannot marshal Deferred with nil value for Raw (will not unmarshal)") - } - _, err := w.Write(d.Raw) - return err -} - -func (d *Deferred) UnmarshalCBOR(br io.Reader) (err error) { - // Reuse any existing buffers. - reusedBuf := d.Raw[:0] - d.Raw = nil - buf := bytes.NewBuffer(reusedBuf) - - // Allocate some scratch space. - scratch := make([]byte, maxHeaderSize) - - hasReadOnce := false - defer func() { - if err == io.EOF && hasReadOnce { - err = io.ErrUnexpectedEOF - } - }() - - // Algorithm: - // - // 1. We start off expecting to read one element. - // 2. If we see a tag, we expect to read one more element so we increment "remaining". - // 3. If see an array, we expect to read "extra" elements so we add "extra" to "remaining". - // 4. If see a map, we expect to read "2*extra" elements so we add "2*extra" to "remaining". - // 5. While "remaining" is non-zero, read more elements. - - // define this once so we don't keep allocating it. - limitedReader := io.LimitedReader{R: br} - for remaining := uint64(1); remaining > 0; remaining-- { - maj, extra, err := CborReadHeaderBuf(br, scratch) - if err != nil { - return err - } - hasReadOnce = true - if err := WriteMajorTypeHeaderBuf(scratch, buf, maj, extra); err != nil { - return err - } - - switch maj { - case MajUnsignedInt, MajNegativeInt, MajOther: - // nothing fancy to do - case MajByteString, MajTextString: - if extra > ByteArrayMaxLen { - return maxLengthError - } - // Copy the bytes - limitedReader.N = int64(extra) - buf.Grow(int(extra)) - if n, err := buf.ReadFrom(&limitedReader); err != nil { - return err - } else if n < int64(extra) { - return io.ErrUnexpectedEOF - } - case MajTag: - remaining++ - case MajArray: - if extra > MaxLength { - return maxLengthError - } - remaining += extra - case MajMap: - if extra > MaxLength { - return maxLengthError - } - remaining += extra * 2 - default: - return fmt.Errorf("unhandled deferred cbor type: %d", maj) - } - } - d.Raw = buf.Bytes() - return nil -} - func readByte(r io.Reader) (byte, error) { // try to cast to a concrete type, it's much faster than casting to an // interface. @@ -555,24 +360,6 @@ func WriteByteArray(bw io.Writer, bytes []byte) error { return nil } -var ( - CborBoolFalse = []byte{0xf4} - CborBoolTrue = []byte{0xf5} - CborNull = []byte{0xf6} -) - -func EncodeBool(b bool) []byte { - if b { - return CborBoolTrue - } - return CborBoolFalse -} - -func WriteBool(w io.Writer, b bool) error { - _, err := w.Write(EncodeBool(b)) - return err -} - var stringBufPool = sync.Pool{ New: func() interface{} { b := make([]byte, MaxLength) @@ -693,111 +480,3 @@ func WriteCidBuf(buf []byte, w io.Writer, c cid.Cid) error { return nil } - -type CborBool bool - -func (cb CborBool) MarshalCBOR(w io.Writer) error { - return WriteBool(w, bool(cb)) -} - -func (cb *CborBool) UnmarshalCBOR(r io.Reader) error { - t, val, err := CborReadHeader(r) - if err != nil { - return err - } - - if t != MajOther { - return fmt.Errorf("booleans should be major type 7") - } - - switch val { - case 20: - *cb = false - case 21: - *cb = true - default: - return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", val) - } - return nil -} - -type CborInt int64 - -func (ci CborInt) MarshalCBOR(w io.Writer) error { - v := int64(ci) - if v >= 0 { - if err := WriteMajorTypeHeader(w, MajUnsignedInt, uint64(v)); err != nil { - return err - } - } else { - if err := WriteMajorTypeHeader(w, MajNegativeInt, uint64(-v)-1); err != nil { - return err - } - } - return nil -} - -func (ci *CborInt) UnmarshalCBOR(r io.Reader) error { - maj, extra, err := CborReadHeader(r) - if err != nil { - return err - } - var extraI int64 - switch maj { - case MajUnsignedInt: - extraI = int64(extra) - if extraI < 0 { - return fmt.Errorf("int64 positive overflow") - } - case MajNegativeInt: - extraI = int64(extra) - if extraI < 0 { - return fmt.Errorf("int64 negative overflow") - } - extraI = -1 - extraI - default: - return fmt.Errorf("wrong type for int64 field: %d", maj) - } - - *ci = CborInt(extraI) - return nil -} - -type CborTime time.Time - -func (ct CborTime) MarshalCBOR(w io.Writer) error { - nsecs := ct.Time().UnixNano() - - cbi := CborInt(nsecs) - - return cbi.MarshalCBOR(w) -} - -func (ct *CborTime) UnmarshalCBOR(r io.Reader) error { - var cbi CborInt - if err := cbi.UnmarshalCBOR(r); err != nil { - return err - } - - t := time.Unix(0, int64(cbi)) - - *ct = (CborTime)(t) - return nil -} - -func (ct CborTime) Time() time.Time { - return (time.Time)(ct) -} - -func (ct CborTime) MarshalJSON() ([]byte, error) { - return ct.Time().MarshalJSON() -} - -func (ct *CborTime) UnmarshalJSON(b []byte) error { - var t time.Time - if err := t.UnmarshalJSON(b); err != nil { - return err - } - *(*time.Time)(ct) = t - return nil -}