Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Request Inputs in MarshalGQL #209

Merged
merged 3 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions .golangci.yml
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
# See https://github.com/golangci/golangci-lint/blob/master/.golangci.example.yml
run:
linters-settings:
govet:
enable-all: true
disable:
- shadow
unused:
check-exported: true
unparam:
check-exported: true
varcheck:
exported-fields: true
structcheck:
exported-fields: true
nakedret:
max-func-lines: 1


linters:
enable-all: true
disable:
Expand Down Expand Up @@ -55,11 +49,18 @@ linters:
- depguard
- musttag
- paralleltest
fast: false
- nlreturn
fast: true

issues:
fix: true
exclude-files:
- _test\.go
- examples/**/*\.go
max-issues-per-linter: 0
max-same-issues: 0
exclude-dirs:
- examples
exclude-rules:
# Test
- path: _test\.go
Expand Down
8 changes: 4 additions & 4 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ type Client struct {

// Request represents an outgoing GraphQL request
type Request struct {
Query string `json:"query"`
Variables map[string]interface{} `json:"variables,omitempty"`
OperationName string `json:"operationName,omitempty"`
Query string `json:"query"`
Variables map[string]any `json:"variables,omitempty"`
OperationName string `json:"operationName,omitempty"`
}

// NewClient creates a new http client wrapper
Expand All @@ -38,7 +38,7 @@ func NewClient(client *http.Client, baseURL string, options ...HTTPRequestOption
}
}

func (c *Client) newRequest(ctx context.Context, operationName, query string, vars map[string]interface{}, httpRequestOptions []HTTPRequestOption) (*http.Request, error) {
func (c *Client) newRequest(ctx context.Context, operationName, query string, vars map[string]any, httpRequestOptions []HTTPRequestOption) (*http.Request, error) {
r := &Request{
Query: query,
Variables: vars,
Expand Down
2 changes: 1 addition & 1 deletion clientgenv2/template.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@

{{- if $.GenerateClient }}
func (c *Client) {{ $model.Name|go }} (ctx context.Context{{- range $arg := .Args }}, {{ $arg.Variable | goPrivate }} {{ $arg.Type | ref }} {{- end }}, interceptors ...clientv2.RequestInterceptor) (*{{ $model.ResponseStructName | go }}, error) {
vars := map[string]interface{}{
vars := map[string]any{
{{- range $args := .VariableDefinitions}}
"{{ $args.Variable }}": {{ $args.Variable | goPrivate }},
{{- end }}
Expand Down
259 changes: 258 additions & 1 deletion clientv2/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import (
"io"
"mime/multipart"
"net/http"
"reflect"
"strconv"
"strings"

"github.com/99designs/gqlgen/graphql"
"github.com/Yamashou/gqlgenc/graphqljson"
Expand Down Expand Up @@ -182,7 +184,7 @@ func (c *Client) Post(ctx context.Context, operationName, query string, respData

headers = append(headers, header{key: "Content-Type", value: contentType})
} else {
requestBody, err := json.Marshal(r)
requestBody, err := MarshalJSON(r)
if err != nil {
return fmt.Errorf("encode: %w", err)
}
Expand Down Expand Up @@ -391,3 +393,258 @@ func (c *Client) unmarshal(data []byte, res interface{}) error {

return err
}

func MarshalJSON(v interface{}) ([]byte, error) {
encoderFunc := getTypeEncoder(reflect.TypeOf(v))
return encoderFunc(v)
}

// getTypeEncoder returns an appropriate encoder function for the provided type.
func getTypeEncoder(t reflect.Type) func(interface{}) ([]byte, error) {
if t.Implements(reflect.TypeOf((*graphql.Marshaler)(nil)).Elem()) {
return gqlMarshalerEncoder
}

switch t.Kind() {
case reflect.Ptr:
return newPtrEncoder(t)
case reflect.Struct:
return newStructEncoder(t)
case reflect.Map:
return newMapEncoder(t)
case reflect.Slice:
return newSliceEncoder(t)
case reflect.Array:
return newArrayEncoder(t)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return newIntEncoder()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return newUintEncoder()
case reflect.String:
return newStringEncoder()
case reflect.Bool:
return newBoolEncoder()
case reflect.Float32, reflect.Float64:
return newFloatEncoder()
case reflect.Interface:
return newInterfaceEncoder()
case reflect.Invalid, reflect.Complex64, reflect.Complex128, reflect.Chan, reflect.Func, reflect.UnsafePointer:
panic(fmt.Sprintf("unsupported type: %s", t))
default:
panic(fmt.Sprintf("unsupported type: %s", t))
}
}

func gqlMarshalerEncoder(v interface{}) ([]byte, error) {
var buf bytes.Buffer
if val, ok := v.(graphql.Marshaler); ok {
val.MarshalGQL(&buf)
} else {
return nil, fmt.Errorf("failed to encode graphql.Marshaler: %v", v)
}

return buf.Bytes(), nil
}

func newBoolEncoder() func(interface{}) ([]byte, error) {
return func(v interface{}) ([]byte, error) {
if v, ok := v.(bool); ok {
boolValue, err := json.Marshal(v)
if err != nil {
return nil, fmt.Errorf("failed to encode bool: %v", v)
}
return boolValue, nil
} else {
return nil, fmt.Errorf("failed to encode bool: %v", v)
}
}
}

func newIntEncoder() func(interface{}) ([]byte, error) {
return func(v interface{}) ([]byte, error) {
return []byte(fmt.Sprintf("%d", v)), nil
}
}

func newUintEncoder() func(interface{}) ([]byte, error) {
return func(v interface{}) ([]byte, error) {
return []byte(fmt.Sprintf("%d", v)), nil
}
}

func newFloatEncoder() func(interface{}) ([]byte, error) {
return func(v interface{}) ([]byte, error) {
return []byte(fmt.Sprintf("%f", v)), nil
}
}

func newStringEncoder() func(interface{}) ([]byte, error) {
return func(v interface{}) ([]byte, error) {
stringValue, err := json.Marshal(v)
if err != nil {
return nil, fmt.Errorf("failed to encode string: %v", v)
}

return stringValue, nil
}
}

type fieldInfo struct {
name string
jsonName string
typ reflect.Type
}

func prepareFields(t reflect.Type) []fieldInfo {
num := t.NumField()
fields := make([]fieldInfo, 0, num)
for i := 0; i < num; i++ {
f := t.Field(i)
if f.PkgPath != "" && !f.Anonymous { // Skip unexported fields unless they are embedded
continue
}
jsonTag := f.Tag.Get("json")
if jsonTag == "-" {
continue // Skip fields explicitly marked to be ignored
}
jsonName := f.Name
if jsonTag != "" {
parts := strings.Split(jsonTag, ",")
jsonName = parts[0] // Use the name specified in the JSON tag
}
fields = append(fields, fieldInfo{
name: f.Name,
jsonName: jsonName,
typ: f.Type,
})
}

return fields
}

func newStructEncoder(t reflect.Type) func(interface{}) ([]byte, error) {
fields := prepareFields(t) // Prepare and cache fields information
return func(v interface{}) ([]byte, error) {
val := reflect.ValueOf(v)
result := make(map[string]json.RawMessage)
for _, field := range fields {
fieldValue := val.FieldByName(field.name)
if !fieldValue.IsValid() {
continue
}
encoder := getTypeEncoder(field.typ)
encodedValue, err := encoder(fieldValue.Interface())
if err != nil {
return nil, err
}
result[field.jsonName] = encodedValue
}

return json.Marshal(result)
}
}

func trimQuotes(s string) string {
if len(s) > 1 && s[0] == '"' && s[len(s)-1] == '"' {
return s[1 : len(s)-1]
}

return s
}

func newMapEncoder(t reflect.Type) func(interface{}) ([]byte, error) {
keyEncoder := getTypeEncoder(t.Key())
valueEncoder := getTypeEncoder(t.Elem())

return func(v interface{}) ([]byte, error) {
val := reflect.ValueOf(v)
result := make(map[string]json.RawMessage)
for _, key := range val.MapKeys() {
encodedKey, err := keyEncoder(key.Interface())
if err != nil {
return nil, err
}
keyStr := string(encodedKey)
keyStr = trimQuotes(keyStr)

value := val.MapIndex(key)
encodedValue, err := valueEncoder(value.Interface())
if err != nil {
return nil, err
}
result[keyStr] = json.RawMessage(encodedValue) // Use json.RawMessage to avoid double encoding
}

return json.Marshal(result)
}
}

func newSliceEncoder(t reflect.Type) func(interface{}) ([]byte, error) {
elemEncoder := getTypeEncoder(t.Elem())
return func(v interface{}) ([]byte, error) {
val := reflect.ValueOf(v)
result := make([]json.RawMessage, val.Len())
for i := 0; i < val.Len(); i++ {
encodedValue, err := elemEncoder(val.Index(i).Interface())
if err != nil {
return nil, err
}
result[i] = encodedValue
}

return json.Marshal(result)
}
}

func newArrayEncoder(t reflect.Type) func(interface{}) ([]byte, error) {
elemEncoder := getTypeEncoder(t.Elem())
return func(v interface{}) ([]byte, error) {
val := reflect.ValueOf(v)
result := make([]json.RawMessage, val.Len())
for i := 0; i < val.Len(); i++ {
encodedValue, err := elemEncoder(val.Index(i).Interface())
if err != nil {
return nil, err
}
result[i] = encodedValue
}

return json.Marshal(result)
}
}

func newPtrEncoder(t reflect.Type) func(interface{}) ([]byte, error) {
if t.Elem().Kind() == reflect.Ptr {
return newPtrEncoder(t.Elem())
}
elemEncoder := getTypeEncoder(t.Elem())
return func(v interface{}) ([]byte, error) {
val := reflect.ValueOf(v)
if val.IsNil() {
return []byte("null"), nil
}

return elemEncoder(val.Elem().Interface())
}
}

func newInterfaceEncoder() func(interface{}) ([]byte, error) {
return func(v interface{}) ([]byte, error) {
if v == nil {
return []byte("null"), nil
}
actualValue := reflect.ValueOf(v)
if actualValue.Kind() == reflect.Interface && !actualValue.IsNil() {
// Extract the element inside the interface value
actualValue = actualValue.Elem()
}
if actualValue.IsValid() {
actualType := actualValue.Type()
encoder := getTypeEncoder(actualType)

return encoder(actualValue.Interface())
}

return []byte("null"), nil
}
}
Loading