diff --git a/clientgenv2/source_generator.go b/clientgenv2/source_generator.go index 05d6bfe..bf67fd3 100644 --- a/clientgenv2/source_generator.go +++ b/clientgenv2/source_generator.go @@ -3,6 +3,7 @@ package clientgenv2 import ( "fmt" "go/types" + "sort" "strings" "github.com/99designs/gqlgen/codegen/config" @@ -27,6 +28,12 @@ type ResponseField struct { ResponseFields ResponseFieldList } +func (r ResponseField) FieldTypeString() string { + fullFieldType := r.Type.String() + parts := strings.Split(fullFieldType, ".") + return parts[len(parts)-1] +} + type ResponseFieldList []*ResponseField func (rs ResponseFieldList) IsFragmentSpread() bool { @@ -40,23 +47,10 @@ func (rs ResponseFieldList) IsFragmentSpread() bool { func (rs ResponseFieldList) StructType() *types.Struct { vars := make([]*types.Var, 0) structTags := make([]string, 0) - for _, filed := range rs { - // クエリーのフィールドの子階層がFragmentの場合、このフィールドにそのFragmentの型を追加する - if filed.IsFragmentSpread { - typ, ok := filed.ResponseFields.StructType().Underlying().(*types.Struct) - if !ok { - continue - } - for j := range typ.NumFields() { - vars = append(vars, typ.Field(j)) - structTags = append(structTags, typ.Tag(j)) - } - } else { - vars = append(vars, types.NewVar(0, nil, templates.ToGo(filed.Name), filed.Type)) - structTags = append(structTags, strings.Join(filed.Tags, " ")) - } + for _, field := range rs { + vars = append(vars, types.NewVar(0, nil, templates.ToGo(field.Name), field.Type)) + structTags = append(structTags, strings.Join(field.Tags, " ")) } - return types.NewStruct(vars, structTags) } @@ -76,6 +70,123 @@ func (rs ResponseFieldList) IsStructType() bool { return len(rs) > 0 && !rs.IsFragment() } +func (rs ResponseFieldList) MapByName() map[string]*ResponseField { + res := make(map[string]*ResponseField) + for _, field := range rs { + res[field.Name] = field + } + return res +} + +func (rs ResponseFieldList) SortByName() ResponseFieldList { + sort.Slice(rs, func(i, j int) bool { + return rs[i].Name < rs[j].Name + }) + return rs +} + +type StructGenerator struct { + currentResponseFieldList ResponseFieldList // Create fields based on this ResponseFieldList + preMergedStructSources []*StructSource // Struct sources that will no longer be created due to merging + postMergedStructSources []*StructSource // Struct sources that will be created due to merging +} + +func NewStructGenerator(responseFieldList ResponseFieldList) *StructGenerator { + currentFields := make(ResponseFieldList, 0) + fragmentChildrenFields := make(ResponseFieldList, 0) + for _, field := range responseFieldList { + if field.IsFragmentSpread { + fragmentChildrenFields = append(fragmentChildrenFields, field.ResponseFields...) + } else { + currentFields = append(currentFields, field) + } + } + + preMergedStructSources := make([]*StructSource, 0) + + for _, field := range responseFieldList { + if field.IsFragmentSpread { + preMergedStructSources = append(preMergedStructSources, &StructSource{ + Name: field.FieldTypeString(), + Type: field.ResponseFields.StructType(), + }) + } + } + + currentFields, preMergedStructSources, postMergedStructSources := mergeFieldsRecursively(currentFields, fragmentChildrenFields, preMergedStructSources, nil) + return &StructGenerator{ + currentResponseFieldList: currentFields, + preMergedStructSources: preMergedStructSources, + postMergedStructSources: postMergedStructSources, + } +} + +func mergeFieldsRecursively(targetFields ResponseFieldList, sourceFields ResponseFieldList, preMerged, postMerged []*StructSource) (ResponseFieldList, []*StructSource, []*StructSource) { + responseFieldList := make(ResponseFieldList, 0) + targetFieldsMap := targetFields.MapByName() + newPreMerged := preMerged + newPostMerged := postMerged + + for _, sourceField := range sourceFields { + if targetField, ok := targetFieldsMap[sourceField.Name]; ok { + if targetField.ResponseFields.IsBasicType() { + continue + } + newPreMerged = append(newPreMerged, &StructSource{ + Name: sourceField.FieldTypeString(), + Type: sourceField.ResponseFields.StructType(), + }) + newPreMerged = append(newPreMerged, &StructSource{ + Name: targetField.FieldTypeString(), + Type: targetField.ResponseFields.StructType(), + }) + + targetField.ResponseFields, newPreMerged, newPostMerged = mergeFieldsRecursively(targetField.ResponseFields, sourceField.ResponseFields, newPreMerged, newPostMerged) + newPostMerged = append(newPostMerged, &StructSource{ + Name: targetField.FieldTypeString(), + Type: targetField.ResponseFields.StructType(), + }) + } else { + responseFieldList = append(responseFieldList, sourceField) + } + } + for _, field := range targetFieldsMap { + responseFieldList = append(responseFieldList, field) + } + responseFieldList = responseFieldList.SortByName() + return responseFieldList, newPreMerged, newPostMerged +} + +func structSourcesMapByTypeName(sources []*StructSource) map[string]*StructSource { + res := make(map[string]*StructSource) + for _, source := range sources { + res[source.Name] = source + } + return res +} + +func (g *StructGenerator) MergedStructSources(sources []*StructSource) []*StructSource { + preMergedStructSourcesMap := structSourcesMapByTypeName(g.preMergedStructSources) + res := make([]*StructSource, 0) + // remove pre-merged struct + for _, source := range sources { + // when name is same, remove it + if _, ok := preMergedStructSourcesMap[source.Name]; ok { + continue + } + res = append(res, source) + } + + // append post-merged struct + res = append(res, g.postMergedStructSources...) + + return res +} + +func (g *StructGenerator) GetCurrentResponseFieldList() ResponseFieldList { + return g.currentResponseFieldList +} + type StructSource struct { Name string Type types.Type @@ -130,7 +241,15 @@ func (r *SourceGenerator) NewResponseField(selection ast.Selection, typeName str // if a child field is fragment, this field type became fragment. baseType = fieldsResponseFields[0].Type case fieldsResponseFields.IsStructType(): - structType := fieldsResponseFields.StructType() + // 子フィールドにFragmentがある場合は、現在のフィールドとマージする + // if there is a fragment in child fields, merge it with the current field + generator := NewStructGenerator(fieldsResponseFields) + + // restruct struct sources + r.StructSources = generator.MergedStructSources(r.StructSources) + + // append current struct + structType := generator.GetCurrentResponseFieldList().StructType() r.StructSources = append(r.StructSources, &StructSource{ Name: typeName, Type: structType, diff --git a/example/custom-fragment/.gqlgenc.yml b/example/custom-fragment/.gqlgenc.yml new file mode 100644 index 0000000..12f1359 --- /dev/null +++ b/example/custom-fragment/.gqlgenc.yml @@ -0,0 +1,17 @@ +model: + package: generated + filename: ./model/models_gen.go +client: + package: generated + filename: ./gen/client.go +models: + Int: + model: github.com/99designs/gqlgen/graphql.Int64 + Date: + model: github.com/99designs/gqlgen/graphql.Time +schema: + - "./schema/**/*.graphql" +query: + - "./query/*.graphql" +generate: + onlyUsedModels: true diff --git a/example/custom-fragment/gen/client.go b/example/custom-fragment/gen/client.go new file mode 100644 index 0000000..664d88f --- /dev/null +++ b/example/custom-fragment/gen/client.go @@ -0,0 +1,186 @@ +// Code generated by github.com/Yamashou/gqlgenc, DO NOT EDIT. + +package generated + +import ( + "context" + + "github.com/Yamashou/gqlgenc/clientv2" +) + +type Client struct { + Client *clientv2.Client +} + +func NewClient(cli clientv2.HttpClient, baseURL string, options *clientv2.Options, interceptors ...clientv2.RequestInterceptor) *Client { + return &Client{Client: clientv2.NewClient(cli, baseURL, options, interceptors...)} +} + +type UserFragment struct { + ID string "json:\"id\" graphql:\"id\"" + Profile UserFragment_Profile "json:\"profile\" graphql:\"profile\"" +} + +func (t *UserFragment) GetID() string { + if t == nil { + t = &UserFragment{} + } + return t.ID +} +func (t *UserFragment) GetProfile() *UserFragment_Profile { + if t == nil { + t = &UserFragment{} + } + return &t.Profile +} + +type UserFragment_Profile_Detail struct { + BirthDate string "json:\"birthDate\" graphql:\"birthDate\"" +} + +func (t *UserFragment_Profile_Detail) GetBirthDate() string { + if t == nil { + t = &UserFragment_Profile_Detail{} + } + return t.BirthDate +} + +type UserFragment_Profile struct { + Detail UserFragment_Profile_Detail "json:\"detail\" graphql:\"detail\"" + ID string "json:\"id\" graphql:\"id\"" +} + +func (t *UserFragment_Profile) GetDetail() *UserFragment_Profile_Detail { + if t == nil { + t = &UserFragment_Profile{} + } + return &t.Detail +} +func (t *UserFragment_Profile) GetID() string { + if t == nil { + t = &UserFragment_Profile{} + } + return t.ID +} + +type UserDetail_User_Profile_Detail struct { + BirthDate string "json:\"birthDate\" graphql:\"birthDate\"" + ID string "json:\"id\" graphql:\"id\"" +} + +func (t *UserDetail_User_Profile_Detail) GetBirthDate() string { + if t == nil { + t = &UserDetail_User_Profile_Detail{} + } + return t.BirthDate +} +func (t *UserDetail_User_Profile_Detail) GetID() string { + if t == nil { + t = &UserDetail_User_Profile_Detail{} + } + return t.ID +} + +type UserDetail_User_Profile struct { + Company string "json:\"company\" graphql:\"company\"" + Detail UserDetail_User_Profile_Detail "json:\"detail\" graphql:\"detail\"" + ID string "json:\"id\" graphql:\"id\"" + Name string "json:\"name\" graphql:\"name\"" +} + +func (t *UserDetail_User_Profile) GetCompany() string { + if t == nil { + t = &UserDetail_User_Profile{} + } + return t.Company +} +func (t *UserDetail_User_Profile) GetDetail() *UserDetail_User_Profile_Detail { + if t == nil { + t = &UserDetail_User_Profile{} + } + return &t.Detail +} +func (t *UserDetail_User_Profile) GetID() string { + if t == nil { + t = &UserDetail_User_Profile{} + } + return t.ID +} +func (t *UserDetail_User_Profile) GetName() string { + if t == nil { + t = &UserDetail_User_Profile{} + } + return t.Name +} + +type UserDetail_User struct { + ID string "json:\"id\" graphql:\"id\"" + Profile UserDetail_User_Profile "json:\"profile\" graphql:\"profile\"" +} + +func (t *UserDetail_User) GetID() string { + if t == nil { + t = &UserDetail_User{} + } + return t.ID +} +func (t *UserDetail_User) GetProfile() *UserDetail_User_Profile { + if t == nil { + t = &UserDetail_User{} + } + return &t.Profile +} + +type UserDetail struct { + User UserDetail_User "json:\"user\" graphql:\"user\"" +} + +func (t *UserDetail) GetUser() *UserDetail_User { + if t == nil { + t = &UserDetail{} + } + return &t.User +} + +const UserDetailDocument = `query UserDetail { + user { + ... UserFragment + id + profile { + name + company + detail { + id + } + } + } +} +fragment UserFragment on User { + id + profile { + id + detail { + birthDate + } + } +} +` + +func (c *Client) UserDetail(ctx context.Context, interceptors ...clientv2.RequestInterceptor) (*UserDetail, error) { + vars := map[string]any{} + + var res UserDetail + if err := c.Client.Post(ctx, "UserDetail", UserDetailDocument, &res, vars, interceptors...); err != nil { + if c.Client.ParseDataWhenErrors { + return &res, err + } + + return nil, err + } + + return &res, nil +} + +var DocumentOperationNames = map[string]string{ + UserDetailDocument: "UserDetail", +} diff --git a/example/custom-fragment/model/models_gen.go b/example/custom-fragment/model/models_gen.go new file mode 100644 index 0000000..6fdd9c0 --- /dev/null +++ b/example/custom-fragment/model/models_gen.go @@ -0,0 +1,3 @@ +// Code generated by github.com/99designs/gqlgen, DO NOT EDIT. + +package generated diff --git a/example/custom-fragment/query/query.graphql b/example/custom-fragment/query/query.graphql new file mode 100644 index 0000000..aec1444 --- /dev/null +++ b/example/custom-fragment/query/query.graphql @@ -0,0 +1,23 @@ +fragment UserFragment on User { + id + profile { + id + detail { + birthDate + } + } +} + +query UserDetail { + user { + ...UserFragment + id + profile { + name + company + detail { + id + } + } + } +} diff --git a/example/custom-fragment/schema/schema.graphql b/example/custom-fragment/schema/schema.graphql new file mode 100644 index 0000000..6d85a8d --- /dev/null +++ b/example/custom-fragment/schema/schema.graphql @@ -0,0 +1,24 @@ +schema { + query: Query +} + +type Query { + user: User! +} + +type User { + id: ID! + profile: Profile! +} + +type Profile { + id: ID! + name: String! + company: String! + detail: ProfileDetail! +} + +type ProfileDetail { + id: ID! + birthDate: String! +}