Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a few streaming interceptor bugs #3655

Merged
merged 2 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions codegen/service/interceptors.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ func interceptorFile(svc *Data, server bool) *codegen.File {
Data: interceptors,
FuncMap: map[string]any{
"hasPrivateImplementationTypes": hasPrivateImplementationTypes,
"hasEndpointStruct": hasEndpointStruct(server),
},
})
}
Expand Down Expand Up @@ -226,6 +227,17 @@ func hasPrivateImplementationTypes(interceptors []*InterceptorData) bool {
return false
}

// hasEndpointStruct returns a function that returns true if the method has an endpoint struct
// if server is true, otherwise it returns false.
func hasEndpointStruct(server bool) func(*MethodInterceptorData) bool {
if !server {
return func(*MethodInterceptorData) bool { return false }
}
return func(m *MethodInterceptorData) bool {
return m.ServerStream != nil && m.ServerStream.EndpointStruct != ""
}
}

// collectWrappedStreams returns a slice of streams to be wrapped by interceptor wrapper functions.
func collectWrappedStreams(interceptors []*InterceptorData, server bool) []*StreamInterceptorData {
var (
Expand Down
7 changes: 5 additions & 2 deletions codegen/service/service_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,7 @@ type (
// function.
MustClose bool
// EndpointStruct is the name of the endpoint struct that holds a payload
// reference (if any) and the endpoint server stream. It is set only if the
// client sends a normal payload and server streams a result.
// reference (if any) and the endpoint server stream.
EndpointStruct string
// Kind is the kind of the stream (payload, result or bidirectional).
Kind expr.StreamKind
Expand Down Expand Up @@ -345,6 +344,9 @@ type (
// MustClose indicates whether the stream should implement the Close()
// function.
MustClose bool
// EndpointStruct is the name of the endpoint struct that holds a payload
// reference (if any) and the endpoint server stream.
EndpointStruct string
}

// AttributeData describes a single attribute.
Expand Down Expand Up @@ -1314,6 +1316,7 @@ func buildInterceptorMethodData(i *expr.InterceptorExpr, md *MethodData) *Method
RecvWithContextName: md.ServerStream.RecvWithContextName,
RecvTypeRef: md.ServerStream.RecvTypeRef,
MustClose: md.ServerStream.MustClose,
EndpointStruct: md.ServerStream.EndpointStruct,
}
}
if md.ClientStream != nil {
Expand Down
20 changes: 19 additions & 1 deletion codegen/service/templates/interceptors.go.tpl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
{{- if hasPrivateImplementationTypes . }}
// Public accessor methods for Info types
{{- range . }}

Expand Down Expand Up @@ -29,13 +28,31 @@ func (info *{{ .Name }}Info) Payload() {{ .Name }}Payload {
switch info.Method() {
{{- range .Methods }}
case "{{ .MethodName }}":
{{- if hasEndpointStruct . }}
switch pay := info.RawPayload().(type) {
case *{{ .ServerStream.EndpointStruct }}:
return &{{ .PayloadAccess }}{payload: pay.Payload}
default:
return &{{ .PayloadAccess }}{payload: pay.({{ .PayloadRef }})}
}
{{- else }}
return &{{ .PayloadAccess }}{payload: info.RawPayload().({{ .PayloadRef }})}
{{- end }}
{{- end }}
default:
return nil
}
{{- else }}
{{- if hasEndpointStruct (index .Methods 0) }}
switch pay := info.RawPayload().(type) {
case *{{ (index .Methods 0).ServerStream.EndpointStruct }}:
return &{{ (index .Methods 0).PayloadAccess }}{payload: pay.Payload}
default:
return &{{ (index .Methods 0).PayloadAccess }}{payload: pay.({{ (index .Methods 0).PayloadRef }})}
}
{{- else }}
return &{{ (index .Methods 0).PayloadAccess }}{payload: info.RawPayload().({{ (index .Methods 0).PayloadRef }})}
{{- end }}
{{- end }}
}
{{- end }}
Expand Down Expand Up @@ -131,6 +148,7 @@ func (info *{{ .Name }}Info) ServerStreamingResult() {{ .Name }}StreamingResult
{{- end }}
{{- end }}

{{- if hasPrivateImplementationTypes . }}
// Private implementation methods
{{- range . }}
{{ $interceptor := . }}
Expand Down
31 changes: 14 additions & 17 deletions codegen/service/templates/server_interceptor_wrappers.go.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,8 @@
func wrap{{ .MethodName }}{{ $interceptor.Name }}(endpoint goa.Endpoint, i ServerInterceptors) goa.Endpoint {
return func(ctx context.Context, req any) (any, error) {
{{- if or $interceptor.HasStreamingPayloadAccess $interceptor.HasStreamingResultAccess }}
{{- if $interceptor.HasPayloadAccess }}
info := &{{ $interceptor.Name }}Info{
service: "{{ $.Service }}",
method: "{{ .MethodName }}",
callType: goa.InterceptorUnary,
rawPayload: req,
}
res, err := i.{{ $interceptor.Name }}(ctx, info, endpoint)
{{- else }}
res, err := endpoint(ctx, req)
{{- end }}
if err != nil {
return res, err
}
stream := res.({{ .ServerStream.Interface }})
return &wrapped{{ .ServerStream.Interface }}{
stream := req.(*{{ .ServerStream.EndpointStruct }}).Stream
req.(*{{ .ServerStream.EndpointStruct }}).Stream = &wrapped{{ .ServerStream.Interface }}{
ctx: ctx,
{{- if $interceptor.HasStreamingResultAccess }}
sendWithContext: func(ctx context.Context, req {{ .ServerStream.SendTypeRef }}) error {
Expand Down Expand Up @@ -53,7 +39,18 @@ func wrap{{ .MethodName }}{{ $interceptor.Name }}(endpoint goa.Endpoint, i Serve
},
{{- end }}
stream: stream,
}, nil
}
{{- if $interceptor.HasPayloadAccess }}
info := &{{ $interceptor.Name }}Info{
service: "{{ $.Service }}",
method: "{{ .MethodName }}",
callType: goa.InterceptorUnary,
rawPayload: req,
}
return i.{{ $interceptor.Name }}(ctx, info, endpoint)
{{- else }}
return endpoint(ctx, req)
{{- end }}
{{- else }}
info := &{{ $interceptor.Name }}Info{
service: "{{ $.Service }}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,44 @@ func WrapMethodClientEndpoint(endpoint goa.Endpoint, i ClientInterceptors) goa.E
return endpoint
}

// Public accessor methods for Info types

// Service returns the name of the service handling the request.
func (info *Test2Info) Service() string {
return info.service
}

// Method returns the name of the method handling the request.
func (info *Test2Info) Method() string {
return info.method
}

// CallType returns the type of call the interceptor is handling.
func (info *Test2Info) CallType() goa.InterceptorCallType {
return info.callType
}

// RawPayload returns the raw payload of the request.
func (info *Test2Info) RawPayload() any {
return info.rawPayload
}

// Service returns the name of the service handling the request.
func (info *Test4Info) Service() string {
return info.service
}

// Method returns the name of the method handling the request.
func (info *Test4Info) Method() string {
return info.method
}

// CallType returns the type of call the interceptor is handling.
func (info *Test4Info) CallType() goa.InterceptorCallType {
return info.callType
}

// RawPayload returns the raw payload of the request.
func (info *Test4Info) RawPayload() any {
return info.rawPayload
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,44 @@ func WrapMethodEndpoint(endpoint goa.Endpoint, i ServerInterceptors) goa.Endpoin
return endpoint
}

// Public accessor methods for Info types

// Service returns the name of the service handling the request.
func (info *TestInfo) Service() string {
return info.service
}

// Method returns the name of the method handling the request.
func (info *TestInfo) Method() string {
return info.method
}

// CallType returns the type of call the interceptor is handling.
func (info *TestInfo) CallType() goa.InterceptorCallType {
return info.callType
}

// RawPayload returns the raw payload of the request.
func (info *TestInfo) RawPayload() any {
return info.rawPayload
}

// Service returns the name of the service handling the request.
func (info *Test3Info) Service() string {
return info.service
}

// Method returns the name of the method handling the request.
func (info *Test3Info) Method() string {
return info.method
}

// CallType returns the type of call the interceptor is handling.
func (info *Test3Info) CallType() goa.InterceptorCallType {
return info.callType
}

// RawPayload returns the raw payload of the request.
func (info *Test3Info) RawPayload() any {
return info.rawPayload
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,24 @@ func WrapMethod2Endpoint(endpoint goa.Endpoint, i ServerInterceptors) goa.Endpoi
return endpoint
}

// Public accessor methods for Info types

// Service returns the name of the service handling the request.
func (info *LoggingInfo) Service() string {
return info.service
}

// Method returns the name of the method handling the request.
func (info *LoggingInfo) Method() string {
return info.method
}

// CallType returns the type of call the interceptor is handling.
func (info *LoggingInfo) CallType() goa.InterceptorCallType {
return info.callType
}

// RawPayload returns the raw payload of the request.
func (info *LoggingInfo) RawPayload() any {
return info.rawPayload
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,24 @@ func WrapMethodClientEndpoint(endpoint goa.Endpoint, i ClientInterceptors) goa.E
return endpoint
}

// Public accessor methods for Info types

// Service returns the name of the service handling the request.
func (info *TracingInfo) Service() string {
return info.service
}

// Method returns the name of the method handling the request.
func (info *TracingInfo) Method() string {
return info.method
}

// CallType returns the type of call the interceptor is handling.
func (info *TracingInfo) CallType() goa.InterceptorCallType {
return info.callType
}

// RawPayload returns the raw payload of the request.
func (info *TracingInfo) RawPayload() any {
return info.rawPayload
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,24 @@ func WrapMethodEndpoint(endpoint goa.Endpoint, i ServerInterceptors) goa.Endpoin
return endpoint
}

// Public accessor methods for Info types

// Service returns the name of the service handling the request.
func (info *LoggingInfo) Service() string {
return info.service
}

// Method returns the name of the method handling the request.
func (info *LoggingInfo) Method() string {
return info.method
}

// CallType returns the type of call the interceptor is handling.
func (info *LoggingInfo) CallType() goa.InterceptorCallType {
return info.callType
}

// RawPayload returns the raw payload of the request.
func (info *LoggingInfo) RawPayload() any {
return info.rawPayload
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,24 @@ func WrapMethod2Endpoint(endpoint goa.Endpoint, i ServerInterceptors) goa.Endpoi
return endpoint
}

// Public accessor methods for Info types

// Service returns the name of the service handling the request.
func (info *LoggingInfo) Service() string {
return info.service
}

// Method returns the name of the method handling the request.
func (info *LoggingInfo) Method() string {
return info.method
}

// CallType returns the type of call the interceptor is handling.
func (info *LoggingInfo) CallType() goa.InterceptorCallType {
return info.callType
}

// RawPayload returns the raw payload of the request.
func (info *LoggingInfo) RawPayload() any {
return info.rawPayload
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,8 @@ type wrappedMethodClientStream struct {
// wrapLoggingMethod applies the logging server interceptor to endpoints.
func wrapMethodLogging(endpoint goa.Endpoint, i ServerInterceptors) goa.Endpoint {
return func(ctx context.Context, req any) (any, error) {
info := &LoggingInfo{
service: "StreamingInterceptorsWithReadPayloadAndReadStreamingPayload",
method: "Method",
callType: goa.InterceptorUnary,
rawPayload: req,
}
res, err := i.Logging(ctx, info, endpoint)
if err != nil {
return res, err
}
stream := res.(MethodServerStream)
return &wrappedMethodServerStream{
stream := req.(*MethodEndpointInput).Stream
req.(*MethodEndpointInput).Stream = &wrappedMethodServerStream{
ctx: ctx,
recvWithContext: func(ctx context.Context) (*MethodStreamingPayload, error) {
info := &LoggingInfo{
Expand All @@ -45,7 +35,14 @@ func wrapMethodLogging(endpoint goa.Endpoint, i ServerInterceptors) goa.Endpoint
return castRes, err
},
stream: stream,
}, nil
}
info := &LoggingInfo{
service: "StreamingInterceptorsWithReadPayloadAndReadStreamingPayload",
method: "Method",
callType: goa.InterceptorUnary,
rawPayload: req,
}
return i.Logging(ctx, info, endpoint)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,12 @@ func (info *LoggingInfo) RawPayload() any {

// Payload returns a type-safe accessor for the method payload.
func (info *LoggingInfo) Payload() LoggingPayload {
return &loggingMethodPayload{payload: info.RawPayload().(*MethodPayload)}
switch pay := info.RawPayload().(type) {
case *MethodEndpointInput:
return &loggingMethodPayload{payload: pay.Payload}
default:
return &loggingMethodPayload{payload: pay.(*MethodPayload)}
}
}

// ClientStreamingPayload returns a type-safe accessor for the method streaming payload for a client-side interceptor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,12 @@ func (info *LoggingInfo) RawPayload() any {

// Payload returns a type-safe accessor for the method payload.
func (info *LoggingInfo) Payload() LoggingPayload {
return &loggingMethodPayload{payload: info.RawPayload().(*MethodPayload)}
switch pay := info.RawPayload().(type) {
case *MethodEndpointInput:
return &loggingMethodPayload{payload: pay.Payload}
default:
return &loggingMethodPayload{payload: pay.(*MethodPayload)}
}
}

// Private implementation methods
Expand Down
Loading
Loading