Skip to content

Commit

Permalink
Merge pull request #155 from ipfs/fix/pointer-reflection
Browse files Browse the repository at this point in the history
typed encoder: improve pointer reflection
  • Loading branch information
Stebalien authored Mar 21, 2019
2 parents 8f644c8 + c99a709 commit 3c3985a
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 4 deletions.
27 changes: 23 additions & 4 deletions encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,35 @@ func MakeTypedEncoder(f interface{}) func(*Request) func(io.Writer) Encoder {
panic("MakeTypedEncoder must receive a function matching func(*Request, io.Writer, ...)")
}

valType := t.In(2)
valTypePtr := reflect.PtrTo(valType)
var (
valType, valTypeAlt reflect.Type
)

valType = t.In(2)
valTypeIsPtr := valType.Kind() == reflect.Ptr
if valTypeIsPtr {
valTypeAlt = valType.Elem()
} else {
valTypeAlt = reflect.PtrTo(valType)
}

return MakeEncoder(func(req *Request, w io.Writer, i interface{}) error {
iType := reflect.TypeOf(i)
iValue := reflect.ValueOf(i)
switch iType {
case valType:
case valTypePtr:
iValue = iValue.Elem()
case valTypeAlt:
if valTypeIsPtr {
if iValue.CanAddr() {
iValue = iValue.Addr()
} else {
oldValue := iValue
iValue = reflect.New(iType)
iValue.Elem().Set(oldValue)
}
} else {
iValue = iValue.Elem()
}
default:
return fmt.Errorf("unexpected type %T, expected %v", i, valType)
}
Expand Down
25 changes: 25 additions & 0 deletions encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,31 @@ func TestMakeTypedEncoderByValue(t *testing.T) {
}
}

func TestMakeTypedEncoderByPointer(t *testing.T) {
expErr := fmt.Errorf("command fooTestObj failed")
f := MakeTypedEncoder(func(req *Request, w io.Writer, v *fooTestObj) error {
if v.Good {
return nil
}
return expErr
})

req := &Request{}

encoderFunc := f(req)

buf := new(bytes.Buffer)
encoder := encoderFunc(buf)

if err := encoder.Encode(fooTestObj{true}); err != nil {
t.Fatal(err)
}

if err := encoder.Encode(fooTestObj{false}); err != expErr {
t.Fatal("expected: ", expErr)
}
}

func TestMakeTypedEncoderArrays(t *testing.T) {
f := MakeTypedEncoder(func(req *Request, w io.Writer, v []fooTestObj) error {
if len(v) != 2 {
Expand Down

0 comments on commit 3c3985a

Please sign in to comment.