diff --git a/any_test.go b/any_test.go deleted file mode 100644 index a984acd7..00000000 --- a/any_test.go +++ /dev/null @@ -1,85 +0,0 @@ -package cmds - -import ( - "encoding/json" - "io" - "reflect" - "strings" - "testing" -) - -type Foo struct { - Bar int -} - -type Bar struct { - Foo string -} - -type ValueError struct { - Error error - Value interface{} -} - -type anyTestCase struct { - Types []interface{} - JSON string - Decoded []ValueError -} - -func TestMaybe(t *testing.T) { - testcases := []anyTestCase{ - anyTestCase{ - Types: []interface{}{Foo{}, &Bar{}}, - JSON: `{"Bar":2}{"Foo":"abc"}`, - Decoded: []ValueError{ - ValueError{Error: nil, Value: &Foo{2}}, - ValueError{Error: nil, Value: &Bar{"abc"}}, - }, - }, - } - - for _, tc := range testcases { - a := &Any{} - - for _, t := range tc.Types { - a.Add(t) - } - - r := strings.NewReader(tc.JSON) - d := json.NewDecoder(r) - - var err error - - for _, dec := range tc.Decoded { - err = d.Decode(a) - if err != dec.Error { - t.Fatalf("error is %v, expected %v", err, dec.Error) - } - - rx := a.Interface() - rxIsPtr := reflect.TypeOf(rx).Kind() == reflect.Ptr - - ex := dec.Value - exIsPtr := reflect.TypeOf(ex).Kind() == reflect.Ptr - - if rxIsPtr != exIsPtr { - t.Fatalf("value is %#v, expected %#v", a.Interface(), dec.Value) - } - - if rxIsPtr { - rx = reflect.ValueOf(rx).Elem().Interface() - ex = reflect.ValueOf(ex).Elem().Interface() - } - - if rx != ex { - t.Fatalf("value is %#v, expected %#v", a.Interface(), dec.Value) - } - } - - err = d.Decode(a) - if err != io.EOF { - t.Fatal("data left in decoder:", a.Interface()) - } - } -} diff --git a/http/response.go b/http/response.go index e01665a7..2a8b22ac 100644 --- a/http/response.go +++ b/http/response.go @@ -67,18 +67,15 @@ func (res *Response) RawNext() (interface{}, error) { } } - a := &cmds.Any{} - a.Add(&cmdkit.Error{}) - a.Add(res.req.Command().Type) - - err := res.dec.Decode(a) + m := &cmds.MaybeError{Value: res.req.Command().Type} + err := res.dec.Decode(m) // last error was sent as value, now we get the same error from the headers. ignore and EOF! if err != nil && res.err != nil && err.Error() == res.err.Error() { err = io.EOF } - return a.Interface(), err + return m.Get(), err } func (res *Response) Next() (interface{}, error) { diff --git a/maybeerror_test.go b/maybeerror_test.go new file mode 100644 index 00000000..f3150576 --- /dev/null +++ b/maybeerror_test.go @@ -0,0 +1,111 @@ +package cmds + +import ( + "encoding/json" + "io" + "reflect" + "strings" + "testing" + + "github.com/ipfs/go-ipfs-cmdkit" +) + +type Foo struct { + Bar int +} + +type Bar struct { + Foo string +} + +type ValueError struct { + Error error + Value interface{} +} + +type anyTestCase struct { + Value interface{} + JSON string + Decoded []ValueError +} + +func TestMaybeError(t *testing.T) { + testcases := []anyTestCase{ + anyTestCase{ + Value: &Foo{}, + JSON: `{"Bar":23}{"Bar":42}{"Message":"some error", "Type": "error"}`, + Decoded: []ValueError{ + ValueError{Error: nil, Value: &Foo{23}}, + ValueError{Error: nil, Value: &Foo{42}}, + ValueError{Error: nil, Value: cmdkit.Error{Message: "some error", Code: 0}}, + }, + }, + anyTestCase{ + Value: Foo{}, + JSON: `{"Bar":23}{"Bar":42}{"Message":"some error", "Type": "error"}`, + Decoded: []ValueError{ + ValueError{Error: nil, Value: &Foo{23}}, + ValueError{Error: nil, Value: &Foo{42}}, + ValueError{Error: nil, Value: cmdkit.Error{Message: "some error", Code: 0}}, + }, + }, + anyTestCase{ + Value: &Bar{}, + JSON: `{"Foo":""}{"Foo":"Qmabc"}{"Message":"some error", "Type": "error"}`, + Decoded: []ValueError{ + ValueError{Error: nil, Value: &Bar{""}}, + ValueError{Error: nil, Value: &Bar{"Qmabc"}}, + ValueError{Error: nil, Value: cmdkit.Error{Message: "some error", Code: 0}}, + }, + }, + anyTestCase{ + Value: Bar{}, + JSON: `{"Foo":""}{"Foo":"Qmabc"}{"Message":"some error", "Type": "error"}`, + Decoded: []ValueError{ + ValueError{Error: nil, Value: &Bar{""}}, + ValueError{Error: nil, Value: &Bar{"Qmabc"}}, + ValueError{Error: nil, Value: cmdkit.Error{Message: "some error", Code: 0}}, + }, + }, + } + + for _, tc := range testcases { + m := &MaybeError{Value: tc.Value} + + r := strings.NewReader(tc.JSON) + d := json.NewDecoder(r) + + var err error + + for _, dec := range tc.Decoded { + err = d.Decode(m) + if err != dec.Error { + t.Fatalf("error is %v, expected %v", err, dec.Error) + } + + rx := m.Get() + rxIsPtr := reflect.TypeOf(rx).Kind() == reflect.Ptr + + ex := dec.Value + exIsPtr := reflect.TypeOf(ex).Kind() == reflect.Ptr + + if rxIsPtr != exIsPtr { + t.Fatalf("value is %#v, expected %#v", m.Get(), dec.Value) + } + + if rxIsPtr { + rx = reflect.ValueOf(rx).Elem().Interface() + ex = reflect.ValueOf(ex).Elem().Interface() + } + + if rx != ex { + t.Fatalf("value is %#v, expected %#v", m.Get(), dec.Value) + } + } + + err = d.Decode(m) + if err != io.EOF { + t.Fatal("data left in decoder:", m.Get()) + } + } +} diff --git a/writer.go b/writer.go index 8f48d507..ed12f986 100644 --- a/writer.go +++ b/writer.go @@ -67,34 +67,24 @@ func (r *readerResponse) Length() uint64 { } func (r *readerResponse) RawNext() (interface{}, error) { - a := &Any{} - a.Add(cmdkit.Error{}) - a.Add(r.req.Command().Type) - - err := r.dec.Decode(a) + m := &MaybeError{Value: r.req.Command().Type} + err := r.dec.Decode(m) if err != nil { return nil, err } r.once.Do(func() { close(r.emitted) }) - v := a.Interface() + v := m.Get() return v, nil } func (r *readerResponse) Next() (interface{}, error) { - a := &Any{} - a.Add(cmdkit.Error{}) - a.Add(r.req.Command().Type) - - err := r.dec.Decode(a) + v, err := r.RawNext() if err != nil { return nil, err } - r.once.Do(func() { close(r.emitted) }) - - v := a.Interface() if err, ok := v.(cmdkit.Error); ok { v = &err } @@ -177,81 +167,32 @@ func (re *WriterResponseEmitter) Emit(v interface{}) error { return re.enc.Encode(v) } -type Any struct { - types map[reflect.Type]bool - order []reflect.Type +type MaybeError struct { + Value interface{} // needs to be a pointer + Error cmdkit.Error - v interface{} + isError bool } -func (a *Any) UnmarshalJSON(data []byte) error { - var ( - iv interface{} - err error - ) - - for _, t := range a.order { - v := reflect.New(t).Elem().Addr() - - isNil := func(v reflect.Value) (yup, ok bool) { - ok = true - defer func() { - r := recover() - if r != nil { - ok = false - } - }() - yup = v.IsNil() - return - } - - isZero := func(v reflect.Value, t reflect.Type) (yup, ok bool) { - ok = true - defer func() { - r := recover() - if r != nil { - ok = false - } - }() - yup = v.Elem().Interface() == reflect.Zero(t).Interface() - return - } - - err = json.Unmarshal(data, v.Interface()) - - vIsNil, isNilOk := isNil(v) - vIsZero, isZeroOk := isZero(v, t) - - nilish := (isNilOk && vIsNil) || (isZeroOk && vIsZero) - if err == nil && !nilish { - a.v = v.Interface() - return nil - } +func (m *MaybeError) Get() interface{} { + if m.isError { + return m.Error } - - err = json.Unmarshal(data, &iv) - a.v = iv - - return err + return m.Value } -func (a *Any) Add(v interface{}) { - if v == nil { - return - } - if a.types == nil { - a.types = map[reflect.Type]bool{} - } - t := reflect.TypeOf(v) - isPtr := t.Kind() == reflect.Ptr - if isPtr || t.Kind() == reflect.Interface { - t = t.Elem() +func (m *MaybeError) UnmarshalJSON(data []byte) error { + err := json.Unmarshal(data, &m.Error) + if err == nil { + m.isError = true + return nil } - a.types[t] = isPtr - a.order = append(a.order, t) -} + // make sure we are working with a pointer here + v := reflect.ValueOf(m.Value) + if v.Kind() != reflect.Ptr { + m.Value = reflect.New(v.Type()).Interface() + } -func (a *Any) Interface() interface{} { - return a.v + return json.Unmarshal(data, m.Value) }