diff --git a/common/nexus/failure.go b/common/nexus/failure.go index dd3c5da4b89..2608647114b 100644 --- a/common/nexus/failure.go +++ b/common/nexus/failure.go @@ -23,9 +23,11 @@ package nexus import ( + "context" "encoding/json" "errors" "net/http" + "sync/atomic" "github.com/nexus-rpc/sdk-go/nexus" commonpb "go.temporal.io/api/common/v1" @@ -48,6 +50,26 @@ type failureSourceContextKeyType struct{} var FailureSourceContextKey = failureSourceContextKeyType{} +func SetFailureSourceOnContext(ctx context.Context, response *http.Response) { + if response == nil || response.Header == nil { + return + } + + failureSourceHeader := response.Header.Get(FailureSourceHeaderName) + if failureSourceHeader == "" { + return + } + + failureSourceContext := ctx.Value(FailureSourceContextKey) + if failureSourceContext == nil { + return + } + + if val, ok := failureSourceContext.(*atomic.Value); ok { + val.Store(failureSourceHeader) + } +} + var failureTypeString = string((&failurepb.Failure{}).ProtoReflect().Descriptor().FullName()) // ProtoFailureToNexusFailure converts a proto Nexus Failure to a Nexus SDK Failure. diff --git a/components/nexusoperations/fx.go b/components/nexusoperations/fx.go index e8c66d2e24b..d17a81c55f5 100644 --- a/components/nexusoperations/fx.go +++ b/components/nexusoperations/fx.go @@ -26,7 +26,6 @@ import ( "context" "fmt" "net/http" - "sync/atomic" "github.com/nexus-rpc/sdk-go/nexus" "go.temporal.io/api/serviceerror" @@ -140,7 +139,7 @@ func ClientProviderFactory( httpCaller = func(r *http.Request) (*http.Response, error) { r.Header.Set(NexusCallbackSourceHeader, clusterInfo.ClusterID) resp, callErr := httpClient.Do(r) - setFailureSourceOnContext(ctx, resp) + commonnexus.SetFailureSourceOnContext(ctx, resp) return resp, callErr } } @@ -158,23 +157,3 @@ func ClientProviderFactory( func CallbackTokenGeneratorProvider() *commonnexus.CallbackTokenGenerator { return commonnexus.NewCallbackTokenGenerator() } - -func setFailureSourceOnContext(ctx context.Context, response *http.Response) { - if response == nil || response.Header == nil { - return - } - - failureSourceHeader := response.Header.Get(commonnexus.FailureSourceHeaderName) - if failureSourceHeader == "" { - return - } - - failureSourceContext := ctx.Value(commonnexus.FailureSourceContextKey) - if failureSourceContext == nil { - return - } - - if val, ok := failureSourceContext.(*atomic.Value); ok { - val.Store(failureSourceHeader) - } -}