Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

some cleanup for easier reading #89

Merged
merged 1 commit into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions deferred.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package typegen

import (
"bytes"
"errors"
"fmt"
"io"
)

type Deferred struct {
Raw []byte
}

func (d *Deferred) MarshalCBOR(w io.Writer) error {
if d == nil {
_, err := w.Write(CborNull)
return err
}
if d.Raw == nil {
return errors.New("cannot marshal Deferred with nil value for Raw (will not unmarshal)")
}
_, err := w.Write(d.Raw)
return err
}

func (d *Deferred) UnmarshalCBOR(br io.Reader) (err error) {
// Reuse any existing buffers.
reusedBuf := d.Raw[:0]
d.Raw = nil
buf := bytes.NewBuffer(reusedBuf)

// Allocate some scratch space.
scratch := make([]byte, maxHeaderSize)

hasReadOnce := false
defer func() {
if err == io.EOF && hasReadOnce {
err = io.ErrUnexpectedEOF
}
}()

// Algorithm:
//
// 1. We start off expecting to read one element.
// 2. If we see a tag, we expect to read one more element so we increment "remaining".
// 3. If see an array, we expect to read "extra" elements so we add "extra" to "remaining".
// 4. If see a map, we expect to read "2*extra" elements so we add "2*extra" to "remaining".
// 5. While "remaining" is non-zero, read more elements.

// define this once so we don't keep allocating it.
limitedReader := io.LimitedReader{R: br}
for remaining := uint64(1); remaining > 0; remaining-- {
maj, extra, err := CborReadHeaderBuf(br, scratch)
if err != nil {
return err
}
hasReadOnce = true
if err := WriteMajorTypeHeaderBuf(scratch, buf, maj, extra); err != nil {
return err
}

switch maj {
case MajUnsignedInt, MajNegativeInt, MajOther:
// nothing fancy to do
case MajByteString, MajTextString:
if extra > ByteArrayMaxLen {
return maxLengthError
}
// Copy the bytes
limitedReader.N = int64(extra)
buf.Grow(int(extra))
if n, err := buf.ReadFrom(&limitedReader); err != nil {
return err
} else if n < int64(extra) {
return io.ErrUnexpectedEOF
}
case MajTag:
remaining++
case MajArray:
if extra > MaxLength {
return maxLengthError
}
remaining += extra
case MajMap:
if extra > MaxLength {
return maxLengthError
}
remaining += extra * 2
default:
return fmt.Errorf("unhandled deferred cbor type: %d", maj)
}
}
d.Raw = buf.Bytes()
return nil
}
133 changes: 133 additions & 0 deletions helper_types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package typegen

import (
"fmt"
"io"
"time"
)

var (
CborBoolFalse = []byte{0xf4}
CborBoolTrue = []byte{0xf5}
CborNull = []byte{0xf6}
)

func EncodeBool(b bool) []byte {
if b {
return CborBoolTrue
}
return CborBoolFalse
}

func WriteBool(w io.Writer, b bool) error {
_, err := w.Write(EncodeBool(b))
return err
}

type CborBool bool

func (cb CborBool) MarshalCBOR(w io.Writer) error {
return WriteBool(w, bool(cb))
}

func (cb *CborBool) UnmarshalCBOR(r io.Reader) error {
t, val, err := CborReadHeader(r)
if err != nil {
return err
}

if t != MajOther {
return fmt.Errorf("booleans should be major type 7")
}

switch val {
case 20:
*cb = false
case 21:
*cb = true
default:
return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", val)
}
return nil
}

type CborInt int64

func (ci CborInt) MarshalCBOR(w io.Writer) error {
v := int64(ci)
if v >= 0 {
if err := WriteMajorTypeHeader(w, MajUnsignedInt, uint64(v)); err != nil {
return err
}
} else {
if err := WriteMajorTypeHeader(w, MajNegativeInt, uint64(-v)-1); err != nil {
return err
}
}
return nil
}

func (ci *CborInt) UnmarshalCBOR(r io.Reader) error {
maj, extra, err := CborReadHeader(r)
if err != nil {
return err
}
var extraI int64
switch maj {
case MajUnsignedInt:
extraI = int64(extra)
if extraI < 0 {
return fmt.Errorf("int64 positive overflow")
}
case MajNegativeInt:
extraI = int64(extra)
if extraI < 0 {
return fmt.Errorf("int64 negative overflow")
}
extraI = -1 - extraI
default:
return fmt.Errorf("wrong type for int64 field: %d", maj)
}

*ci = CborInt(extraI)
return nil
}

type CborTime time.Time

func (ct CborTime) MarshalCBOR(w io.Writer) error {
nsecs := ct.Time().UnixNano()

cbi := CborInt(nsecs)

return cbi.MarshalCBOR(w)
}

func (ct *CborTime) UnmarshalCBOR(r io.Reader) error {
var cbi CborInt
if err := cbi.UnmarshalCBOR(r); err != nil {
return err
}

t := time.Unix(0, int64(cbi))

*ct = (CborTime)(t)
return nil
}

func (ct CborTime) Time() time.Time {
return (time.Time)(ct)
}

func (ct CborTime) MarshalJSON() ([]byte, error) {
return ct.Time().MarshalJSON()
}

func (ct *CborTime) UnmarshalJSON(b []byte) error {
var t time.Time
if err := t.UnmarshalJSON(b); err != nil {
return err
}
*(*time.Time)(ct) = t
return nil
}
125 changes: 125 additions & 0 deletions links.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package typegen

import (
"bufio"
"bytes"
"fmt"
"io"
"io/ioutil"
"math"

cid "github.com/ipfs/go-cid"
)

func ScanForLinks(br io.Reader, cb func(cid.Cid)) (err error) {
hasReadOnce := false
defer func() {
if err == io.EOF && hasReadOnce {
err = io.ErrUnexpectedEOF
}
}()

scratch := make([]byte, maxCidLength)
for remaining := uint64(1); remaining > 0; remaining-- {
maj, extra, err := CborReadHeaderBuf(br, scratch)
if err != nil {
return err
}
hasReadOnce = true

switch maj {
case MajUnsignedInt, MajNegativeInt, MajOther:
case MajByteString, MajTextString:
if extra > math.MaxInt32 {
return fmt.Errorf("string in cbor input too long")
}

err := discard(br, int(extra))
if err != nil {
return err
}
case MajTag:
if extra == 42 {
maj, extra, err = CborReadHeaderBuf(br, scratch)
if err != nil {
return err
}

if maj != MajByteString {
return fmt.Errorf("expected cbor type 'byte string' in input")
}

if extra > maxCidLength {
return fmt.Errorf("string in cbor input too long")
}

if extra == 0 {
return fmt.Errorf("string in cbor input is empty")
}

if _, err := io.ReadAtLeast(br, scratch[:extra], int(extra)); err != nil {
return err
}

c, err := cid.Cast(scratch[1:extra])
if err != nil {
return err
}
cb(c)

} else {
remaining++
}
case MajArray:
remaining += extra
case MajMap:
remaining += (extra * 2)
default:
return fmt.Errorf("unhandled cbor type: %d", maj)
}
}
return nil
}

// discard is a helper function to discard data from a reader, special-casing
// the most common readers we encounter in this library for a significant
// performance boost.
func discard(br io.Reader, n int) error {
// If we're expecting no bytes, don't even try to read. Otherwise, we may read an EOF.
if n == 0 {
return nil
}

switch r := br.(type) {
case *bytes.Buffer:
buf := r.Next(n)
if len(buf) == 0 {
return io.EOF
} else if len(buf) < n {
return io.ErrUnexpectedEOF
}
return nil
case *bytes.Reader:
if r.Len() == 0 {
return io.EOF
} else if r.Len() < n {
_, _ = r.Seek(0, io.SeekEnd)
return io.ErrUnexpectedEOF
}
_, err := r.Seek(int64(n), io.SeekCurrent)
return err
case *bufio.Reader:
discarded, err := r.Discard(n)
if discarded != 0 && discarded < n && err == io.EOF {
return io.ErrUnexpectedEOF
}
return err
default:
discarded, err := io.CopyN(ioutil.Discard, br, int64(n))
if discarded != 0 && discarded < int64(n) && err == io.EOF {
return io.ErrUnexpectedEOF
}

return err
}
}
Loading