Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: missing external enums #14

Merged
merged 5 commits into from
Nov 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
go-version: '^1.17'
- uses: arduino/setup-protoc@v1
with:
version: '3.17.3'
version: '3.19.1'
- name: Run go test
run: go test -v ./...
- name: Install protoc-gen-pubsub-schema
Expand Down
101 changes: 12 additions & 89 deletions content_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (b *contentBuilder) build(protoFile *descriptorpb.FileDescriptorProto) (str
compVersion := b.request.GetCompilerVersion()
fmt.Fprintf(b.output, "// Code generated by protoc-gen-pubsub-schema. DO NOT EDIT.\n")
fmt.Fprintf(b.output, "// versions:\n")
fmt.Fprintf(b.output, "// protoc-gen-pubsub-schema v1.4.3\n")
fmt.Fprintf(b.output, "// protoc-gen-pubsub-schema v1.4.4\n")
fmt.Fprintf(b.output, "// protoc v%d.%d.%d%s\n", compVersion.GetMajor(), compVersion.GetMinor(), compVersion.GetPatch(), compVersion.GetSuffix())
fmt.Fprintf(b.output, "// source: %s\n\n", protoFile.GetName())
fmt.Fprintf(b.output, "syntax = \"%s\";\n", b.schemaSyntax)
Expand All @@ -39,110 +39,33 @@ func (b *contentBuilder) build(protoFile *descriptorpb.FileDescriptorProto) (str
}

func (b *contentBuilder) buildMessages(messages []*descriptorpb.DescriptorProto, level int) {
built := make(map[*descriptorpb.DescriptorProto]bool)
for _, message := range messages {
fmt.Fprintln(b.output)
b.buildMessage(message, level)
}
}

func (b *contentBuilder) buildMessage(message *descriptorpb.DescriptorProto, level int) {
fmt.Fprintf(b.output, "%smessage %s {\n", buildIndent(level), message.GetName())
b.buildFields(message.GetField(), level+1)
b.buildMessages(message.GetNestedType(), level+1)
b.buildEnums(message.GetEnumType(), level+1)
b.buildOtherTypes(message, level+1)
fmt.Fprintf(b.output, "%s}\n", buildIndent(level))
}

func (b *contentBuilder) buildFields(fields []*descriptorpb.FieldDescriptorProto, level int) {
for _, field := range fields {
fmt.Fprint(b.output, buildIndent(level))
label := field.GetLabel()
if b.schemaSyntax == "proto2" || label == descriptorpb.FieldDescriptorProto_LABEL_REPEATED {
fmt.Fprintf(b.output, "%s ", strings.ToLower(strings.TrimPrefix(label.String(), "LABEL_")))
}
fmt.Fprintf(b.output, "%s %s = %d;\n", b.getFieldType(field), field.GetName(), field.GetNumber())
}
}

func (b *contentBuilder) getFieldType(field *descriptorpb.FieldDescriptorProto) string {
typeName := field.GetTypeName()
switch field.GetType() {
case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE:
if b.messageEncoding == "json" && wktMapping[typeName] != "" {
return wktMapping[typeName]
}
if b.isNestedType(typeName) {
return shortName(typeName)
if built[message] {
continue
}
return pascalCase(typeName)
case descriptorpb.FieldDescriptorProto_TYPE_ENUM:
return shortName(typeName)
default:
return strings.ToLower(strings.TrimPrefix(field.GetType().String(), "TYPE_"))
fmt.Fprintln(b.output)
newMessageBuilder(b, message, level).build()
built[message] = true
}
}

func (b *contentBuilder) buildEnums(enums []*descriptorpb.EnumDescriptorProto, level int) {
built := make(map[*descriptorpb.EnumDescriptorProto]bool)
for _, enum := range enums {
if built[enum] {
continue
}
fmt.Fprintln(b.output)
fmt.Fprintf(b.output, "%senum %s {\n", buildIndent(level), enum.GetName())
for _, value := range enum.GetValue() {
fmt.Fprintf(b.output, "%s%s = %d;\n", buildIndent(level+1), value.GetName(), value.GetNumber())
}
fmt.Fprintf(b.output, "%s}\n", buildIndent(level))
built[enum] = true
}
}

func (b *contentBuilder) buildOtherTypes(message *descriptorpb.DescriptorProto, level int) {
built := make(map[string]bool)
for _, field := range message.GetField() {
typeName := field.GetTypeName()
if field.GetType() != descriptorpb.FieldDescriptorProto_TYPE_MESSAGE {
continue
}
if b.messageEncoding == "json" && wktMapping[typeName] != "" {
continue
}
if b.isNestedType(typeName) {
continue
}
if built[typeName] {
continue
}
b.buildOtherType(typeName, level)
built[typeName] = true
}
}

func (b *contentBuilder) buildOtherType(typeName string, level int) {
message := b.messageTypes[typeName]
defer func(name *string) { message.Name = name }(message.Name)
*message.Name = pascalCase(typeName)
fmt.Fprintln(b.output)
b.buildMessage(message, level)
}

func (b *contentBuilder) isNestedType(name string) bool {
return b.messageTypes[name[:strings.LastIndexByte(name, '.')]] != nil
}

func buildIndent(level int) string {
return strings.Repeat(" ", level)
}

func shortName(name string) string {
return name[strings.LastIndexByte(name, '.')+1:]
}

func pascalCase(name string) string {
sb := new(strings.Builder)
for i, c := range name {
if i > 0 && name[i-1] == '.' {
sb.WriteString(strings.ToUpper(string(c)))
} else if c != '.' {
sb.WriteRune(c)
}
}
return sb.String()
}
9 changes: 9 additions & 0 deletions example/common/role.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
syntax = "proto3";

package example.common;

enum Role {
OWNER = 0;
EDITOR = 1;
VIEWER = 2;
}
18 changes: 9 additions & 9 deletions example/user_add_comment.pps
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Code generated by protoc-gen-pubsub-schema. DO NOT EDIT.
// versions:
// protoc-gen-pubsub-schema v1.4.3
// protoc v3.17.3
// protoc-gen-pubsub-schema v1.4.4
// protoc v3.19.1
// source: example/user_add_comment.proto

syntax = "proto2";
Expand All @@ -15,7 +15,7 @@ message UserAddComment {
message User {
required string first_name = 1;
optional string last_name = 2;
required Role role = 3;
required ExampleCommonRole role = 3;
optional bytes avatar = 4;
optional Location location = 5;
optional GoogleProtobufTimestamp created_at = 6;
Expand All @@ -26,16 +26,16 @@ message UserAddComment {
required double latitude = 2;
}

enum Role {
OWNER = 1;
EDITOR = 2;
VIEWER = 3;
}

message GoogleProtobufTimestamp {
optional int64 seconds = 1;
optional int32 nanos = 2;
}

enum ExampleCommonRole {
OWNER = 0;
EDITOR = 1;
VIEWER = 2;
}
}

message ExampleCommonLabel {
Expand Down
9 changes: 2 additions & 7 deletions example/user_add_comment.proto
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ syntax = "proto2";
package example;

import "example/common/label.proto";
import "example/common/role.proto";
import "google/protobuf/timestamp.proto";

message UserAddComment {
Expand All @@ -14,7 +15,7 @@ message UserAddComment {
message User {
required string first_name = 1;
optional string last_name = 2;
required Role role = 3;
required example.common.Role role = 3;
optional bytes avatar = 4;
optional Location location = 5;
optional google.protobuf.Timestamp created_at = 6;
Expand All @@ -24,11 +25,5 @@ message UserAddComment {
required double longitude = 1;
required double latitude = 2;
}

enum Role {
OWNER = 1;
EDITOR = 2;
VIEWER = 3;
}
}
}
93 changes: 93 additions & 0 deletions message_builder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package main

import (
"fmt"
"strings"

"google.golang.org/protobuf/types/descriptorpb"
)

type messageBuilder struct {
*contentBuilder
message *descriptorpb.DescriptorProto
level int
externalMessages []*descriptorpb.DescriptorProto
externalEnums []*descriptorpb.EnumDescriptorProto
}

func newMessageBuilder(b *contentBuilder, message *descriptorpb.DescriptorProto, level int) *messageBuilder {
return &messageBuilder{b, message, level, nil, nil}
}

func (b *messageBuilder) build() {
fmt.Fprintf(b.output, "%smessage %s {\n", buildIndent(b.level), b.message.GetName())
b.buildFields()
b.buildMessages(b.message.GetNestedType(), b.level+1)
b.buildEnums(b.message.GetEnumType(), b.level+1)
fmt.Fprintf(b.output, "%s}\n", buildIndent(b.level))
}

func (b *messageBuilder) buildFields() {
for _, field := range b.message.GetField() {
fmt.Fprint(b.output, buildIndent(b.level+1))
label := field.GetLabel()
if b.schemaSyntax == "proto2" || label == descriptorpb.FieldDescriptorProto_LABEL_REPEATED {
fmt.Fprintf(b.output, "%s ", strings.ToLower(strings.TrimPrefix(label.String(), "LABEL_")))
}
fmt.Fprintf(b.output, "%s %s = %d;\n", b.buildFieldType(field), field.GetName(), field.GetNumber())
}
}

func (b *messageBuilder) buildFieldType(field *descriptorpb.FieldDescriptorProto) string {
typeName := field.GetTypeName()
if b.isNestedType(field) {
return getChildName(typeName)
}
switch field.GetType() {
case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE:
if b.messageEncoding == "json" && wktMapping[typeName] != "" {
return wktMapping[typeName]
}
internalName := pascalCase(typeName)
internalMessage := b.messageTypes[field.GetTypeName()]
internalMessage.Name = &internalName
b.message.NestedType = append(b.message.NestedType, internalMessage)
return internalName
case descriptorpb.FieldDescriptorProto_TYPE_ENUM:
internalName := pascalCase(typeName)
internalEnum := b.enums[field.GetTypeName()]
internalEnum.Name = &internalName
b.message.EnumType = append(b.message.EnumType, internalEnum)
return internalName
default:
return strings.ToLower(strings.TrimPrefix(field.GetType().String(), "TYPE_"))
}
}

func (b *messageBuilder) isNestedType(field *descriptorpb.FieldDescriptorProto) bool {
return b.messageTypes[getParentName(field.GetTypeName())] == b.message
}

func getParentName(name string) string {
lastDotIndex := strings.LastIndexByte(name, '.')
if lastDotIndex == -1 {
return name
}
return name[:lastDotIndex]
}

func getChildName(name string) string {
return name[strings.LastIndexByte(name, '.')+1:]
}

func pascalCase(name string) string {
sb := new(strings.Builder)
for i, c := range name {
if i > 0 && name[i-1] == '.' {
sb.WriteString(strings.ToUpper(string(c)))
} else if c != '.' {
sb.WriteRune(c)
}
}
return sb.String()
}
63 changes: 63 additions & 0 deletions message_builder_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package main

import (
"testing"
)

func Test_getParentName(t *testing.T) {
type args struct {
name string
}
tests := []struct {
name string
args args
want string
}{
{
name: "empty name",
args: args{""},
want: "",
},
{
name: "normal name",
args: args{".example.UserAddComment.User"},
want: ".example.UserAddComment",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getParentName(tt.args.name); got != tt.want {
t.Errorf("getParentName() = %v, want %v", got, tt.want)
}
})
}
}

func Test_getChildName(t *testing.T) {
type args struct {
name string
}
tests := []struct {
name string
args args
want string
}{
{
name: "empty name",
args: args{""},
want: "",
},
{
name: "normal name",
args: args{".example.UserAddComment.User"},
want: "User",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getChildName(tt.args.name); got != tt.want {
t.Errorf("getChildName() = %v, want %v", got, tt.want)
}
})
}
}
Loading