Skip to content

Commit

Permalink
Merge pull request #220 from Yamashou/fix-bug-marshal
Browse files Browse the repository at this point in the history
Fix Bug Marshal json
  • Loading branch information
Yamashou authored Apr 23, 2024
2 parents bfc7c4f + d4cb58d commit c9749ff
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 211 deletions.
313 changes: 103 additions & 210 deletions clientv2/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,47 +404,59 @@ func MarshalJSON(v interface{}) ([]byte, error) {
return []byte("null"), nil // Return "null" for nil pointer or invalid reflect value
}

encoderFunc := getTypeEncoder(reflect.TypeOf(v))
return encoderFunc(v)
return encode(val)
}

// getTypeEncoder returns an appropriate encoder function for the provided type.
func getTypeEncoder(t reflect.Type) func(any2 any) ([]byte, error) {
if t.Implements(reflect.TypeOf((*graphql.Marshaler)(nil)).Elem()) || (t.Kind() == reflect.Ptr && reflect.PtrTo(t).Implements(reflect.TypeOf((*graphql.Marshaler)(nil)).Elem())) {
return gqlMarshalerEncoder
func checkImplements[I any](v reflect.Value) bool {
t := v.Type()
interfaceType := reflect.TypeOf((*I)(nil)).Elem()

// Check if the type implements the interface directly or as a pointer.
return t.Implements(interfaceType) || (t.Kind() == reflect.Ptr && reflect.PtrTo(t).Implements(interfaceType))
}

// encode returns an appropriate encoder function for the provided value.
func encode(v reflect.Value) ([]byte, error) {
if checkImplements[graphql.Marshaler](v) {
return encodeGQLMarshaler(v.Interface())
}

if checkImplements[json.Marshaler](v) {
return encodeJsonMarshaler(v.Interface())
}

t := v.Type() // Get the type from the value
switch t.Kind() {
case reflect.Ptr:
return newPtrEncoder(t)
return encodePtr(v)
case reflect.Struct:
return newStructEncoder(t)
return encodeStruct(v)
case reflect.Map:
return newMapEncoder(t)
return encodeMap(v)
case reflect.Slice:
return newSliceEncoder(t)
return encodeSlice(v)
case reflect.Array:
return newArrayEncoder(t)
return encodeArray(v)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return newIntEncoder()
return encodeInt(v)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return newUintEncoder()
return encodeUint(v)
case reflect.String:
return newStringEncoder()
return encodeString(v)
case reflect.Bool:
return newBoolEncoder()
return encodeBool(v)
case reflect.Float32, reflect.Float64:
return newFloatEncoder()
return encodeFloat(v)
case reflect.Interface:
return newInterfaceEncoder()
return encodeInterface(v)
case reflect.Invalid, reflect.Complex64, reflect.Complex128, reflect.Chan, reflect.Func, reflect.UnsafePointer:
panic(fmt.Sprintf("unsupported type: %s", t))
default:
panic(fmt.Sprintf("unsupported type: %s", t))
}
}

func gqlMarshalerEncoder(v any) ([]byte, error) {
func encodeGQLMarshaler(v any) ([]byte, error) {
var buf bytes.Buffer
if val, ok := v.(graphql.Marshaler); ok {
val.MarshalGQL(&buf)
Expand All @@ -455,47 +467,40 @@ func gqlMarshalerEncoder(v any) ([]byte, error) {
return buf.Bytes(), nil
}

func newBoolEncoder() func(interface{}) ([]byte, error) {
return func(v interface{}) ([]byte, error) {
if v, ok := v.(bool); ok {
boolValue, err := json.Marshal(v)
if err != nil {
return nil, fmt.Errorf("failed to encode bool: %v", v)
}
return boolValue, nil
} else {
return nil, fmt.Errorf("failed to encode bool: %v", v)
}
func encodeJsonMarshaler(v any) ([]byte, error) {
if val, ok := v.(json.Marshaler); ok {
return val.MarshalJSON()
} else {
return nil, fmt.Errorf("failed to encode json.Marshaler: %v", v)
}
}

func newIntEncoder() func(interface{}) ([]byte, error) {
return func(v interface{}) ([]byte, error) {
return []byte(fmt.Sprintf("%d", v)), nil
func encodeBool(v reflect.Value) ([]byte, error) {
boolValue, err := json.Marshal(v.Bool())
if err != nil {
return nil, fmt.Errorf("failed to encode bool: %v", v)
}
return boolValue, nil
}

func newUintEncoder() func(interface{}) ([]byte, error) {
return func(v interface{}) ([]byte, error) {
return []byte(fmt.Sprintf("%d", v)), nil
}
func encodeInt(v reflect.Value) ([]byte, error) {
return []byte(fmt.Sprintf("%d", v.Int())), nil
}

func newFloatEncoder() func(interface{}) ([]byte, error) {
return func(v interface{}) ([]byte, error) {
return []byte(fmt.Sprintf("%f", v)), nil
}
func encodeUint(v reflect.Value) ([]byte, error) {
return []byte(fmt.Sprintf("%d", v.Uint())), nil
}

func newStringEncoder() func(interface{}) ([]byte, error) {
return func(v interface{}) ([]byte, error) {
stringValue, err := json.Marshal(v)
if err != nil {
return nil, fmt.Errorf("failed to encode string: %v", v)
}
func encodeFloat(v reflect.Value) ([]byte, error) {
return []byte(fmt.Sprintf("%f", v.Float())), nil
}

return stringValue, nil
func encodeString(v reflect.Value) ([]byte, error) {
stringValue, err := json.Marshal(v.String())
if err != nil {
return nil, fmt.Errorf("failed to encode string: %v", v)
}
return stringValue, nil
}

type fieldInfo struct {
Expand Down Expand Up @@ -531,97 +536,22 @@ func prepareFields(t reflect.Type) []fieldInfo {
return fields
}

func checkMarshalerFields(t reflect.Type) bool {
switch t.Kind() {
case reflect.Ptr:
return checkMarshalerFields(t.Elem())

case reflect.Struct:
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if isMarshalerType(f.Type) {
return true
}
// Recursively check for nested structs
if checkMarshalerFields(f.Type) {
return true
}

// If the value type is interface{}, we need to handle it at runtime
if f.Type.Kind() == reflect.Interface {
return true // Assume it could implement Marshaler at runtime
}
}

case reflect.Map:
// Check key type for Marshaler implementation (usually not needed unless custom types used as keys)
keyType := t.Key()
if isMarshalerType(keyType) {
return true
}

// Check value type for Marshaler implementation
valueType := t.Elem()
if isMarshalerType(valueType) {
return true
}

// If the value type is interface{}, we need to handle it at runtime
if valueType.Kind() == reflect.Interface {
return true // Assume it could implement Marshaler at runtime
func encodeStruct(v reflect.Value) ([]byte, error) {
fields := prepareFields(v.Type())
result := make(map[string]json.RawMessage)
for _, field := range fields {
fieldValue := v.FieldByName(field.name)
if !fieldValue.IsValid() || (fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil()) {
continue // Skip invalid or nil pointers to avoid panics
}

// Recursively check the map value type
return checkMarshalerFields(valueType)

case reflect.Slice, reflect.Array:
// Recursively check the element type
return checkMarshalerFields(t.Elem())
case reflect.Interface, reflect.Invalid, reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return false
default:
return false
}

return false
}

func isMarshalerType(t reflect.Type) bool {
if t.Implements(reflect.TypeOf((*graphql.Marshaler)(nil)).Elem()) {
return true
}
if reflect.PtrTo(t).Implements(reflect.TypeOf((*graphql.Marshaler)(nil)).Elem()) {
return true
}
return false
}

func newStructEncoder(t reflect.Type) func(any2 any) ([]byte, error) {
fields := prepareFields(t)
marshalerFieldExists := checkMarshalerFields(t)

return func(v any) ([]byte, error) {
// If no field implements the MarshalerGQL interface, use standard JSON marshaling
if !marshalerFieldExists {
return json.Marshal(v)
}

val := reflect.ValueOf(v)
result := make(map[string]json.RawMessage)
for _, field := range fields {
fieldValue := val.FieldByName(field.name)
if !fieldValue.IsValid() || (fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil()) {
continue // Skip invalid or nil pointers to avoid panics
}
encoder := getTypeEncoder(field.typ)
encodedValue, err := encoder(fieldValue.Interface())
if err != nil {
return nil, err
}
result[field.jsonName] = encodedValue
encodedValue, err := encode(fieldValue)
if err != nil {
return nil, err
}
return json.Marshal(result)
result[field.jsonName] = encodedValue
}
return json.Marshal(result)
}

func trimQuotes(s string) string {
Expand All @@ -632,99 +562,62 @@ func trimQuotes(s string) string {
return s
}

func newMapEncoder(t reflect.Type) func(interface{}) ([]byte, error) {
keyEncoder := getTypeEncoder(t.Key())

return func(v interface{}) ([]byte, error) {
val := reflect.ValueOf(v)
result := make(map[string]json.RawMessage)
for _, key := range val.MapKeys() {
encodedKey, err := keyEncoder(key.Interface())
if err != nil {
return nil, err
}
keyStr := string(encodedKey)
keyStr = trimQuotes(keyStr)

value := val.MapIndex(key)
valueEncoder := getTypeEncoder(value.Type())
encodedValue, err := valueEncoder(value.Interface())
if err != nil {
return nil, err
}
result[keyStr] = encodedValue
func encodeMap(v reflect.Value) ([]byte, error) {
result := make(map[string]json.RawMessage)
for _, key := range v.MapKeys() {
encodedKey, err := encode(key)
if err != nil {
return nil, err
}
keyStr := string(encodedKey)
keyStr = trimQuotes(keyStr)

return json.Marshal(result)
value := v.MapIndex(key)
encodedValue, err := encode(value)
if err != nil {
return nil, err
}
result[keyStr] = encodedValue
}
return json.Marshal(result)
}

func newSliceEncoder(t reflect.Type) func(interface{}) ([]byte, error) {
elemEncoder := getTypeEncoder(t.Elem())
return func(v interface{}) ([]byte, error) {
val := reflect.ValueOf(v)
result := make([]json.RawMessage, val.Len())
for i := 0; i < val.Len(); i++ {
encodedValue, err := elemEncoder(val.Index(i).Interface())
if err != nil {
return nil, err
}
result[i] = encodedValue
func encodeSlice(v reflect.Value) ([]byte, error) {
result := make([]json.RawMessage, v.Len())
for i := 0; i < v.Len(); i++ {
encodedValue, err := encode(v.Index(i))
if err != nil {
return nil, err
}

return json.Marshal(result)
result[i] = encodedValue
}
return json.Marshal(result)
}

func newArrayEncoder(t reflect.Type) func(interface{}) ([]byte, error) {
elemEncoder := getTypeEncoder(t.Elem())
return func(v interface{}) ([]byte, error) {
val := reflect.ValueOf(v)
result := make([]json.RawMessage, val.Len())
for i := 0; i < val.Len(); i++ {
encodedValue, err := elemEncoder(val.Index(i).Interface())
if err != nil {
return nil, err
}
result[i] = encodedValue
func encodeArray(v reflect.Value) ([]byte, error) {
result := make([]json.RawMessage, v.Len())
for i := 0; i < v.Len(); i++ {
encodedValue, err := encode(v.Index(i))
if err != nil {
return nil, err
}

return json.Marshal(result)
result[i] = encodedValue
}
return json.Marshal(result)
}

func newPtrEncoder(t reflect.Type) func(interface{}) ([]byte, error) {
if t.Elem().Kind() == reflect.Ptr {
return newPtrEncoder(t.Elem())
func encodePtr(v reflect.Value) ([]byte, error) {
if v.IsNil() {
return []byte("null"), nil
}
elemEncoder := getTypeEncoder(t.Elem())
return func(v interface{}) ([]byte, error) {
val := reflect.ValueOf(v)
if val.IsNil() {
return []byte("null"), nil
}

return elemEncoder(val.Elem().Interface())
}
return encode(v.Elem())
}

func newInterfaceEncoder() func(interface{}) ([]byte, error) {
return func(v interface{}) ([]byte, error) {
if v == nil {
return []byte("null"), nil
}
actualValue := reflect.ValueOf(v)
if actualValue.Kind() == reflect.Interface && !actualValue.IsNil() {
// Extract the element inside the interface value
actualValue = actualValue.Elem()
}
if actualValue.IsValid() {
actualType := actualValue.Type()
encoder := getTypeEncoder(actualType)

return encoder(actualValue.Interface())
}

func encodeInterface(v reflect.Value) ([]byte, error) {
if v.IsNil() {
return []byte("null"), nil
}
actualValue := v.Elem()
return encode(actualValue)
}
Loading

0 comments on commit c9749ff

Please sign in to comment.