From 567207b824217cecd4c7b643d0232174b915657f Mon Sep 17 00:00:00 2001 From: rkodev <43806892+rkodev@users.noreply.github.com> Date: Thu, 26 Jan 2023 19:27:59 +0300 Subject: [PATCH 1/2] Force revert --- .github/workflows/go.yml | 2 +- .github/workflows/sonarcloud.yml | 2 +- CHANGELOG.md | 6 ++++++ go.mod | 2 +- go.sum | 4 ++-- 5 files changed, 11 insertions(+), 5 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index a38e498..a0ba281 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -23,5 +23,5 @@ jobs: run: go build working-directory: ${{ env.relativePath }} - name: Test project - run: go test + run: go test ./... working-directory: ${{ env.relativePath }} diff --git a/.github/workflows/sonarcloud.yml b/.github/workflows/sonarcloud.yml index df21501..52d3b0a 100644 --- a/.github/workflows/sonarcloud.yml +++ b/.github/workflows/sonarcloud.yml @@ -24,7 +24,7 @@ jobs: run: go build working-directory: ${{ env.relativePath }} - name: Run unit tests - run: go test -o result.out -coverprofile cover.out + run: go test -coverprofile cover.out -coverpkg=./... ./... working-directory: ${{ env.relativePath }} - name: SonarCloud Scan uses: SonarSource/sonarcloud-github-action@master diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e3b6ed..961db70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +## [0.14.0] - 2023-01-25 + +### Added + +- Added implementation methods for backing store. + ## [0.13.0] - 2023-01-10 ### Added diff --git a/go.mod b/go.mod index dfc7254..99c8629 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.18 require ( github.com/google/uuid v1.3.0 - github.com/microsoft/kiota-abstractions-go v0.16.0 + github.com/microsoft/kiota-abstractions-go v0.17.0 github.com/stretchr/testify v1.8.1 go.opentelemetry.io/otel v1.11.2 go.opentelemetry.io/otel/trace v1.11.2 diff --git a/go.sum b/go.sum index d2aada5..0f594da 100644 --- a/go.sum +++ b/go.sum @@ -15,8 +15,8 @@ github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/microsoft/kiota-abstractions-go v0.16.0 h1:DZ1L4YsRsQw39iPGnVq2fQkqLXMsazdPwmWsnaH4EZg= -github.com/microsoft/kiota-abstractions-go v0.16.0/go.mod h1:RT/s9sCzg49i4iO7e2qhyWmX+DlJDgC0P+Wp8fKQQfo= +github.com/microsoft/kiota-abstractions-go v0.17.0 h1:Ye2DTk8ko9Na0uCvhcCV7TQPWt72trT+kyD37btDtsI= +github.com/microsoft/kiota-abstractions-go v0.17.0/go.mod h1:RT/s9sCzg49i4iO7e2qhyWmX+DlJDgC0P+Wp8fKQQfo= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= From bac6b1ab0961e96205404d1caa9f334e8c76f658 Mon Sep 17 00:00:00 2001 From: rkodev <43806892+rkodev@users.noreply.github.com> Date: Thu, 26 Jan 2023 19:43:46 +0300 Subject: [PATCH 2/2] Enable backing store --- internal/mock_parse_node_factory.go | 20 + nethttp_request_adapter.go | 1597 ++++++++++++++------------- nethttp_request_adapter_test.go | 507 ++++----- 3 files changed, 1082 insertions(+), 1042 deletions(-) diff --git a/internal/mock_parse_node_factory.go b/internal/mock_parse_node_factory.go index 1a77f7d..20a083e 100644 --- a/internal/mock_parse_node_factory.go +++ b/internal/mock_parse_node_factory.go @@ -20,6 +20,26 @@ func (e *MockParseNodeFactory) GetRootParseNode(contentType string, content []by type MockParseNode struct { } +func (e *MockParseNode) GetOnBeforeAssignFieldValues() absser.ParsableAction { + //TODO implement me + panic("implement me") +} + +func (e *MockParseNode) SetOnBeforeAssignFieldValues(action absser.ParsableAction) error { + //TODO implement me + panic("implement me") +} + +func (e *MockParseNode) GetOnAfterAssignFieldValues() absser.ParsableAction { + //TODO implement me + panic("implement me") +} + +func (e *MockParseNode) SetOnAfterAssignFieldValues(action absser.ParsableAction) error { + //TODO implement me + panic("implement me") +} + func (*MockParseNode) GetRawValue() (interface{}, error) { return nil, nil } diff --git a/nethttp_request_adapter.go b/nethttp_request_adapter.go index d5ed538..816bf96 100644 --- a/nethttp_request_adapter.go +++ b/nethttp_request_adapter.go @@ -1,796 +1,801 @@ -package nethttplibrary - -import ( - "bytes" - "context" - "errors" - "io" - "io/ioutil" - nethttp "net/http" - "reflect" - "regexp" - "strconv" - "strings" - "time" - - abs "github.com/microsoft/kiota-abstractions-go" - absauth "github.com/microsoft/kiota-abstractions-go/authentication" - absser "github.com/microsoft/kiota-abstractions-go/serialization" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/codes" - "go.opentelemetry.io/otel/trace" -) - -// nopCloser is an alternate io.nopCloser implementation which -// provides io.ReadSeekCloser instead of io.ReadCloser as we need -// Seek for retries -type nopCloser struct { - io.ReadSeeker -} - -func NopCloser(r io.ReadSeeker) io.ReadSeekCloser { - return nopCloser{r} -} - -func (nopCloser) Close() error { return nil } - -// NetHttpRequestAdapter implements the RequestAdapter interface using net/http -type NetHttpRequestAdapter struct { - // serializationWriterFactory is the factory used to create serialization writers - serializationWriterFactory absser.SerializationWriterFactory - // parseNodeFactory is the factory used to create parse nodes - parseNodeFactory absser.ParseNodeFactory - // httpClient is the client used to send requests - httpClient *nethttp.Client - // authenticationProvider is the provider used to authenticate requests - authenticationProvider absauth.AuthenticationProvider - // The base url for every request. - baseUrl string - // The observation options for the request adapter. - observabilityOptions ObservabilityOptions -} - -// NewNetHttpRequestAdapter creates a new NetHttpRequestAdapter with the given parameters -func NewNetHttpRequestAdapter(authenticationProvider absauth.AuthenticationProvider) (*NetHttpRequestAdapter, error) { - return NewNetHttpRequestAdapterWithParseNodeFactory(authenticationProvider, nil) -} - -// NewNetHttpRequestAdapterWithParseNodeFactory creates a new NetHttpRequestAdapter with the given parameters -func NewNetHttpRequestAdapterWithParseNodeFactory(authenticationProvider absauth.AuthenticationProvider, parseNodeFactory absser.ParseNodeFactory) (*NetHttpRequestAdapter, error) { - return NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactory(authenticationProvider, parseNodeFactory, nil) -} - -// NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactory creates a new NetHttpRequestAdapter with the given parameters -func NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactory(authenticationProvider absauth.AuthenticationProvider, parseNodeFactory absser.ParseNodeFactory, serializationWriterFactory absser.SerializationWriterFactory) (*NetHttpRequestAdapter, error) { - return NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClient(authenticationProvider, parseNodeFactory, serializationWriterFactory, nil) -} - -// NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClient creates a new NetHttpRequestAdapter with the given parameters -func NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClient(authenticationProvider absauth.AuthenticationProvider, parseNodeFactory absser.ParseNodeFactory, serializationWriterFactory absser.SerializationWriterFactory, httpClient *nethttp.Client) (*NetHttpRequestAdapter, error) { - return NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClientAndObservabilityOptions(authenticationProvider, parseNodeFactory, serializationWriterFactory, httpClient, ObservabilityOptions{}) -} - -// NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClientAndObservabilityOptions creates a new NetHttpRequestAdapter with the given parameters -func NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClientAndObservabilityOptions(authenticationProvider absauth.AuthenticationProvider, parseNodeFactory absser.ParseNodeFactory, serializationWriterFactory absser.SerializationWriterFactory, httpClient *nethttp.Client, observabilityOptions ObservabilityOptions) (*NetHttpRequestAdapter, error) { - if authenticationProvider == nil { - return nil, errors.New("authenticationProvider cannot be nil") - } - result := &NetHttpRequestAdapter{ - serializationWriterFactory: serializationWriterFactory, - parseNodeFactory: parseNodeFactory, - httpClient: httpClient, - authenticationProvider: authenticationProvider, - baseUrl: "", - observabilityOptions: observabilityOptions, - } - if result.httpClient == nil { - defaultClient := GetDefaultClient() - result.httpClient = defaultClient - } - if result.serializationWriterFactory == nil { - result.serializationWriterFactory = absser.DefaultSerializationWriterFactoryInstance - } - if result.parseNodeFactory == nil { - result.parseNodeFactory = absser.DefaultParseNodeFactoryInstance - } - return result, nil -} - -// GetSerializationWriterFactory returns the serialization writer factory currently in use for the request adapter service. -func (a *NetHttpRequestAdapter) GetSerializationWriterFactory() absser.SerializationWriterFactory { - return a.serializationWriterFactory -} - -// EnableBackingStore enables the backing store proxies for the SerializationWriters and ParseNodes in use. -func (a *NetHttpRequestAdapter) EnableBackingStore() { - //TODO implement when backing store is available for go -} - -// SetBaseUrl sets the base url for every request. -func (a *NetHttpRequestAdapter) SetBaseUrl(baseUrl string) { - a.baseUrl = baseUrl -} - -// GetBaseUrl gets the base url for every request. -func (a *NetHttpRequestAdapter) GetBaseUrl() string { - return a.baseUrl -} - -func (a *NetHttpRequestAdapter) getHttpResponseMessage(ctx context.Context, requestInfo *abs.RequestInformation, claims string, spanForAttributes trace.Span) (*nethttp.Response, error) { - ctx, span := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "getHttpResponseMessage") - defer span.End() - if ctx == nil { - ctx = context.Background() - } - a.setBaseUrlForRequestInformation(requestInfo) - additionalContext := make(map[string]any) - if claims != "" { - additionalContext[claimsKey] = claims - } - err := a.authenticationProvider.AuthenticateRequest(ctx, requestInfo, additionalContext) - if err != nil { - return nil, err - } - request, err := a.getRequestFromRequestInformation(ctx, requestInfo, spanForAttributes) - if err != nil { - return nil, err - } - response, err := (*a.httpClient).Do(request) - if err != nil { - spanForAttributes.RecordError(err) - return nil, err - } - if response != nil { - contentLenHeader := response.Header.Get("Content-Length") - if contentLenHeader != "" { - contentLen, _ := strconv.Atoi(contentLenHeader) - spanForAttributes.SetAttributes(attribute.Int("http.response_content_length", contentLen)) - } - contentTypeHeader := response.Header.Get("Content-Type") - if contentTypeHeader != "" { - spanForAttributes.SetAttributes(attribute.String("http.response_content_type", contentTypeHeader)) - } - spanForAttributes.SetAttributes( - attribute.Int("http.status_code", response.StatusCode), - attribute.String("http.flavor", response.Proto), - ) - } - return a.retryCAEResponseIfRequired(ctx, response, requestInfo, claims, spanForAttributes) -} - -const claimsKey = "claims" - -var reBearer = regexp.MustCompile(`(?i)^Bearer\s`) -var reClaims = regexp.MustCompile(`\"([^\"]*)\"`) - -const AuthenticateChallengedEventKey = "com.microsoft.kiota.authenticate_challenge_received" - -func (a *NetHttpRequestAdapter) retryCAEResponseIfRequired(ctx context.Context, response *nethttp.Response, requestInfo *abs.RequestInformation, claims string, spanForAttributes trace.Span) (*nethttp.Response, error) { - ctx, span := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "retryCAEResponseIfRequired") - defer span.End() - if response.StatusCode == 401 && - claims == "" { //avoid infinite loop, we only retry once - authenticateHeaderVal := response.Header.Get("WWW-Authenticate") - if authenticateHeaderVal != "" && reBearer.Match([]byte(authenticateHeaderVal)) { - span.AddEvent(AuthenticateChallengedEventKey) - spanForAttributes.SetAttributes(attribute.Int("http.retry_count", 1)) - responseClaims := "" - parametersRaw := string(reBearer.ReplaceAll([]byte(authenticateHeaderVal), []byte(""))) - parameters := strings.Split(parametersRaw, ",") - for _, parameter := range parameters { - if strings.HasPrefix(strings.Trim(parameter, " "), claimsKey) { - responseClaims = reClaims.FindStringSubmatch(parameter)[1] - break - } - } - if responseClaims != "" { - defer a.purge(response) - return a.getHttpResponseMessage(ctx, requestInfo, responseClaims, spanForAttributes) - } - } - } - return response, nil -} - -func (a *NetHttpRequestAdapter) getResponsePrimaryContentType(response *nethttp.Response) string { - if response.Header == nil { - return "" - } - rawType := response.Header.Get("Content-Type") - splat := strings.Split(rawType, ";") - return strings.ToLower(splat[0]) -} - -func (a *NetHttpRequestAdapter) setBaseUrlForRequestInformation(requestInfo *abs.RequestInformation) { - requestInfo.PathParameters["baseurl"] = a.GetBaseUrl() -} - -const requestTimeOutInSeconds = 100 - -func (a *NetHttpRequestAdapter) prepareContext(ctx context.Context, requestInfo *abs.RequestInformation) context.Context { - if ctx == nil { - ctx = context.Background() - } - // set deadline if not set in receiving context - if _, deadlineSet := ctx.Deadline(); !deadlineSet { - ctx, _ = context.WithTimeout(ctx, time.Second*requestTimeOutInSeconds) - } - - for _, value := range requestInfo.GetRequestOptions() { - ctx = context.WithValue(ctx, value.GetKey(), value) - } - obsOptionsSet := false - if reqObsOpt := ctx.Value(observabilityOptionsKeyValue); reqObsOpt != nil { - if _, ok := reqObsOpt.(ObservabilityOptionsInt); ok { - obsOptionsSet = true - } - } - if !obsOptionsSet { - ctx = context.WithValue(ctx, observabilityOptionsKeyValue, &a.observabilityOptions) - } - return ctx -} - -// ConvertToNativeRequest converts the given RequestInformation into a native HTTP request. -func (a *NetHttpRequestAdapter) ConvertToNativeRequest(context context.Context, requestInfo *abs.RequestInformation) (any, error) { - err := a.authenticationProvider.AuthenticateRequest(context, requestInfo, nil) - if err != nil { - return nil, err - } - request, err := a.getRequestFromRequestInformation(context, requestInfo, nil) - if err != nil { - return nil, err - } - return request, nil -} - -func (a *NetHttpRequestAdapter) getRequestFromRequestInformation(ctx context.Context, requestInfo *abs.RequestInformation, spanForAttributes trace.Span) (*nethttp.Request, error) { - ctx, span := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "getRequestFromRequestInformation") - defer span.End() - if spanForAttributes == nil { - spanForAttributes = span - } - spanForAttributes.SetAttributes(attribute.String("http.method", requestInfo.Method.String())) - uri, err := requestInfo.GetUri() - if err != nil { - spanForAttributes.RecordError(err) - return nil, err - } - spanForAttributes.SetAttributes( - attribute.String("http.scheme", uri.Scheme), - attribute.String("http.host", uri.Host), - ) - - if a.observabilityOptions.IncludeEUIIAttributes { - spanForAttributes.SetAttributes(attribute.String("http.uri", uri.String())) - } - - request, err := nethttp.NewRequestWithContext(ctx, requestInfo.Method.String(), uri.String(), nil) - - if err != nil { - spanForAttributes.RecordError(err) - return nil, err - } - if len(requestInfo.Content) > 0 { - reader := bytes.NewReader(requestInfo.Content) - request.Body = NopCloser(reader) - } - if request.Header == nil { - request.Header = make(nethttp.Header) - } - if requestInfo.Headers != nil { - for _, key := range requestInfo.Headers.ListKeys() { - values := requestInfo.Headers.Get(key) - for _, v := range values { - request.Header.Add(key, v) - } - } - if request.Header.Get("Content-Type") != "" { - spanForAttributes.SetAttributes( - attribute.String("http.request_content_type", request.Header.Get("Content-Type")), - ) - } - if request.Header.Get("Content-Length") != "" { - contentLenVal, _ := strconv.Atoi(request.Header.Get("Content-Length")) - spanForAttributes.SetAttributes( - attribute.Int("http.request_content_length", contentLenVal), - ) - } - } - - return request, nil -} - -const EventResponseHandlerInvokedKey = "com.microsoft.kiota.response_handler_invoked" - -var queryParametersCleanupRegex = regexp.MustCompile(`\{\?[^\}]+}`) - -func (a *NetHttpRequestAdapter) startTracingSpan(ctx context.Context, requestInfo *abs.RequestInformation, methodName string) (context.Context, trace.Span) { - decodedUriTemplate := decodeUriEncodedString(requestInfo.UrlTemplate, []byte{'-', '.', '~', '$'}) - telemetryPathValue := queryParametersCleanupRegex.ReplaceAll([]byte(decodedUriTemplate), []byte("")) - ctx, span := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, methodName+" - "+string(telemetryPathValue)) - span.SetAttributes(attribute.String("http.uri_template", decodedUriTemplate)) - return ctx, span -} - -// Send executes the HTTP request specified by the given RequestInformation and returns the deserialized response model. -func (a *NetHttpRequestAdapter) Send(ctx context.Context, requestInfo *abs.RequestInformation, constructor absser.ParsableFactory, errorMappings abs.ErrorMappings) (absser.Parsable, error) { - if requestInfo == nil { - return nil, errors.New("requestInfo cannot be nil") - } - ctx = a.prepareContext(ctx, requestInfo) - ctx, span := a.startTracingSpan(ctx, requestInfo, "Send") - defer span.End() - response, err := a.getHttpResponseMessage(ctx, requestInfo, "", span) - if err != nil { - return nil, err - } - - responseHandler := getResponseHandler(ctx) - if responseHandler != nil { - span.AddEvent(EventResponseHandlerInvokedKey) - result, err := responseHandler(response, errorMappings) - if err != nil { - span.RecordError(err) - return nil, err - } - return result.(absser.Parsable), nil - } else if response != nil { - defer a.purge(response) - err = a.throwIfFailedResponse(ctx, response, errorMappings, span) - if err != nil { - return nil, err - } - if a.shouldReturnNil(response) { - return nil, nil - } - parseNode, _, err := a.getRootParseNode(ctx, response, span) - if err != nil { - return nil, err - } - if parseNode == nil { - return nil, nil - } - _, deserializeSpan := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "GetObjectValue") - defer deserializeSpan.End() - result, err := parseNode.GetObjectValue(constructor) - a.setResponseType(result, span) - if err != nil { - span.RecordError(err) - } - return result, err - } else { - return nil, errors.New("response is nil") - } -} - -func (a *NetHttpRequestAdapter) setResponseType(result any, span trace.Span) { - if result != nil { - span.SetAttributes(attribute.String("com.microsoft.kiota.response.type", reflect.TypeOf(result).String())) - } -} - -// SendEnum executes the HTTP request specified by the given RequestInformation and returns the deserialized response model. -func (a *NetHttpRequestAdapter) SendEnum(ctx context.Context, requestInfo *abs.RequestInformation, parser absser.EnumFactory, errorMappings abs.ErrorMappings) (any, error) { - if requestInfo == nil { - return nil, errors.New("requestInfo cannot be nil") - } - ctx = a.prepareContext(ctx, requestInfo) - ctx, span := a.startTracingSpan(ctx, requestInfo, "SendEnum") - defer span.End() - response, err := a.getHttpResponseMessage(ctx, requestInfo, "", span) - if err != nil { - return nil, err - } - - responseHandler := getResponseHandler(ctx) - if responseHandler != nil { - span.AddEvent(EventResponseHandlerInvokedKey) - result, err := responseHandler(response, errorMappings) - if err != nil { - span.RecordError(err) - return nil, err - } - return result.(absser.Parsable), nil - } else if response != nil { - defer a.purge(response) - err = a.throwIfFailedResponse(ctx, response, errorMappings, span) - if err != nil { - return nil, err - } - if a.shouldReturnNil(response) { - return nil, nil - } - parseNode, _, err := a.getRootParseNode(ctx, response, span) - if err != nil { - return nil, err - } - if parseNode == nil { - return nil, nil - } - _, deserializeSpan := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "GetEnumValue") - defer deserializeSpan.End() - result, err := parseNode.GetEnumValue(parser) - a.setResponseType(result, span) - if err != nil { - span.RecordError(err) - } - return result, err - } else { - return nil, errors.New("response is nil") - } -} - -// SendCollection executes the HTTP request specified by the given RequestInformation and returns the deserialized response model collection. -func (a *NetHttpRequestAdapter) SendCollection(ctx context.Context, requestInfo *abs.RequestInformation, constructor absser.ParsableFactory, errorMappings abs.ErrorMappings) ([]absser.Parsable, error) { - if requestInfo == nil { - return nil, errors.New("requestInfo cannot be nil") - } - ctx = a.prepareContext(ctx, requestInfo) - ctx, span := a.startTracingSpan(ctx, requestInfo, "SendCollection") - defer span.End() - response, err := a.getHttpResponseMessage(ctx, requestInfo, "", span) - if err != nil { - return nil, err - } - - responseHandler := getResponseHandler(ctx) - if responseHandler != nil { - span.AddEvent(EventResponseHandlerInvokedKey) - result, err := responseHandler(response, errorMappings) - if err != nil { - span.RecordError(err) - return nil, err - } - return result.([]absser.Parsable), nil - } else if response != nil { - defer a.purge(response) - err = a.throwIfFailedResponse(ctx, response, errorMappings, span) - if err != nil { - return nil, err - } - if a.shouldReturnNil(response) { - return nil, nil - } - parseNode, _, err := a.getRootParseNode(ctx, response, span) - if err != nil { - return nil, err - } - if parseNode == nil { - return nil, nil - } - _, deserializeSpan := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "GetCollectionOfObjectValues") - defer deserializeSpan.End() - result, err := parseNode.GetCollectionOfObjectValues(constructor) - a.setResponseType(result, span) - if err != nil { - span.RecordError(err) - } - return result, err - } else { - return nil, errors.New("response is nil") - } -} - -// SendEnumCollection executes the HTTP request specified by the given RequestInformation and returns the deserialized response model collection. -func (a *NetHttpRequestAdapter) SendEnumCollection(ctx context.Context, requestInfo *abs.RequestInformation, parser absser.EnumFactory, errorMappings abs.ErrorMappings) ([]any, error) { - if requestInfo == nil { - return nil, errors.New("requestInfo cannot be nil") - } - ctx = a.prepareContext(ctx, requestInfo) - ctx, span := a.startTracingSpan(ctx, requestInfo, "SendEnumCollection") - defer span.End() - response, err := a.getHttpResponseMessage(ctx, requestInfo, "", span) - if err != nil { - return nil, err - } - - responseHandler := getResponseHandler(ctx) - if responseHandler != nil { - span.AddEvent(EventResponseHandlerInvokedKey) - result, err := responseHandler(response, errorMappings) - if err != nil { - span.RecordError(err) - return nil, err - } - return result.([]any), nil - } else if response != nil { - defer a.purge(response) - err = a.throwIfFailedResponse(ctx, response, errorMappings, span) - if err != nil { - return nil, err - } - if a.shouldReturnNil(response) { - return nil, nil - } - parseNode, _, err := a.getRootParseNode(ctx, response, span) - if err != nil { - return nil, err - } - if parseNode == nil { - return nil, nil - } - _, deserializeSpan := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "GetCollectionOfEnumValues") - defer deserializeSpan.End() - result, err := parseNode.GetCollectionOfEnumValues(parser) - a.setResponseType(result, span) - if err != nil { - span.RecordError(err) - } - return result, err - } else { - return nil, errors.New("response is nil") - } -} - -func getResponseHandler(ctx context.Context) abs.ResponseHandler { - var handlerOption = ctx.Value(abs.ResponseHandlerOptionKey) - if handlerOption != nil { - return handlerOption.(abs.RequestHandlerOption).GetResponseHandler() - } - return nil -} - -// SendPrimitive executes the HTTP request specified by the given RequestInformation and returns the deserialized primitive response model. -func (a *NetHttpRequestAdapter) SendPrimitive(ctx context.Context, requestInfo *abs.RequestInformation, typeName string, errorMappings abs.ErrorMappings) (any, error) { - if requestInfo == nil { - return nil, errors.New("requestInfo cannot be nil") - } - ctx = a.prepareContext(ctx, requestInfo) - ctx, span := a.startTracingSpan(ctx, requestInfo, "SendPrimitive") - defer span.End() - response, err := a.getHttpResponseMessage(ctx, requestInfo, "", span) - if err != nil { - return nil, err - } - - responseHandler := getResponseHandler(ctx) - if responseHandler != nil { - span.AddEvent(EventResponseHandlerInvokedKey) - result, err := responseHandler(response, errorMappings) - if err != nil { - span.RecordError(err) - return nil, err - } - return result.(absser.Parsable), nil - } else if response != nil { - defer a.purge(response) - err = a.throwIfFailedResponse(ctx, response, errorMappings, span) - if err != nil { - return nil, err - } - if a.shouldReturnNil(response) { - return nil, nil - } - if typeName == "[]byte" { - res, err := ioutil.ReadAll(response.Body) - if err != nil { - span.RecordError(err) - return nil, err - } else if len(res) == 0 { - return nil, nil - } - return res, nil - } - parseNode, _, err := a.getRootParseNode(ctx, response, span) - if err != nil { - return nil, err - } - if parseNode == nil { - return nil, nil - } - _, deserializeSpan := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "Get"+typeName+"Value") - defer deserializeSpan.End() - var result any - switch typeName { - case "string": - result, err = parseNode.GetStringValue() - case "float32": - result, err = parseNode.GetFloat32Value() - case "float64": - result, err = parseNode.GetFloat64Value() - case "int32": - result, err = parseNode.GetInt32Value() - case "int64": - result, err = parseNode.GetInt64Value() - case "bool": - result, err = parseNode.GetBoolValue() - case "Time": - result, err = parseNode.GetTimeValue() - case "UUID": - result, err = parseNode.GetUUIDValue() - default: - return nil, errors.New("unsupported type") - } - a.setResponseType(result, span) - if err != nil { - span.RecordError(err) - } - return result, err - } else { - return nil, errors.New("response is nil") - } -} - -// SendPrimitiveCollection executes the HTTP request specified by the given RequestInformation and returns the deserialized primitive response model collection. -func (a *NetHttpRequestAdapter) SendPrimitiveCollection(ctx context.Context, requestInfo *abs.RequestInformation, typeName string, errorMappings abs.ErrorMappings) ([]any, error) { - if requestInfo == nil { - return nil, errors.New("requestInfo cannot be nil") - } - ctx = a.prepareContext(ctx, requestInfo) - ctx, span := a.startTracingSpan(ctx, requestInfo, "SendPrimitiveCollection") - defer span.End() - response, err := a.getHttpResponseMessage(ctx, requestInfo, "", span) - if err != nil { - return nil, err - } - - responseHandler := getResponseHandler(ctx) - if responseHandler != nil { - span.AddEvent(EventResponseHandlerInvokedKey) - result, err := responseHandler(response, errorMappings) - if err != nil { - span.RecordError(err) - return nil, err - } - return result.([]any), nil - } else if response != nil { - defer a.purge(response) - err = a.throwIfFailedResponse(ctx, response, errorMappings, span) - if err != nil { - return nil, err - } - if a.shouldReturnNil(response) { - return nil, nil - } - parseNode, _, err := a.getRootParseNode(ctx, response, span) - if err != nil { - return nil, err - } - if parseNode == nil { - return nil, nil - } - _, deserializeSpan := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "GetCollectionOfPrimitiveValues") - defer deserializeSpan.End() - result, err := parseNode.GetCollectionOfPrimitiveValues(typeName) - a.setResponseType(result, span) - if err != nil { - span.RecordError(err) - } - return result, err - } else { - return nil, errors.New("response is nil") - } -} - -// SendNoContent executes the HTTP request specified by the given RequestInformation with no return content. -func (a *NetHttpRequestAdapter) SendNoContent(ctx context.Context, requestInfo *abs.RequestInformation, errorMappings abs.ErrorMappings) error { - if requestInfo == nil { - return errors.New("requestInfo cannot be nil") - } - ctx = a.prepareContext(ctx, requestInfo) - ctx, span := a.startTracingSpan(ctx, requestInfo, "SendNoContent") - defer span.End() - response, err := a.getHttpResponseMessage(ctx, requestInfo, "", span) - if err != nil { - return err - } - - responseHandler := getResponseHandler(ctx) - if responseHandler != nil { - span.AddEvent(EventResponseHandlerInvokedKey) - _, err := responseHandler(response, errorMappings) - if err != nil { - span.RecordError(err) - } - return err - } else if response != nil { - defer a.purge(response) - err = a.throwIfFailedResponse(ctx, response, errorMappings, span) - if err != nil { - return err - } - return nil - } else { - return errors.New("response is nil") - } -} - -func (a *NetHttpRequestAdapter) getRootParseNode(ctx context.Context, response *nethttp.Response, spanForAttributes trace.Span) (absser.ParseNode, context.Context, error) { - ctx, span := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "getRootParseNode") - defer span.End() - body, err := ioutil.ReadAll(response.Body) - if err != nil { - spanForAttributes.RecordError(err) - return nil, ctx, err - } - contentType := a.getResponsePrimaryContentType(response) - if contentType == "" { - return nil, ctx, nil - } - rootNode, err := a.parseNodeFactory.GetRootParseNode(contentType, body) - if err != nil { - spanForAttributes.RecordError(err) - } - return rootNode, ctx, err -} -func (a *NetHttpRequestAdapter) purge(response *nethttp.Response) error { - _, _ = ioutil.ReadAll(response.Body) //we don't care about errors comming from reading the body, just trying to purge anything that maybe left - err := response.Body.Close() - if err != nil { - return err - } - return nil -} -func (a *NetHttpRequestAdapter) shouldReturnNil(response *nethttp.Response) bool { - return response.StatusCode == 204 -} - -// ErrorMappingFoundAttributeName is the attribute name used to indicate whether an error code mapping was found. -const ErrorMappingFoundAttributeName = "com.microsoft.kiota.error.mapping_found" - -// ErrorBodyFoundAttributeName is the attribute name used to indicate whether the error response contained a body -const ErrorBodyFoundAttributeName = "com.microsoft.kiota.error.body_found" - -func (a *NetHttpRequestAdapter) throwIfFailedResponse(ctx context.Context, response *nethttp.Response, errorMappings abs.ErrorMappings, spanForAttributes trace.Span) error { - ctx, span := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "throwIfFailedResponse") - defer span.End() - if response.StatusCode < 400 { - return nil - } - spanForAttributes.SetStatus(codes.Error, "received_error_response") - - statusAsString := strconv.Itoa(response.StatusCode) - var errorCtor absser.ParsableFactory = nil - if len(errorMappings) != 0 { - if errorMappings[statusAsString] != nil { - errorCtor = errorMappings[statusAsString] - } else if response.StatusCode >= 400 && response.StatusCode < 500 && errorMappings["4XX"] != nil { - errorCtor = errorMappings["4XX"] - } else if response.StatusCode >= 500 && response.StatusCode < 600 && errorMappings["5XX"] != nil { - errorCtor = errorMappings["5XX"] - } - } - - if errorCtor == nil { - spanForAttributes.SetAttributes(attribute.Bool(ErrorMappingFoundAttributeName, false)) - err := &abs.ApiError{ - Message: "The server returned an unexpected status code and no error factory is registered for this code: " + statusAsString, - } - spanForAttributes.RecordError(err) - return err - } - spanForAttributes.SetAttributes(attribute.Bool(ErrorMappingFoundAttributeName, true)) - - rootNode, _, err := a.getRootParseNode(ctx, response, spanForAttributes) - if err != nil { - spanForAttributes.RecordError(err) - return err - } - if rootNode == nil { - spanForAttributes.SetAttributes(attribute.Bool(ErrorBodyFoundAttributeName, false)) - err := &abs.ApiError{ - Message: "The server returned an unexpected status code with no response body: " + statusAsString, - } - spanForAttributes.RecordError(err) - return err - } - spanForAttributes.SetAttributes(attribute.Bool(ErrorBodyFoundAttributeName, true)) - - _, deserializeSpan := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "GetObjectValue") - defer deserializeSpan.End() - errValue, err := rootNode.GetObjectValue(errorCtor) - if err != nil { - spanForAttributes.RecordError(err) - return err - } else if errValue == nil { - return &abs.ApiError{ - Message: "The server returned an unexpected status code but the error could not be deserialized: " + statusAsString, - } - } - - err = errValue.(error) - spanForAttributes.RecordError(err) - return err -} +package nethttplibrary + +import ( + "bytes" + "context" + "errors" + "github.com/microsoft/kiota-abstractions-go/store" + "io" + "io/ioutil" + nethttp "net/http" + "reflect" + "regexp" + "strconv" + "strings" + "time" + + abs "github.com/microsoft/kiota-abstractions-go" + absauth "github.com/microsoft/kiota-abstractions-go/authentication" + absser "github.com/microsoft/kiota-abstractions-go/serialization" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" +) + +// nopCloser is an alternate io.nopCloser implementation which +// provides io.ReadSeekCloser instead of io.ReadCloser as we need +// Seek for retries +type nopCloser struct { + io.ReadSeeker +} + +func NopCloser(r io.ReadSeeker) io.ReadSeekCloser { + return nopCloser{r} +} + +func (nopCloser) Close() error { return nil } + +// NetHttpRequestAdapter implements the RequestAdapter interface using net/http +type NetHttpRequestAdapter struct { + // serializationWriterFactory is the factory used to create serialization writers + serializationWriterFactory absser.SerializationWriterFactory + // parseNodeFactory is the factory used to create parse nodes + parseNodeFactory absser.ParseNodeFactory + // httpClient is the client used to send requests + httpClient *nethttp.Client + // authenticationProvider is the provider used to authenticate requests + authenticationProvider absauth.AuthenticationProvider + // The base url for every request. + baseUrl string + // The observation options for the request adapter. + observabilityOptions ObservabilityOptions +} + +// NewNetHttpRequestAdapter creates a new NetHttpRequestAdapter with the given parameters +func NewNetHttpRequestAdapter(authenticationProvider absauth.AuthenticationProvider) (*NetHttpRequestAdapter, error) { + return NewNetHttpRequestAdapterWithParseNodeFactory(authenticationProvider, nil) +} + +// NewNetHttpRequestAdapterWithParseNodeFactory creates a new NetHttpRequestAdapter with the given parameters +func NewNetHttpRequestAdapterWithParseNodeFactory(authenticationProvider absauth.AuthenticationProvider, parseNodeFactory absser.ParseNodeFactory) (*NetHttpRequestAdapter, error) { + return NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactory(authenticationProvider, parseNodeFactory, nil) +} + +// NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactory creates a new NetHttpRequestAdapter with the given parameters +func NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactory(authenticationProvider absauth.AuthenticationProvider, parseNodeFactory absser.ParseNodeFactory, serializationWriterFactory absser.SerializationWriterFactory) (*NetHttpRequestAdapter, error) { + return NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClient(authenticationProvider, parseNodeFactory, serializationWriterFactory, nil) +} + +// NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClient creates a new NetHttpRequestAdapter with the given parameters +func NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClient(authenticationProvider absauth.AuthenticationProvider, parseNodeFactory absser.ParseNodeFactory, serializationWriterFactory absser.SerializationWriterFactory, httpClient *nethttp.Client) (*NetHttpRequestAdapter, error) { + return NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClientAndObservabilityOptions(authenticationProvider, parseNodeFactory, serializationWriterFactory, httpClient, ObservabilityOptions{}) +} + +// NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClientAndObservabilityOptions creates a new NetHttpRequestAdapter with the given parameters +func NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClientAndObservabilityOptions(authenticationProvider absauth.AuthenticationProvider, parseNodeFactory absser.ParseNodeFactory, serializationWriterFactory absser.SerializationWriterFactory, httpClient *nethttp.Client, observabilityOptions ObservabilityOptions) (*NetHttpRequestAdapter, error) { + if authenticationProvider == nil { + return nil, errors.New("authenticationProvider cannot be nil") + } + result := &NetHttpRequestAdapter{ + serializationWriterFactory: serializationWriterFactory, + parseNodeFactory: parseNodeFactory, + httpClient: httpClient, + authenticationProvider: authenticationProvider, + baseUrl: "", + observabilityOptions: observabilityOptions, + } + if result.httpClient == nil { + defaultClient := GetDefaultClient() + result.httpClient = defaultClient + } + if result.serializationWriterFactory == nil { + result.serializationWriterFactory = absser.DefaultSerializationWriterFactoryInstance + } + if result.parseNodeFactory == nil { + result.parseNodeFactory = absser.DefaultParseNodeFactoryInstance + } + return result, nil +} + +// GetSerializationWriterFactory returns the serialization writer factory currently in use for the request adapter service. +func (a *NetHttpRequestAdapter) GetSerializationWriterFactory() absser.SerializationWriterFactory { + return a.serializationWriterFactory +} + +// EnableBackingStore enables the backing store proxies for the SerializationWriters and ParseNodes in use. +func (a *NetHttpRequestAdapter) EnableBackingStore(factory store.BackingStoreFactory) { + a.parseNodeFactory = abs.EnableBackingStoreForParseNodeFactory(a.parseNodeFactory) + a.serializationWriterFactory = abs.EnableBackingStoreForSerializationWriterFactory(a.serializationWriterFactory) + if factory != nil { + store.BackingStoreFactoryInstance = factory + } +} + +// SetBaseUrl sets the base url for every request. +func (a *NetHttpRequestAdapter) SetBaseUrl(baseUrl string) { + a.baseUrl = baseUrl +} + +// GetBaseUrl gets the base url for every request. +func (a *NetHttpRequestAdapter) GetBaseUrl() string { + return a.baseUrl +} + +func (a *NetHttpRequestAdapter) getHttpResponseMessage(ctx context.Context, requestInfo *abs.RequestInformation, claims string, spanForAttributes trace.Span) (*nethttp.Response, error) { + ctx, span := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "getHttpResponseMessage") + defer span.End() + if ctx == nil { + ctx = context.Background() + } + a.setBaseUrlForRequestInformation(requestInfo) + additionalContext := make(map[string]any) + if claims != "" { + additionalContext[claimsKey] = claims + } + err := a.authenticationProvider.AuthenticateRequest(ctx, requestInfo, additionalContext) + if err != nil { + return nil, err + } + request, err := a.getRequestFromRequestInformation(ctx, requestInfo, spanForAttributes) + if err != nil { + return nil, err + } + response, err := (*a.httpClient).Do(request) + if err != nil { + spanForAttributes.RecordError(err) + return nil, err + } + if response != nil { + contentLenHeader := response.Header.Get("Content-Length") + if contentLenHeader != "" { + contentLen, _ := strconv.Atoi(contentLenHeader) + spanForAttributes.SetAttributes(attribute.Int("http.response_content_length", contentLen)) + } + contentTypeHeader := response.Header.Get("Content-Type") + if contentTypeHeader != "" { + spanForAttributes.SetAttributes(attribute.String("http.response_content_type", contentTypeHeader)) + } + spanForAttributes.SetAttributes( + attribute.Int("http.status_code", response.StatusCode), + attribute.String("http.flavor", response.Proto), + ) + } + return a.retryCAEResponseIfRequired(ctx, response, requestInfo, claims, spanForAttributes) +} + +const claimsKey = "claims" + +var reBearer = regexp.MustCompile(`(?i)^Bearer\s`) +var reClaims = regexp.MustCompile(`\"([^\"]*)\"`) + +const AuthenticateChallengedEventKey = "com.microsoft.kiota.authenticate_challenge_received" + +func (a *NetHttpRequestAdapter) retryCAEResponseIfRequired(ctx context.Context, response *nethttp.Response, requestInfo *abs.RequestInformation, claims string, spanForAttributes trace.Span) (*nethttp.Response, error) { + ctx, span := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "retryCAEResponseIfRequired") + defer span.End() + if response.StatusCode == 401 && + claims == "" { //avoid infinite loop, we only retry once + authenticateHeaderVal := response.Header.Get("WWW-Authenticate") + if authenticateHeaderVal != "" && reBearer.Match([]byte(authenticateHeaderVal)) { + span.AddEvent(AuthenticateChallengedEventKey) + spanForAttributes.SetAttributes(attribute.Int("http.retry_count", 1)) + responseClaims := "" + parametersRaw := string(reBearer.ReplaceAll([]byte(authenticateHeaderVal), []byte(""))) + parameters := strings.Split(parametersRaw, ",") + for _, parameter := range parameters { + if strings.HasPrefix(strings.Trim(parameter, " "), claimsKey) { + responseClaims = reClaims.FindStringSubmatch(parameter)[1] + break + } + } + if responseClaims != "" { + defer a.purge(response) + return a.getHttpResponseMessage(ctx, requestInfo, responseClaims, spanForAttributes) + } + } + } + return response, nil +} + +func (a *NetHttpRequestAdapter) getResponsePrimaryContentType(response *nethttp.Response) string { + if response.Header == nil { + return "" + } + rawType := response.Header.Get("Content-Type") + splat := strings.Split(rawType, ";") + return strings.ToLower(splat[0]) +} + +func (a *NetHttpRequestAdapter) setBaseUrlForRequestInformation(requestInfo *abs.RequestInformation) { + requestInfo.PathParameters["baseurl"] = a.GetBaseUrl() +} + +const requestTimeOutInSeconds = 100 + +func (a *NetHttpRequestAdapter) prepareContext(ctx context.Context, requestInfo *abs.RequestInformation) context.Context { + if ctx == nil { + ctx = context.Background() + } + // set deadline if not set in receiving context + if _, deadlineSet := ctx.Deadline(); !deadlineSet { + ctx, _ = context.WithTimeout(ctx, time.Second*requestTimeOutInSeconds) + } + + for _, value := range requestInfo.GetRequestOptions() { + ctx = context.WithValue(ctx, value.GetKey(), value) + } + obsOptionsSet := false + if reqObsOpt := ctx.Value(observabilityOptionsKeyValue); reqObsOpt != nil { + if _, ok := reqObsOpt.(ObservabilityOptionsInt); ok { + obsOptionsSet = true + } + } + if !obsOptionsSet { + ctx = context.WithValue(ctx, observabilityOptionsKeyValue, &a.observabilityOptions) + } + return ctx +} + +// ConvertToNativeRequest converts the given RequestInformation into a native HTTP request. +func (a *NetHttpRequestAdapter) ConvertToNativeRequest(context context.Context, requestInfo *abs.RequestInformation) (any, error) { + err := a.authenticationProvider.AuthenticateRequest(context, requestInfo, nil) + if err != nil { + return nil, err + } + request, err := a.getRequestFromRequestInformation(context, requestInfo, nil) + if err != nil { + return nil, err + } + return request, nil +} + +func (a *NetHttpRequestAdapter) getRequestFromRequestInformation(ctx context.Context, requestInfo *abs.RequestInformation, spanForAttributes trace.Span) (*nethttp.Request, error) { + ctx, span := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "getRequestFromRequestInformation") + defer span.End() + if spanForAttributes == nil { + spanForAttributes = span + } + spanForAttributes.SetAttributes(attribute.String("http.method", requestInfo.Method.String())) + uri, err := requestInfo.GetUri() + if err != nil { + spanForAttributes.RecordError(err) + return nil, err + } + spanForAttributes.SetAttributes( + attribute.String("http.scheme", uri.Scheme), + attribute.String("http.host", uri.Host), + ) + + if a.observabilityOptions.IncludeEUIIAttributes { + spanForAttributes.SetAttributes(attribute.String("http.uri", uri.String())) + } + + request, err := nethttp.NewRequestWithContext(ctx, requestInfo.Method.String(), uri.String(), nil) + + if err != nil { + spanForAttributes.RecordError(err) + return nil, err + } + if len(requestInfo.Content) > 0 { + reader := bytes.NewReader(requestInfo.Content) + request.Body = NopCloser(reader) + } + if request.Header == nil { + request.Header = make(nethttp.Header) + } + if requestInfo.Headers != nil { + for _, key := range requestInfo.Headers.ListKeys() { + values := requestInfo.Headers.Get(key) + for _, v := range values { + request.Header.Add(key, v) + } + } + if request.Header.Get("Content-Type") != "" { + spanForAttributes.SetAttributes( + attribute.String("http.request_content_type", request.Header.Get("Content-Type")), + ) + } + if request.Header.Get("Content-Length") != "" { + contentLenVal, _ := strconv.Atoi(request.Header.Get("Content-Length")) + spanForAttributes.SetAttributes( + attribute.Int("http.request_content_length", contentLenVal), + ) + } + } + + return request, nil +} + +const EventResponseHandlerInvokedKey = "com.microsoft.kiota.response_handler_invoked" + +var queryParametersCleanupRegex = regexp.MustCompile(`\{\?[^\}]+}`) + +func (a *NetHttpRequestAdapter) startTracingSpan(ctx context.Context, requestInfo *abs.RequestInformation, methodName string) (context.Context, trace.Span) { + decodedUriTemplate := decodeUriEncodedString(requestInfo.UrlTemplate, []byte{'-', '.', '~', '$'}) + telemetryPathValue := queryParametersCleanupRegex.ReplaceAll([]byte(decodedUriTemplate), []byte("")) + ctx, span := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, methodName+" - "+string(telemetryPathValue)) + span.SetAttributes(attribute.String("http.uri_template", decodedUriTemplate)) + return ctx, span +} + +// Send executes the HTTP request specified by the given RequestInformation and returns the deserialized response model. +func (a *NetHttpRequestAdapter) Send(ctx context.Context, requestInfo *abs.RequestInformation, constructor absser.ParsableFactory, errorMappings abs.ErrorMappings) (absser.Parsable, error) { + if requestInfo == nil { + return nil, errors.New("requestInfo cannot be nil") + } + ctx = a.prepareContext(ctx, requestInfo) + ctx, span := a.startTracingSpan(ctx, requestInfo, "Send") + defer span.End() + response, err := a.getHttpResponseMessage(ctx, requestInfo, "", span) + if err != nil { + return nil, err + } + + responseHandler := getResponseHandler(ctx) + if responseHandler != nil { + span.AddEvent(EventResponseHandlerInvokedKey) + result, err := responseHandler(response, errorMappings) + if err != nil { + span.RecordError(err) + return nil, err + } + return result.(absser.Parsable), nil + } else if response != nil { + defer a.purge(response) + err = a.throwIfFailedResponse(ctx, response, errorMappings, span) + if err != nil { + return nil, err + } + if a.shouldReturnNil(response) { + return nil, nil + } + parseNode, _, err := a.getRootParseNode(ctx, response, span) + if err != nil { + return nil, err + } + if parseNode == nil { + return nil, nil + } + _, deserializeSpan := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "GetObjectValue") + defer deserializeSpan.End() + result, err := parseNode.GetObjectValue(constructor) + a.setResponseType(result, span) + if err != nil { + span.RecordError(err) + } + return result, err + } else { + return nil, errors.New("response is nil") + } +} + +func (a *NetHttpRequestAdapter) setResponseType(result any, span trace.Span) { + if result != nil { + span.SetAttributes(attribute.String("com.microsoft.kiota.response.type", reflect.TypeOf(result).String())) + } +} + +// SendEnum executes the HTTP request specified by the given RequestInformation and returns the deserialized response model. +func (a *NetHttpRequestAdapter) SendEnum(ctx context.Context, requestInfo *abs.RequestInformation, parser absser.EnumFactory, errorMappings abs.ErrorMappings) (any, error) { + if requestInfo == nil { + return nil, errors.New("requestInfo cannot be nil") + } + ctx = a.prepareContext(ctx, requestInfo) + ctx, span := a.startTracingSpan(ctx, requestInfo, "SendEnum") + defer span.End() + response, err := a.getHttpResponseMessage(ctx, requestInfo, "", span) + if err != nil { + return nil, err + } + + responseHandler := getResponseHandler(ctx) + if responseHandler != nil { + span.AddEvent(EventResponseHandlerInvokedKey) + result, err := responseHandler(response, errorMappings) + if err != nil { + span.RecordError(err) + return nil, err + } + return result.(absser.Parsable), nil + } else if response != nil { + defer a.purge(response) + err = a.throwIfFailedResponse(ctx, response, errorMappings, span) + if err != nil { + return nil, err + } + if a.shouldReturnNil(response) { + return nil, nil + } + parseNode, _, err := a.getRootParseNode(ctx, response, span) + if err != nil { + return nil, err + } + if parseNode == nil { + return nil, nil + } + _, deserializeSpan := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "GetEnumValue") + defer deserializeSpan.End() + result, err := parseNode.GetEnumValue(parser) + a.setResponseType(result, span) + if err != nil { + span.RecordError(err) + } + return result, err + } else { + return nil, errors.New("response is nil") + } +} + +// SendCollection executes the HTTP request specified by the given RequestInformation and returns the deserialized response model collection. +func (a *NetHttpRequestAdapter) SendCollection(ctx context.Context, requestInfo *abs.RequestInformation, constructor absser.ParsableFactory, errorMappings abs.ErrorMappings) ([]absser.Parsable, error) { + if requestInfo == nil { + return nil, errors.New("requestInfo cannot be nil") + } + ctx = a.prepareContext(ctx, requestInfo) + ctx, span := a.startTracingSpan(ctx, requestInfo, "SendCollection") + defer span.End() + response, err := a.getHttpResponseMessage(ctx, requestInfo, "", span) + if err != nil { + return nil, err + } + + responseHandler := getResponseHandler(ctx) + if responseHandler != nil { + span.AddEvent(EventResponseHandlerInvokedKey) + result, err := responseHandler(response, errorMappings) + if err != nil { + span.RecordError(err) + return nil, err + } + return result.([]absser.Parsable), nil + } else if response != nil { + defer a.purge(response) + err = a.throwIfFailedResponse(ctx, response, errorMappings, span) + if err != nil { + return nil, err + } + if a.shouldReturnNil(response) { + return nil, nil + } + parseNode, _, err := a.getRootParseNode(ctx, response, span) + if err != nil { + return nil, err + } + if parseNode == nil { + return nil, nil + } + _, deserializeSpan := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "GetCollectionOfObjectValues") + defer deserializeSpan.End() + result, err := parseNode.GetCollectionOfObjectValues(constructor) + a.setResponseType(result, span) + if err != nil { + span.RecordError(err) + } + return result, err + } else { + return nil, errors.New("response is nil") + } +} + +// SendEnumCollection executes the HTTP request specified by the given RequestInformation and returns the deserialized response model collection. +func (a *NetHttpRequestAdapter) SendEnumCollection(ctx context.Context, requestInfo *abs.RequestInformation, parser absser.EnumFactory, errorMappings abs.ErrorMappings) ([]any, error) { + if requestInfo == nil { + return nil, errors.New("requestInfo cannot be nil") + } + ctx = a.prepareContext(ctx, requestInfo) + ctx, span := a.startTracingSpan(ctx, requestInfo, "SendEnumCollection") + defer span.End() + response, err := a.getHttpResponseMessage(ctx, requestInfo, "", span) + if err != nil { + return nil, err + } + + responseHandler := getResponseHandler(ctx) + if responseHandler != nil { + span.AddEvent(EventResponseHandlerInvokedKey) + result, err := responseHandler(response, errorMappings) + if err != nil { + span.RecordError(err) + return nil, err + } + return result.([]any), nil + } else if response != nil { + defer a.purge(response) + err = a.throwIfFailedResponse(ctx, response, errorMappings, span) + if err != nil { + return nil, err + } + if a.shouldReturnNil(response) { + return nil, nil + } + parseNode, _, err := a.getRootParseNode(ctx, response, span) + if err != nil { + return nil, err + } + if parseNode == nil { + return nil, nil + } + _, deserializeSpan := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "GetCollectionOfEnumValues") + defer deserializeSpan.End() + result, err := parseNode.GetCollectionOfEnumValues(parser) + a.setResponseType(result, span) + if err != nil { + span.RecordError(err) + } + return result, err + } else { + return nil, errors.New("response is nil") + } +} + +func getResponseHandler(ctx context.Context) abs.ResponseHandler { + var handlerOption = ctx.Value(abs.ResponseHandlerOptionKey) + if handlerOption != nil { + return handlerOption.(abs.RequestHandlerOption).GetResponseHandler() + } + return nil +} + +// SendPrimitive executes the HTTP request specified by the given RequestInformation and returns the deserialized primitive response model. +func (a *NetHttpRequestAdapter) SendPrimitive(ctx context.Context, requestInfo *abs.RequestInformation, typeName string, errorMappings abs.ErrorMappings) (any, error) { + if requestInfo == nil { + return nil, errors.New("requestInfo cannot be nil") + } + ctx = a.prepareContext(ctx, requestInfo) + ctx, span := a.startTracingSpan(ctx, requestInfo, "SendPrimitive") + defer span.End() + response, err := a.getHttpResponseMessage(ctx, requestInfo, "", span) + if err != nil { + return nil, err + } + + responseHandler := getResponseHandler(ctx) + if responseHandler != nil { + span.AddEvent(EventResponseHandlerInvokedKey) + result, err := responseHandler(response, errorMappings) + if err != nil { + span.RecordError(err) + return nil, err + } + return result.(absser.Parsable), nil + } else if response != nil { + defer a.purge(response) + err = a.throwIfFailedResponse(ctx, response, errorMappings, span) + if err != nil { + return nil, err + } + if a.shouldReturnNil(response) { + return nil, nil + } + if typeName == "[]byte" { + res, err := ioutil.ReadAll(response.Body) + if err != nil { + span.RecordError(err) + return nil, err + } else if len(res) == 0 { + return nil, nil + } + return res, nil + } + parseNode, _, err := a.getRootParseNode(ctx, response, span) + if err != nil { + return nil, err + } + if parseNode == nil { + return nil, nil + } + _, deserializeSpan := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "Get"+typeName+"Value") + defer deserializeSpan.End() + var result any + switch typeName { + case "string": + result, err = parseNode.GetStringValue() + case "float32": + result, err = parseNode.GetFloat32Value() + case "float64": + result, err = parseNode.GetFloat64Value() + case "int32": + result, err = parseNode.GetInt32Value() + case "int64": + result, err = parseNode.GetInt64Value() + case "bool": + result, err = parseNode.GetBoolValue() + case "Time": + result, err = parseNode.GetTimeValue() + case "UUID": + result, err = parseNode.GetUUIDValue() + default: + return nil, errors.New("unsupported type") + } + a.setResponseType(result, span) + if err != nil { + span.RecordError(err) + } + return result, err + } else { + return nil, errors.New("response is nil") + } +} + +// SendPrimitiveCollection executes the HTTP request specified by the given RequestInformation and returns the deserialized primitive response model collection. +func (a *NetHttpRequestAdapter) SendPrimitiveCollection(ctx context.Context, requestInfo *abs.RequestInformation, typeName string, errorMappings abs.ErrorMappings) ([]any, error) { + if requestInfo == nil { + return nil, errors.New("requestInfo cannot be nil") + } + ctx = a.prepareContext(ctx, requestInfo) + ctx, span := a.startTracingSpan(ctx, requestInfo, "SendPrimitiveCollection") + defer span.End() + response, err := a.getHttpResponseMessage(ctx, requestInfo, "", span) + if err != nil { + return nil, err + } + + responseHandler := getResponseHandler(ctx) + if responseHandler != nil { + span.AddEvent(EventResponseHandlerInvokedKey) + result, err := responseHandler(response, errorMappings) + if err != nil { + span.RecordError(err) + return nil, err + } + return result.([]any), nil + } else if response != nil { + defer a.purge(response) + err = a.throwIfFailedResponse(ctx, response, errorMappings, span) + if err != nil { + return nil, err + } + if a.shouldReturnNil(response) { + return nil, nil + } + parseNode, _, err := a.getRootParseNode(ctx, response, span) + if err != nil { + return nil, err + } + if parseNode == nil { + return nil, nil + } + _, deserializeSpan := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "GetCollectionOfPrimitiveValues") + defer deserializeSpan.End() + result, err := parseNode.GetCollectionOfPrimitiveValues(typeName) + a.setResponseType(result, span) + if err != nil { + span.RecordError(err) + } + return result, err + } else { + return nil, errors.New("response is nil") + } +} + +// SendNoContent executes the HTTP request specified by the given RequestInformation with no return content. +func (a *NetHttpRequestAdapter) SendNoContent(ctx context.Context, requestInfo *abs.RequestInformation, errorMappings abs.ErrorMappings) error { + if requestInfo == nil { + return errors.New("requestInfo cannot be nil") + } + ctx = a.prepareContext(ctx, requestInfo) + ctx, span := a.startTracingSpan(ctx, requestInfo, "SendNoContent") + defer span.End() + response, err := a.getHttpResponseMessage(ctx, requestInfo, "", span) + if err != nil { + return err + } + + responseHandler := getResponseHandler(ctx) + if responseHandler != nil { + span.AddEvent(EventResponseHandlerInvokedKey) + _, err := responseHandler(response, errorMappings) + if err != nil { + span.RecordError(err) + } + return err + } else if response != nil { + defer a.purge(response) + err = a.throwIfFailedResponse(ctx, response, errorMappings, span) + if err != nil { + return err + } + return nil + } else { + return errors.New("response is nil") + } +} + +func (a *NetHttpRequestAdapter) getRootParseNode(ctx context.Context, response *nethttp.Response, spanForAttributes trace.Span) (absser.ParseNode, context.Context, error) { + ctx, span := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "getRootParseNode") + defer span.End() + body, err := ioutil.ReadAll(response.Body) + if err != nil { + spanForAttributes.RecordError(err) + return nil, ctx, err + } + contentType := a.getResponsePrimaryContentType(response) + if contentType == "" { + return nil, ctx, nil + } + rootNode, err := a.parseNodeFactory.GetRootParseNode(contentType, body) + if err != nil { + spanForAttributes.RecordError(err) + } + return rootNode, ctx, err +} +func (a *NetHttpRequestAdapter) purge(response *nethttp.Response) error { + _, _ = ioutil.ReadAll(response.Body) //we don't care about errors comming from reading the body, just trying to purge anything that maybe left + err := response.Body.Close() + if err != nil { + return err + } + return nil +} +func (a *NetHttpRequestAdapter) shouldReturnNil(response *nethttp.Response) bool { + return response.StatusCode == 204 +} + +// ErrorMappingFoundAttributeName is the attribute name used to indicate whether an error code mapping was found. +const ErrorMappingFoundAttributeName = "com.microsoft.kiota.error.mapping_found" + +// ErrorBodyFoundAttributeName is the attribute name used to indicate whether the error response contained a body +const ErrorBodyFoundAttributeName = "com.microsoft.kiota.error.body_found" + +func (a *NetHttpRequestAdapter) throwIfFailedResponse(ctx context.Context, response *nethttp.Response, errorMappings abs.ErrorMappings, spanForAttributes trace.Span) error { + ctx, span := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "throwIfFailedResponse") + defer span.End() + if response.StatusCode < 400 { + return nil + } + spanForAttributes.SetStatus(codes.Error, "received_error_response") + + statusAsString := strconv.Itoa(response.StatusCode) + var errorCtor absser.ParsableFactory = nil + if len(errorMappings) != 0 { + if errorMappings[statusAsString] != nil { + errorCtor = errorMappings[statusAsString] + } else if response.StatusCode >= 400 && response.StatusCode < 500 && errorMappings["4XX"] != nil { + errorCtor = errorMappings["4XX"] + } else if response.StatusCode >= 500 && response.StatusCode < 600 && errorMappings["5XX"] != nil { + errorCtor = errorMappings["5XX"] + } + } + + if errorCtor == nil { + spanForAttributes.SetAttributes(attribute.Bool(ErrorMappingFoundAttributeName, false)) + err := &abs.ApiError{ + Message: "The server returned an unexpected status code and no error factory is registered for this code: " + statusAsString, + } + spanForAttributes.RecordError(err) + return err + } + spanForAttributes.SetAttributes(attribute.Bool(ErrorMappingFoundAttributeName, true)) + + rootNode, _, err := a.getRootParseNode(ctx, response, spanForAttributes) + if err != nil { + spanForAttributes.RecordError(err) + return err + } + if rootNode == nil { + spanForAttributes.SetAttributes(attribute.Bool(ErrorBodyFoundAttributeName, false)) + err := &abs.ApiError{ + Message: "The server returned an unexpected status code with no response body: " + statusAsString, + } + spanForAttributes.RecordError(err) + return err + } + spanForAttributes.SetAttributes(attribute.Bool(ErrorBodyFoundAttributeName, true)) + + _, deserializeSpan := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "GetObjectValue") + defer deserializeSpan.End() + errValue, err := rootNode.GetObjectValue(errorCtor) + if err != nil { + spanForAttributes.RecordError(err) + return err + } else if errValue == nil { + return &abs.ApiError{ + Message: "The server returned an unexpected status code but the error could not be deserialized: " + statusAsString, + } + } + + err = errValue.(error) + spanForAttributes.RecordError(err) + return err +} diff --git a/nethttp_request_adapter_test.go b/nethttp_request_adapter_test.go index 855993a..3a38f51 100644 --- a/nethttp_request_adapter_test.go +++ b/nethttp_request_adapter_test.go @@ -1,246 +1,261 @@ -package nethttplibrary - -import ( - "context" - nethttp "net/http" - httptest "net/http/httptest" - "net/url" - "testing" - - abs "github.com/microsoft/kiota-abstractions-go" - absauth "github.com/microsoft/kiota-abstractions-go/authentication" - "github.com/microsoft/kiota-http-go/internal" - - "github.com/stretchr/testify/assert" -) - -func TestItRetriesOnCAEResponse(t *testing.T) { - methodCallCount := 0 - - testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { - if methodCallCount > 0 { - res.WriteHeader(200) - } else { - res.Header().Set("WWW-Authenticate", "Bearer realm=\"\", authorization_uri=\"https://login.microsoftonline.com/common/oauth2/authorize\", client_id=\"00000003-0000-0000-c000-000000000000\", error=\"insufficient_claims\", claims=\"eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTY1MjgxMzUwOCJ9fX0=\"") - res.WriteHeader(401) - } - methodCallCount++ - res.Write([]byte("body")) - })) - defer func() { testServer.Close() }() - authProvider := &absauth.AnonymousAuthenticationProvider{} - adapter, err := NewNetHttpRequestAdapter(authProvider) - assert.Nil(t, err) - assert.NotNil(t, adapter) - - uri, err := url.Parse(testServer.URL) - assert.Nil(t, err) - assert.NotNil(t, uri) - request := abs.NewRequestInformation() - request.SetUri(*uri) - request.Method = abs.GET - - err2 := adapter.SendNoContent(context.TODO(), request, nil) - assert.Nil(t, err2) - assert.Equal(t, 2, methodCallCount) -} - -func TestImplementationHonoursInterface(t *testing.T) { - authProvider := &absauth.AnonymousAuthenticationProvider{} - adapter, err := NewNetHttpRequestAdapter(authProvider) - assert.Nil(t, err) - assert.NotNil(t, adapter) - - assert.Implements(t, (*abs.RequestAdapter)(nil), adapter) -} - -func TestItDoesntFailOnEmptyContentType(t *testing.T) { - testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { - res.WriteHeader(201) - })) - defer func() { testServer.Close() }() - authProvider := &absauth.AnonymousAuthenticationProvider{} - adapter, err := NewNetHttpRequestAdapter(authProvider) - assert.Nil(t, err) - assert.NotNil(t, adapter) - - uri, err := url.Parse(testServer.URL) - assert.Nil(t, err) - assert.NotNil(t, uri) - request := abs.NewRequestInformation() - request.SetUri(*uri) - request.Method = abs.GET - - res, err := adapter.Send(context.Background(), request, nil, nil) - assert.Nil(t, err) - assert.Nil(t, res) -} - -func TestItReturnsUsableStreamOnStream(t *testing.T) { - statusCodes := []int{200, 201, 202, 203, 206} - - for i := 0; i < len(statusCodes); i++ { - - testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { - res.WriteHeader(statusCodes[i]) - res.Write([]byte("test")) - })) - defer func() { testServer.Close() }() - authProvider := &absauth.AnonymousAuthenticationProvider{} - adapter, err := NewNetHttpRequestAdapter(authProvider) - assert.Nil(t, err) - assert.NotNil(t, adapter) - - uri, err := url.Parse(testServer.URL) - assert.Nil(t, err) - assert.NotNil(t, uri) - request := abs.NewRequestInformation() - request.SetUri(*uri) - request.Method = abs.GET - - res, err2 := adapter.SendPrimitive(context.TODO(), request, "[]byte", nil) - assert.Nil(t, err2) - assert.NotNil(t, res) - assert.Equal(t, 4, len(res.([]byte))) - } -} - -func TestItReturnsNilOnStream(t *testing.T) { - statusCodes := []int{200, 201, 202, 203, 204} - - for i := 0; i < len(statusCodes); i++ { - - testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { - res.WriteHeader(statusCodes[i]) - })) - defer func() { testServer.Close() }() - authProvider := &absauth.AnonymousAuthenticationProvider{} - adapter, err := NewNetHttpRequestAdapter(authProvider) - assert.Nil(t, err) - assert.NotNil(t, adapter) - - uri, err := url.Parse(testServer.URL) - assert.Nil(t, err) - assert.NotNil(t, uri) - request := abs.NewRequestInformation() - request.SetUri(*uri) - request.Method = abs.GET - - res, err2 := adapter.SendPrimitive(context.TODO(), request, "[]byte", nil) - assert.Nil(t, err2) - assert.Nil(t, res) - } -} - -func TestSendNoContentDoesntFailOnOtherCodes(t *testing.T) { - statusCodes := []int{200, 201, 202, 203, 204, 206} - - for i := 0; i < len(statusCodes); i++ { - - testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { - res.WriteHeader(statusCodes[i]) - })) - defer func() { testServer.Close() }() - authProvider := &absauth.AnonymousAuthenticationProvider{} - adapter, err := NewNetHttpRequestAdapter(authProvider) - assert.Nil(t, err) - assert.NotNil(t, adapter) - - uri, err := url.Parse(testServer.URL) - assert.Nil(t, err) - assert.NotNil(t, uri) - request := abs.NewRequestInformation() - request.SetUri(*uri) - request.Method = abs.GET - - err2 := adapter.SendNoContent(context.TODO(), request, nil) - assert.Nil(t, err2) - } -} - -func TestSendReturnNilOnNoContent(t *testing.T) { - statusCodes := []int{200, 201, 202, 203, 204, 205} - - for i := 0; i < len(statusCodes); i++ { - - testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { - res.WriteHeader(statusCodes[i]) - })) - defer func() { testServer.Close() }() - authProvider := &absauth.AnonymousAuthenticationProvider{} - adapter, err := NewNetHttpRequestAdapter(authProvider) - assert.Nil(t, err) - assert.NotNil(t, adapter) - - uri, err := url.Parse(testServer.URL) - assert.Nil(t, err) - assert.NotNil(t, uri) - request := abs.NewRequestInformation() - request.SetUri(*uri) - request.Method = abs.GET - - res, err2 := adapter.Send(context.TODO(), request, internal.MockEntityFactory, nil) - assert.Nil(t, err2) - assert.Nil(t, res) - } -} - -func TestSendReturnsObjectOnContent(t *testing.T) { - statusCodes := []int{200, 201, 202, 203, 204, 205} - - for i := 0; i < len(statusCodes); i++ { - - testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { - res.WriteHeader(statusCodes[i]) - })) - defer func() { testServer.Close() }() - authProvider := &absauth.AnonymousAuthenticationProvider{} - adapter, err := NewNetHttpRequestAdapterWithParseNodeFactory(authProvider, &internal.MockParseNodeFactory{}) - assert.Nil(t, err) - assert.NotNil(t, adapter) - - uri, err := url.Parse(testServer.URL) - assert.Nil(t, err) - assert.NotNil(t, uri) - request := abs.NewRequestInformation() - request.SetUri(*uri) - request.Method = abs.GET - - res, err2 := adapter.Send(context.TODO(), request, internal.MockEntityFactory, nil) - assert.Nil(t, err2) - assert.Nil(t, res) - } -} - -func TestResponseHandlerIsCalledWhenProvided(t *testing.T) { - testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { - res.WriteHeader(201) - })) - defer func() { testServer.Close() }() - authProvider := &absauth.AnonymousAuthenticationProvider{} - adapter, err := NewNetHttpRequestAdapter(authProvider) - assert.Nil(t, err) - assert.NotNil(t, adapter) - - uri, err := url.Parse(testServer.URL) - assert.Nil(t, err) - assert.NotNil(t, uri) - request := abs.NewRequestInformation() - request.SetUri(*uri) - request.Method = abs.GET - - count := 1 - responseHandler := func(response interface{}, errorMappings abs.ErrorMappings) (interface{}, error) { - count = 2 - return nil, nil - } - - handlerOption := abs.NewRequestHandlerOption() - handlerOption.SetResponseHandler(responseHandler) - - request.AddRequestOptions([]abs.RequestOption{handlerOption}) - - err = adapter.SendNoContent(context.Background(), request, nil) - assert.Nil(t, err) - assert.Equal(t, 2, count) -} +package nethttplibrary + +import ( + "context" + nethttp "net/http" + httptest "net/http/httptest" + "net/url" + "testing" + + abs "github.com/microsoft/kiota-abstractions-go" + absauth "github.com/microsoft/kiota-abstractions-go/authentication" + absstore "github.com/microsoft/kiota-abstractions-go/store" + "github.com/microsoft/kiota-http-go/internal" + + "github.com/stretchr/testify/assert" +) + +func TestItRetriesOnCAEResponse(t *testing.T) { + methodCallCount := 0 + + testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { + if methodCallCount > 0 { + res.WriteHeader(200) + } else { + res.Header().Set("WWW-Authenticate", "Bearer realm=\"\", authorization_uri=\"https://login.microsoftonline.com/common/oauth2/authorize\", client_id=\"00000003-0000-0000-c000-000000000000\", error=\"insufficient_claims\", claims=\"eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTY1MjgxMzUwOCJ9fX0=\"") + res.WriteHeader(401) + } + methodCallCount++ + res.Write([]byte("body")) + })) + defer func() { testServer.Close() }() + authProvider := &absauth.AnonymousAuthenticationProvider{} + adapter, err := NewNetHttpRequestAdapter(authProvider) + assert.Nil(t, err) + assert.NotNil(t, adapter) + + uri, err := url.Parse(testServer.URL) + assert.Nil(t, err) + assert.NotNil(t, uri) + request := abs.NewRequestInformation() + request.SetUri(*uri) + request.Method = abs.GET + + err2 := adapter.SendNoContent(context.TODO(), request, nil) + assert.Nil(t, err2) + assert.Equal(t, 2, methodCallCount) +} + +func TestImplementationHonoursInterface(t *testing.T) { + authProvider := &absauth.AnonymousAuthenticationProvider{} + adapter, err := NewNetHttpRequestAdapter(authProvider) + assert.Nil(t, err) + assert.NotNil(t, adapter) + + assert.Implements(t, (*abs.RequestAdapter)(nil), adapter) +} + +func TestItDoesntFailOnEmptyContentType(t *testing.T) { + testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { + res.WriteHeader(201) + })) + defer func() { testServer.Close() }() + authProvider := &absauth.AnonymousAuthenticationProvider{} + adapter, err := NewNetHttpRequestAdapter(authProvider) + assert.Nil(t, err) + assert.NotNil(t, adapter) + + uri, err := url.Parse(testServer.URL) + assert.Nil(t, err) + assert.NotNil(t, uri) + request := abs.NewRequestInformation() + request.SetUri(*uri) + request.Method = abs.GET + + res, err := adapter.Send(context.Background(), request, nil, nil) + assert.Nil(t, err) + assert.Nil(t, res) +} + +func TestItReturnsUsableStreamOnStream(t *testing.T) { + statusCodes := []int{200, 201, 202, 203, 206} + + for i := 0; i < len(statusCodes); i++ { + + testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { + res.WriteHeader(statusCodes[i]) + res.Write([]byte("test")) + })) + defer func() { testServer.Close() }() + authProvider := &absauth.AnonymousAuthenticationProvider{} + adapter, err := NewNetHttpRequestAdapter(authProvider) + assert.Nil(t, err) + assert.NotNil(t, adapter) + + uri, err := url.Parse(testServer.URL) + assert.Nil(t, err) + assert.NotNil(t, uri) + request := abs.NewRequestInformation() + request.SetUri(*uri) + request.Method = abs.GET + + res, err2 := adapter.SendPrimitive(context.TODO(), request, "[]byte", nil) + assert.Nil(t, err2) + assert.NotNil(t, res) + assert.Equal(t, 4, len(res.([]byte))) + } +} + +func TestItReturnsNilOnStream(t *testing.T) { + statusCodes := []int{200, 201, 202, 203, 204} + + for i := 0; i < len(statusCodes); i++ { + + testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { + res.WriteHeader(statusCodes[i]) + })) + defer func() { testServer.Close() }() + authProvider := &absauth.AnonymousAuthenticationProvider{} + adapter, err := NewNetHttpRequestAdapter(authProvider) + assert.Nil(t, err) + assert.NotNil(t, adapter) + + uri, err := url.Parse(testServer.URL) + assert.Nil(t, err) + assert.NotNil(t, uri) + request := abs.NewRequestInformation() + request.SetUri(*uri) + request.Method = abs.GET + + res, err2 := adapter.SendPrimitive(context.TODO(), request, "[]byte", nil) + assert.Nil(t, err2) + assert.Nil(t, res) + } +} + +func TestSendNoContentDoesntFailOnOtherCodes(t *testing.T) { + statusCodes := []int{200, 201, 202, 203, 204, 206} + + for i := 0; i < len(statusCodes); i++ { + + testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { + res.WriteHeader(statusCodes[i]) + })) + defer func() { testServer.Close() }() + authProvider := &absauth.AnonymousAuthenticationProvider{} + adapter, err := NewNetHttpRequestAdapter(authProvider) + assert.Nil(t, err) + assert.NotNil(t, adapter) + + uri, err := url.Parse(testServer.URL) + assert.Nil(t, err) + assert.NotNil(t, uri) + request := abs.NewRequestInformation() + request.SetUri(*uri) + request.Method = abs.GET + + err2 := adapter.SendNoContent(context.TODO(), request, nil) + assert.Nil(t, err2) + } +} + +func TestSendReturnNilOnNoContent(t *testing.T) { + statusCodes := []int{200, 201, 202, 203, 204, 205} + + for i := 0; i < len(statusCodes); i++ { + + testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { + res.WriteHeader(statusCodes[i]) + })) + defer func() { testServer.Close() }() + authProvider := &absauth.AnonymousAuthenticationProvider{} + adapter, err := NewNetHttpRequestAdapter(authProvider) + assert.Nil(t, err) + assert.NotNil(t, adapter) + + uri, err := url.Parse(testServer.URL) + assert.Nil(t, err) + assert.NotNil(t, uri) + request := abs.NewRequestInformation() + request.SetUri(*uri) + request.Method = abs.GET + + res, err2 := adapter.Send(context.TODO(), request, internal.MockEntityFactory, nil) + assert.Nil(t, err2) + assert.Nil(t, res) + } +} + +func TestSendReturnsObjectOnContent(t *testing.T) { + statusCodes := []int{200, 201, 202, 203, 204, 205} + + for i := 0; i < len(statusCodes); i++ { + + testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { + res.WriteHeader(statusCodes[i]) + })) + defer func() { testServer.Close() }() + authProvider := &absauth.AnonymousAuthenticationProvider{} + adapter, err := NewNetHttpRequestAdapterWithParseNodeFactory(authProvider, &internal.MockParseNodeFactory{}) + assert.Nil(t, err) + assert.NotNil(t, adapter) + + uri, err := url.Parse(testServer.URL) + assert.Nil(t, err) + assert.NotNil(t, uri) + request := abs.NewRequestInformation() + request.SetUri(*uri) + request.Method = abs.GET + + res, err2 := adapter.Send(context.TODO(), request, internal.MockEntityFactory, nil) + assert.Nil(t, err2) + assert.Nil(t, res) + } +} + +func TestResponseHandlerIsCalledWhenProvided(t *testing.T) { + testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { + res.WriteHeader(201) + })) + defer func() { testServer.Close() }() + authProvider := &absauth.AnonymousAuthenticationProvider{} + adapter, err := NewNetHttpRequestAdapter(authProvider) + assert.Nil(t, err) + assert.NotNil(t, adapter) + + uri, err := url.Parse(testServer.URL) + assert.Nil(t, err) + assert.NotNil(t, uri) + request := abs.NewRequestInformation() + request.SetUri(*uri) + request.Method = abs.GET + + count := 1 + responseHandler := func(response interface{}, errorMappings abs.ErrorMappings) (interface{}, error) { + count = 2 + return nil, nil + } + + handlerOption := abs.NewRequestHandlerOption() + handlerOption.SetResponseHandler(responseHandler) + + request.AddRequestOptions([]abs.RequestOption{handlerOption}) + + err = adapter.SendNoContent(context.Background(), request, nil) + assert.Nil(t, err) + assert.Equal(t, 2, count) +} + +func TestNetHttpRequestAdapter_EnableBackingStore(t *testing.T) { + authProvider := &absauth.AnonymousAuthenticationProvider{} + adapter, err := NewNetHttpRequestAdapter(authProvider) + assert.NoError(t, err) + + var store = func() absstore.BackingStore { + return nil + } + + assert.NotEqual(t, absstore.BackingStoreFactoryInstance(), store()) + adapter.EnableBackingStore(store) + assert.Equal(t, absstore.BackingStoreFactoryInstance(), store()) +}