Skip to content

Commit

Permalink
Fix: missing external enums (#14)
Browse files Browse the repository at this point in the history
* build external enums

* bump version

* upgrade github action setup-protoc

* remove some invalid tests

* refactor
  • Loading branch information
alpancs authored Nov 5, 2021
1 parent 1130134 commit 6651aa2
Show file tree
Hide file tree
Showing 11 changed files with 220 additions and 1,456 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
go-version: '^1.17'
- uses: arduino/setup-protoc@v1
with:
version: '3.17.3'
version: '3.19.1'
- name: Run go test
run: go test -v ./...
- name: Install protoc-gen-pubsub-schema
Expand Down
101 changes: 12 additions & 89 deletions content_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (b *contentBuilder) build(protoFile *descriptorpb.FileDescriptorProto) (str
compVersion := b.request.GetCompilerVersion()
fmt.Fprintf(b.output, "// Code generated by protoc-gen-pubsub-schema. DO NOT EDIT.\n")
fmt.Fprintf(b.output, "// versions:\n")
fmt.Fprintf(b.output, "// protoc-gen-pubsub-schema v1.4.3\n")
fmt.Fprintf(b.output, "// protoc-gen-pubsub-schema v1.4.4\n")
fmt.Fprintf(b.output, "// protoc v%d.%d.%d%s\n", compVersion.GetMajor(), compVersion.GetMinor(), compVersion.GetPatch(), compVersion.GetSuffix())
fmt.Fprintf(b.output, "// source: %s\n\n", protoFile.GetName())
fmt.Fprintf(b.output, "syntax = \"%s\";\n", b.schemaSyntax)
Expand All @@ -39,110 +39,33 @@ func (b *contentBuilder) build(protoFile *descriptorpb.FileDescriptorProto) (str
}

func (b *contentBuilder) buildMessages(messages []*descriptorpb.DescriptorProto, level int) {
built := make(map[*descriptorpb.DescriptorProto]bool)
for _, message := range messages {
fmt.Fprintln(b.output)
b.buildMessage(message, level)
}
}

func (b *contentBuilder) buildMessage(message *descriptorpb.DescriptorProto, level int) {
fmt.Fprintf(b.output, "%smessage %s {\n", buildIndent(level), message.GetName())
b.buildFields(message.GetField(), level+1)
b.buildMessages(message.GetNestedType(), level+1)
b.buildEnums(message.GetEnumType(), level+1)
b.buildOtherTypes(message, level+1)
fmt.Fprintf(b.output, "%s}\n", buildIndent(level))
}

func (b *contentBuilder) buildFields(fields []*descriptorpb.FieldDescriptorProto, level int) {
for _, field := range fields {
fmt.Fprint(b.output, buildIndent(level))
label := field.GetLabel()
if b.schemaSyntax == "proto2" || label == descriptorpb.FieldDescriptorProto_LABEL_REPEATED {
fmt.Fprintf(b.output, "%s ", strings.ToLower(strings.TrimPrefix(label.String(), "LABEL_")))
}
fmt.Fprintf(b.output, "%s %s = %d;\n", b.getFieldType(field), field.GetName(), field.GetNumber())
}
}

func (b *contentBuilder) getFieldType(field *descriptorpb.FieldDescriptorProto) string {
typeName := field.GetTypeName()
switch field.GetType() {
case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE:
if b.messageEncoding == "json" && wktMapping[typeName] != "" {
return wktMapping[typeName]
}
if b.isNestedType(typeName) {
return shortName(typeName)
if built[message] {
continue
}
return pascalCase(typeName)
case descriptorpb.FieldDescriptorProto_TYPE_ENUM:
return shortName(typeName)
default:
return strings.ToLower(strings.TrimPrefix(field.GetType().String(), "TYPE_"))
fmt.Fprintln(b.output)
newMessageBuilder(b, message, level).build()
built[message] = true
}
}

func (b *contentBuilder) buildEnums(enums []*descriptorpb.EnumDescriptorProto, level int) {
built := make(map[*descriptorpb.EnumDescriptorProto]bool)
for _, enum := range enums {
if built[enum] {
continue
}
fmt.Fprintln(b.output)
fmt.Fprintf(b.output, "%senum %s {\n", buildIndent(level), enum.GetName())
for _, value := range enum.GetValue() {
fmt.Fprintf(b.output, "%s%s = %d;\n", buildIndent(level+1), value.GetName(), value.GetNumber())
}
fmt.Fprintf(b.output, "%s}\n", buildIndent(level))
built[enum] = true
}
}

func (b *contentBuilder) buildOtherTypes(message *descriptorpb.DescriptorProto, level int) {
built := make(map[string]bool)
for _, field := range message.GetField() {
typeName := field.GetTypeName()
if field.GetType() != descriptorpb.FieldDescriptorProto_TYPE_MESSAGE {
continue
}
if b.messageEncoding == "json" && wktMapping[typeName] != "" {
continue
}
if b.isNestedType(typeName) {
continue
}
if built[typeName] {
continue
}
b.buildOtherType(typeName, level)
built[typeName] = true
}
}

func (b *contentBuilder) buildOtherType(typeName string, level int) {
message := b.messageTypes[typeName]
defer func(name *string) { message.Name = name }(message.Name)
*message.Name = pascalCase(typeName)
fmt.Fprintln(b.output)
b.buildMessage(message, level)
}

func (b *contentBuilder) isNestedType(name string) bool {
return b.messageTypes[name[:strings.LastIndexByte(name, '.')]] != nil
}

func buildIndent(level int) string {
return strings.Repeat(" ", level)
}

func shortName(name string) string {
return name[strings.LastIndexByte(name, '.')+1:]
}

func pascalCase(name string) string {
sb := new(strings.Builder)
for i, c := range name {
if i > 0 && name[i-1] == '.' {
sb.WriteString(strings.ToUpper(string(c)))
} else if c != '.' {
sb.WriteRune(c)
}
}
return sb.String()
}
9 changes: 9 additions & 0 deletions example/common/role.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
syntax = "proto3";

package example.common;

enum Role {
OWNER = 0;
EDITOR = 1;
VIEWER = 2;
}
18 changes: 9 additions & 9 deletions example/user_add_comment.pps
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Code generated by protoc-gen-pubsub-schema. DO NOT EDIT.
// versions:
// protoc-gen-pubsub-schema v1.4.3
// protoc v3.17.3
// protoc-gen-pubsub-schema v1.4.4
// protoc v3.19.1
// source: example/user_add_comment.proto

syntax = "proto2";
Expand All @@ -15,7 +15,7 @@ message UserAddComment {
message User {
required string first_name = 1;
optional string last_name = 2;
required Role role = 3;
required ExampleCommonRole role = 3;
optional bytes avatar = 4;
optional Location location = 5;
optional GoogleProtobufTimestamp created_at = 6;
Expand All @@ -26,16 +26,16 @@ message UserAddComment {
required double latitude = 2;
}

enum Role {
OWNER = 1;
EDITOR = 2;
VIEWER = 3;
}

message GoogleProtobufTimestamp {
optional int64 seconds = 1;
optional int32 nanos = 2;
}

enum ExampleCommonRole {
OWNER = 0;
EDITOR = 1;
VIEWER = 2;
}
}

message ExampleCommonLabel {
Expand Down
9 changes: 2 additions & 7 deletions example/user_add_comment.proto
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ syntax = "proto2";
package example;

import "example/common/label.proto";
import "example/common/role.proto";
import "google/protobuf/timestamp.proto";

message UserAddComment {
Expand All @@ -14,7 +15,7 @@ message UserAddComment {
message User {
required string first_name = 1;
optional string last_name = 2;
required Role role = 3;
required example.common.Role role = 3;
optional bytes avatar = 4;
optional Location location = 5;
optional google.protobuf.Timestamp created_at = 6;
Expand All @@ -24,11 +25,5 @@ message UserAddComment {
required double longitude = 1;
required double latitude = 2;
}

enum Role {
OWNER = 1;
EDITOR = 2;
VIEWER = 3;
}
}
}
93 changes: 93 additions & 0 deletions message_builder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package main

import (
"fmt"
"strings"

"google.golang.org/protobuf/types/descriptorpb"
)

type messageBuilder struct {
*contentBuilder
message *descriptorpb.DescriptorProto
level int
externalMessages []*descriptorpb.DescriptorProto
externalEnums []*descriptorpb.EnumDescriptorProto
}

func newMessageBuilder(b *contentBuilder, message *descriptorpb.DescriptorProto, level int) *messageBuilder {
return &messageBuilder{b, message, level, nil, nil}
}

func (b *messageBuilder) build() {
fmt.Fprintf(b.output, "%smessage %s {\n", buildIndent(b.level), b.message.GetName())
b.buildFields()
b.buildMessages(b.message.GetNestedType(), b.level+1)
b.buildEnums(b.message.GetEnumType(), b.level+1)
fmt.Fprintf(b.output, "%s}\n", buildIndent(b.level))
}

func (b *messageBuilder) buildFields() {
for _, field := range b.message.GetField() {
fmt.Fprint(b.output, buildIndent(b.level+1))
label := field.GetLabel()
if b.schemaSyntax == "proto2" || label == descriptorpb.FieldDescriptorProto_LABEL_REPEATED {
fmt.Fprintf(b.output, "%s ", strings.ToLower(strings.TrimPrefix(label.String(), "LABEL_")))
}
fmt.Fprintf(b.output, "%s %s = %d;\n", b.buildFieldType(field), field.GetName(), field.GetNumber())
}
}

func (b *messageBuilder) buildFieldType(field *descriptorpb.FieldDescriptorProto) string {
typeName := field.GetTypeName()
if b.isNestedType(field) {
return getChildName(typeName)
}
switch field.GetType() {
case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE:
if b.messageEncoding == "json" && wktMapping[typeName] != "" {
return wktMapping[typeName]
}
internalName := pascalCase(typeName)
internalMessage := b.messageTypes[field.GetTypeName()]
internalMessage.Name = &internalName
b.message.NestedType = append(b.message.NestedType, internalMessage)
return internalName
case descriptorpb.FieldDescriptorProto_TYPE_ENUM:
internalName := pascalCase(typeName)
internalEnum := b.enums[field.GetTypeName()]
internalEnum.Name = &internalName
b.message.EnumType = append(b.message.EnumType, internalEnum)
return internalName
default:
return strings.ToLower(strings.TrimPrefix(field.GetType().String(), "TYPE_"))
}
}

func (b *messageBuilder) isNestedType(field *descriptorpb.FieldDescriptorProto) bool {
return b.messageTypes[getParentName(field.GetTypeName())] == b.message
}

func getParentName(name string) string {
lastDotIndex := strings.LastIndexByte(name, '.')
if lastDotIndex == -1 {
return name
}
return name[:lastDotIndex]
}

func getChildName(name string) string {
return name[strings.LastIndexByte(name, '.')+1:]
}

func pascalCase(name string) string {
sb := new(strings.Builder)
for i, c := range name {
if i > 0 && name[i-1] == '.' {
sb.WriteString(strings.ToUpper(string(c)))
} else if c != '.' {
sb.WriteRune(c)
}
}
return sb.String()
}
63 changes: 63 additions & 0 deletions message_builder_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package main

import (
"testing"
)

func Test_getParentName(t *testing.T) {
type args struct {
name string
}
tests := []struct {
name string
args args
want string
}{
{
name: "empty name",
args: args{""},
want: "",
},
{
name: "normal name",
args: args{".example.UserAddComment.User"},
want: ".example.UserAddComment",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getParentName(tt.args.name); got != tt.want {
t.Errorf("getParentName() = %v, want %v", got, tt.want)
}
})
}
}

func Test_getChildName(t *testing.T) {
type args struct {
name string
}
tests := []struct {
name string
args args
want string
}{
{
name: "empty name",
args: args{""},
want: "",
},
{
name: "normal name",
args: args{".example.UserAddComment.User"},
want: "User",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getChildName(tt.args.name); got != tt.want {
t.Errorf("getChildName() = %v, want %v", got, tt.want)
}
})
}
}
Loading

0 comments on commit 6651aa2

Please sign in to comment.