Skip to content

Commit

Permalink
pkg/env: log MSI data-plane interactions
Browse files Browse the repository at this point in the history
Signed-off-by: Steve Kuznetsov <[email protected]>
  • Loading branch information
stevekuznetsov committed Feb 5, 2025
1 parent 5cf9da0 commit 08f6976
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 3 deletions.
78 changes: 77 additions & 1 deletion pkg/env/prod.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@ package env
// Licensed under the Apache License 2.0.

import (
"bytes"
"context"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"strings"
Expand All @@ -26,6 +30,7 @@ import (
"k8s.io/utils/ptr"

"github.com/Azure/ARO-RP/pkg/proxy"
"github.com/Azure/ARO-RP/pkg/util/azureclient"
"github.com/Azure/ARO-RP/pkg/util/azureclient/mgmt/compute"
"github.com/Azure/ARO-RP/pkg/util/clientauthorizer"
"github.com/Azure/ARO-RP/pkg/util/computeskus"
Expand Down Expand Up @@ -399,12 +404,83 @@ func (p *prod) MsiRpEndpoint() string {
}

func (p *prod) MsiDataplaneClientOptions() (*policy.ClientOptions, error) {
armClientOptions := p.Environment().ArmClientOptions()
armClientOptions := p.Environment().ArmClientOptions(ClientDebugLoggerMiddleware(p.log.WithField("client", "msi-dataplane")))
clientOptions := armClientOptions.ClientOptions

return &clientOptions, nil
}

func ClientDebugLoggerMiddleware(log *logrus.Entry) azureclient.Middleware {
return func(delegate http.RoundTripper) http.RoundTripper {
return azureclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
log := log.WithFields(logrus.Fields{
"method": req.Method,
"url": req.URL,
})
if req.Body != nil {
body, err := io.ReadAll(req.Body)
if err != nil {
log.WithError(err).Error("error reading request body")
}
if err := req.Body.Close(); err != nil {
log.WithError(err).Error("error closing request body")
}
log = log.WithField("body", string(body))
req.Body = io.NopCloser(bytes.NewBuffer(body)) // reset body so the delegate can use it
}
log.Info("Sending request.")
resp, err := delegate.RoundTrip(req)
if err != nil {
log.WithError(err).Error("Request errored.")
} else if resp != nil {
log = log.WithFields(logrus.Fields{
"status": resp.StatusCode,
})
body, err := io.ReadAll(resp.Body)
if err != nil {
log.WithError(err).Error("error reading response body")
}
if err := resp.Body.Close(); err != nil {
log.WithError(err).Error("error closing response body")
}
// n.b.: we only send one request now, this is best-effort but would need to be updated if we use other methods
response := dataplane.ManagedIdentityCredentials{}
if err := json.Unmarshal(body, &response); err != nil {
log.WithError(err).Error("error unmarshalling response body")
} else {
censorCredentials(&response)
log = log.WithField("body", string(body))
}
resp.Body = io.NopCloser(bytes.NewBuffer(body)) // reset body so the upstream round-trippers can use it
}
log.Info("Received response.")

return resp, err
})
}
}

func censorCredentials(input *dataplane.ManagedIdentityCredentials) {
input.ClientSecret = nil
if input.DelegatedResources != nil {
for i := 0; i < len(*input.DelegatedResources); i++ {
if (*input.DelegatedResources)[i].ImplicitIdentity != nil {
(*input.DelegatedResources)[i].ImplicitIdentity.ClientSecret = nil
}
if (*input.DelegatedResources)[i].ExplicitIdentities != nil {
for j := 0; j < len(*(*input.DelegatedResources)[i].ExplicitIdentities); j++ {
(*(*input.DelegatedResources)[i].ExplicitIdentities)[j].ClientSecret = nil
}
}
}
}
if input.ExplicitIdentities != nil {
for i := 0; i < len(*input.ExplicitIdentities); i++ {
(*input.ExplicitIdentities)[i].ClientSecret = nil
}
}
}

func (p *prod) MockMSIResponses(msiResourceId *arm.ResourceID) dataplane.ClientFactory {
return &mockFactory{aadHost: p.Environment().Cloud.ActiveDirectoryAuthorityHost, msiResourceId: msiResourceId.String()}
}
Expand Down
146 changes: 146 additions & 0 deletions pkg/env/prod_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package env

// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.

import (
"testing"

"github.com/Azure/msi-dataplane/pkg/dataplane"
"github.com/google/go-cmp/cmp"
"k8s.io/utils/ptr"
)

func managedIdentityCredentials(censor bool, delegatedResources []dataplane.DelegatedResource, explicitIdentities []dataplane.UserAssignedIdentityCredentials) dataplane.ManagedIdentityCredentials {
return dataplane.ManagedIdentityCredentials{
AuthenticationEndpoint: ptr.To("AuthenticationEndpoint"),
CannotRenewAfter: ptr.To("CannotRenewAfter"),
ClientId: ptr.To("ClientId"),
ClientSecret: func() *string {
if censor {
return nil
}
return ptr.To("ClientSecret")
}(),
ClientSecretUrl: ptr.To("ClientSecretUrl"),
CustomClaims: ptr.To(customClaims()),
DelegatedResources: func() *[]dataplane.DelegatedResource {
if len(delegatedResources) > 0 {
return &delegatedResources
}
return nil
}(),
DelegationUrl: ptr.To("DelegationUrl"),
ExplicitIdentities: func() *[]dataplane.UserAssignedIdentityCredentials {
if len(explicitIdentities) > 0 {
return &explicitIdentities
}
return nil
}(),
InternalId: ptr.To("InternalId"),
MtlsAuthenticationEndpoint: ptr.To("MtlsAuthenticationEndpoint"),
NotAfter: ptr.To("NotAfter"),
NotBefore: ptr.To("NotBefore"),
ObjectId: ptr.To("ObjectId"),
RenewAfter: ptr.To("RenewAfter"),
TenantId: ptr.To("TenantId"),
}
}

func delegatedResource(implicitIdentity *dataplane.UserAssignedIdentityCredentials, explicitIdentities ...dataplane.UserAssignedIdentityCredentials) dataplane.DelegatedResource {
return dataplane.DelegatedResource{
DelegationId: ptr.To("DelegationId"),
DelegationUrl: ptr.To("DelegationUrl"),
ExplicitIdentities: func() *[]dataplane.UserAssignedIdentityCredentials {
if len(explicitIdentities) > 0 {
return &explicitIdentities
}
return nil
}(),
ImplicitIdentity: implicitIdentity,
InternalId: ptr.To("InternalId"),
ResourceId: ptr.To("ResourceId"),
}
}

func userAssignedIdentityCredentials(censor bool) dataplane.UserAssignedIdentityCredentials {
return dataplane.UserAssignedIdentityCredentials{
AuthenticationEndpoint: ptr.To("AuthenticationEndpoint"),
CannotRenewAfter: ptr.To("CannotRenewAfter"),
ClientId: ptr.To("ClientId"),
ClientSecret: func() *string {
if censor {
return nil
}
return ptr.To("ClientSecret")
}(),
ClientSecretUrl: ptr.To("ClientSecretUrl"),
CustomClaims: ptr.To(customClaims()),
MtlsAuthenticationEndpoint: ptr.To("MtlsAuthenticationEndpoint"),
NotAfter: ptr.To("NotAfter"),
NotBefore: ptr.To("NotBefore"),
ObjectId: ptr.To("ObjectId"),
RenewAfter: ptr.To("RenewAfter"),
ResourceId: ptr.To("ResourceId"),
TenantId: ptr.To("TenantId"),
}
}

func customClaims() dataplane.CustomClaims {
return dataplane.CustomClaims{
XmsAzNwperimid: ptr.To([]string{"XmsAzNwperimid"}),
XmsAzTm: ptr.To("XmsAzTm"),
}
}

func TestCensorCredentials(t *testing.T) {
for _, testCase := range []struct {
name string
generateData func(censor bool) (data *dataplane.ManagedIdentityCredentials)
}{
{
name: "no delegated resources, explicit credentials",
generateData: func(censor bool) (data *dataplane.ManagedIdentityCredentials) {
return ptr.To(managedIdentityCredentials(censor, nil, nil))
},
},
{
name: "delegated resource without explicit credentials, no top-level explicit credentials",
generateData: func(censor bool) (data *dataplane.ManagedIdentityCredentials) {
return ptr.To(managedIdentityCredentials(censor, []dataplane.DelegatedResource{
delegatedResource(ptr.To(userAssignedIdentityCredentials(censor))),
delegatedResource(ptr.To(userAssignedIdentityCredentials(censor))),
}, nil))
},
},
{
name: "delegated resource with explicit credentials, no top-level explicit credentials",
generateData: func(censor bool) (data *dataplane.ManagedIdentityCredentials) {
return ptr.To(managedIdentityCredentials(censor, []dataplane.DelegatedResource{
delegatedResource(ptr.To(userAssignedIdentityCredentials(censor)), userAssignedIdentityCredentials(censor), userAssignedIdentityCredentials(censor)),
delegatedResource(ptr.To(userAssignedIdentityCredentials(censor)), userAssignedIdentityCredentials(censor), userAssignedIdentityCredentials(censor)),
}, nil))
},
},
{
name: "delegated resource with explicit credentials, top-level explicit credentials",
generateData: func(censor bool) (data *dataplane.ManagedIdentityCredentials) {
return ptr.To(managedIdentityCredentials(censor, []dataplane.DelegatedResource{
delegatedResource(ptr.To(userAssignedIdentityCredentials(censor)), userAssignedIdentityCredentials(censor), userAssignedIdentityCredentials(censor)),
delegatedResource(ptr.To(userAssignedIdentityCredentials(censor)), userAssignedIdentityCredentials(censor), userAssignedIdentityCredentials(censor)),
}, []dataplane.UserAssignedIdentityCredentials{
userAssignedIdentityCredentials(censor),
userAssignedIdentityCredentials(censor),
}))
},
},
} {
t.Run(testCase.name, func(t *testing.T) {
input, output := testCase.generateData(false), testCase.generateData(true)
censorCredentials(input)
if diff := cmp.Diff(output, input); diff != "" {
t.Errorf("censorCredentials mismatch (-want +got):\n%s", diff)
}
})
}
}
27 changes: 25 additions & 2 deletions pkg/util/azureclient/environments.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,33 @@ func EnvironmentFromName(name string) (AROEnvironment, error) {
return AROEnvironment{}, fmt.Errorf("cloud environment %q is unsupported by ARO", name)
}

// RoundTripperFunc allows a function to implement http.RoundTripper
type RoundTripperFunc func(*http.Request) (*http.Response, error)

func (rt RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return rt(req)
}

// Middleware closes over any client-side middleware
type Middleware func(http.RoundTripper) http.RoundTripper

// Chain is a handy function to wrap a base RoundTripper (optional) with the middlewares.
func Chain(rt http.RoundTripper, middlewares ...Middleware) http.RoundTripper {
if rt == nil {
rt = http.DefaultTransport
}

for _, m := range middlewares {
rt = m(rt)
}

return rt
}

// ArmClientOptions returns an arm.ClientOptions to be passed in when instantiating
// Azure SDK for Go clients.
func (e *AROEnvironment) ArmClientOptions() *arm.ClientOptions {
customRoundTripper := NewCustomRoundTripper(http.DefaultTransport)
func (e *AROEnvironment) ArmClientOptions(middlewares ...Middleware) *arm.ClientOptions {
customRoundTripper := Chain(http.DefaultTransport, append([]Middleware{NewCustomRoundTripper}, middlewares...)...)
return &arm.ClientOptions{
ClientOptions: azcore.ClientOptions{
Cloud: e.Cloud,
Expand Down

0 comments on commit 08f6976

Please sign in to comment.