Skip to content

Commit

Permalink
feat: Move the common middlewares to go-mod-bootstrap
Browse files Browse the repository at this point in the history
Closes edgexfoundry#565.
- Define the common middleware funcs on the router level.
- Skip Ping in the auth handler func.

Signed-off-by: Lindsey Cheng <[email protected]>
  • Loading branch information
lindseysimple committed Jul 26, 2023
1 parent eae20c3 commit 1888942
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 6 deletions.
9 changes: 3 additions & 6 deletions bootstrap/controller/commonapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"strings"

"github.com/edgexfoundry/go-mod-bootstrap/v3/bootstrap/container"
"github.com/edgexfoundry/go-mod-bootstrap/v3/bootstrap/handlers"
"github.com/edgexfoundry/go-mod-bootstrap/v3/bootstrap/interfaces"
"github.com/edgexfoundry/go-mod-bootstrap/v3/bootstrap/utils"
"github.com/edgexfoundry/go-mod-bootstrap/v3/di"
Expand Down Expand Up @@ -46,8 +45,6 @@ type config struct {

func NewCommonController(dic *di.Container, r *echo.Echo, serviceName string, serviceVersion string) *CommonController {
lc := container.LoggingClientFrom(dic.Get)
secretProvider := container.SecretProviderExtFrom(dic.Get)
authenticationHook := handlers.AutoConfigAuthenticationFunc(secretProvider, lc)
configuration := container.ConfigurationFrom(dic.Get)
c := CommonController{
dic: dic,
Expand All @@ -62,9 +59,9 @@ func NewCommonController(dic *di.Container, r *echo.Echo, serviceName string, se
},
}
r.GET(common.ApiPingRoute, c.Ping) // Health check is always unauthenticated
r.GET(common.ApiVersionRoute, c.Version, authenticationHook)
r.GET(common.ApiConfigRoute, c.Config, authenticationHook)
r.POST(common.ApiSecretRoute, c.AddSecret, authenticationHook)
r.GET(common.ApiVersionRoute, c.Version)
r.GET(common.ApiConfigRoute, c.Config)
r.POST(common.ApiSecretRoute, c.AddSecret)

return &c
}
Expand Down
6 changes: 6 additions & 0 deletions bootstrap/handlers/auth_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"strings"

"github.com/edgexfoundry/go-mod-core-contracts/v3/clients/logger"
"github.com/edgexfoundry/go-mod-core-contracts/v3/common"

"github.com/edgexfoundry/go-mod-bootstrap/v3/bootstrap/interfaces"
"github.com/edgexfoundry/go-mod-bootstrap/v3/bootstrap/secret"
Expand Down Expand Up @@ -50,6 +51,11 @@ import (
func VaultAuthenticationHandlerFunc(secretProvider interfaces.SecretProviderExt, lc logger.LoggingClient) echo.MiddlewareFunc {
return func(inner echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// Skip the JWT authorization check for Ping request
if c.Path() == common.ApiPingRoute {
return nil
}

r := c.Request()
w := c.Response()
authHeader := r.Header.Get("Authorization")
Expand Down
95 changes: 95 additions & 0 deletions bootstrap/handlers/common_middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
//
// Copyright (C) 2023 IOTech Ltd
//
// SPDX-License-Identifier: Apache-2.0

package handlers

import (
"context"
"net/http"
"net/url"
"time"

"github.com/edgexfoundry/go-mod-core-contracts/v3/clients/logger"
"github.com/edgexfoundry/go-mod-core-contracts/v3/common"
"github.com/edgexfoundry/go-mod-core-contracts/v3/models"

"github.com/google/uuid"
"github.com/labstack/echo/v4"
)

func ManageHeader(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
r := c.Request()
correlationID := r.Header.Get(common.CorrelationHeader)
if correlationID == "" {
correlationID = uuid.New().String()
}
// lint:ignore SA1029 legacy
// nolint:staticcheck // See golangci-lint #741
ctx := context.WithValue(r.Context(), common.CorrelationHeader, correlationID)

contentType := r.Header.Get(common.ContentType)
// lint:ignore SA1029 legacy
// nolint:staticcheck // See golangci-lint #741
ctx = context.WithValue(ctx, common.ContentType, contentType)

c.SetRequest(r.WithContext(ctx))

return next(c)
}
}

func LoggingMiddleware(lc logger.LoggingClient) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if lc.LogLevel() == models.TraceLog {
r := c.Request()
begin := time.Now()
correlationId := FromContext(r.Context())
lc.Trace("Begin request", common.CorrelationHeader, correlationId, "path", r.URL.Path)
err := next(c)
if err != nil {
lc.Errorf("failed to add the middleware: %v", err)
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
lc.Trace("Response complete", common.CorrelationHeader, correlationId, "duration", time.Since(begin).String())
return nil
}
return next(c)
}
}
}

// UrlDecodeMiddleware decode the path variables
// After invoking the router.UseEncodedPath() func, the path variables needs to decode before passing to the controller
func UrlDecodeMiddleware(lc logger.LoggingClient) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
var unescapedParams []string
// Retrieve all the url path param names
paramNames := c.ParamNames()

// Retrieve all the url path param values and decode
for k, v := range c.ParamValues() {
unescape, err := url.PathUnescape(v)
if err != nil {
lc.Debugf("failed to decode the %s from the value %s", paramNames[k], v)
return err
}
unescapedParams = append(unescapedParams, unescape)
}
c.SetParamValues(unescapedParams...)
return next(c)
}
}
}

func FromContext(ctx context.Context) string {
hdr, ok := ctx.Value(common.CorrelationHeader).(string)
if !ok {
hdr = ""
}
return hdr
}
88 changes: 88 additions & 0 deletions bootstrap/handlers/common_middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
//
// Copyright (C) 2023 IOTech Ltd
//
// SPDX-License-Identifier: Apache-2.0

package handlers

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/edgexfoundry/go-mod-core-contracts/v3/clients/logger"
"github.com/edgexfoundry/go-mod-core-contracts/v3/clients/logger/mocks"
"github.com/edgexfoundry/go-mod-core-contracts/v3/common"

"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

var expectedCorrelationId = "927e91d3-864c-4c26-852d-b68c39492d14"

var handler = func(c echo.Context) error {
return c.String(http.StatusOK, "OK")
}

func TestManageHeader(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
c.Response().Header().Set(common.CorrelationHeader, c.Request().Context().Value(common.CorrelationHeader).(string))
c.Response().Header().Set(common.ContentType, c.Request().Context().Value(common.ContentType).(string))
c.Response().WriteHeader(http.StatusOK)
return nil
})
e.Use(ManageHeader)

req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(common.CorrelationHeader, expectedCorrelationId)
expectedContentType := common.ContentTypeJSON
req.Header.Set(common.ContentType, expectedContentType)
res := httptest.NewRecorder()
e.ServeHTTP(res, req)

assert.Equal(t, http.StatusOK, res.Code)
assert.Equal(t, expectedCorrelationId, res.Header().Get(common.CorrelationHeader))
assert.Equal(t, expectedContentType, res.Header().Get(common.ContentType))
}

func TestLoggingMiddleware(t *testing.T) {
e := echo.New()
e.GET("/", handler)
lcMock := &mocks.LoggingClient{}
lcMock.On("Trace", "Begin request", common.CorrelationHeader, expectedCorrelationId, "path", "/")
lcMock.On("Trace", "Response complete", common.CorrelationHeader, expectedCorrelationId, "duration", mock.Anything)
lcMock.On("LogLevel").Return("TRACE")
e.Use(LoggingMiddleware(lcMock))

req := httptest.NewRequest(http.MethodGet, "/", nil)
// lint:ignore SA1029 legacy
// nolint:staticcheck // See golangci-lint #741
ctx := context.WithValue(req.Context(), common.CorrelationHeader, expectedCorrelationId)
req = req.WithContext(ctx)
res := httptest.NewRecorder()
e.ServeHTTP(res, req)

lcMock.AssertCalled(t, "Trace", "Begin request", common.CorrelationHeader, expectedCorrelationId, "path", "/")
lcMock.AssertCalled(t, "Trace", "Response complete", common.CorrelationHeader, expectedCorrelationId, "duration", mock.Anything)
assert.Equal(t, http.StatusOK, res.Code)
}

func TestUrlDecodeMiddleware(t *testing.T) {
e := echo.New()

req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
c := e.NewContext(req, res)
c.SetParamNames("foo")
c.SetParamValues("abc%2F123%25") // the decoded value is abc/123%

lc = logger.NewMockClient()
m := UrlDecodeMiddleware(lc)
err := m(handler)(c)

assert.NoError(t, err)
assert.Equal(t, "abc/123%", c.Param("foo"))
}
9 changes: 9 additions & 0 deletions bootstrap/handlers/httpserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,15 @@ func (b *HttpServer) BootstrapHandler(
return false
}

// Use the common middlewares
secretProvider := container.SecretProviderExtFrom(dic.Get)
authenticationHook := AutoConfigAuthenticationFunc(secretProvider, lc)

b.router.Use(authenticationHook)
b.router.Use(ManageHeader)
b.router.Use(LoggingMiddleware(lc))
b.router.Use(UrlDecodeMiddleware(lc))

timeout, err := time.ParseDuration(bootstrapConfig.Service.RequestTimeout)
if err != nil {
lc.Errorf("unable to parse RequestTimeout value of %s to a duration: %v", bootstrapConfig.Service.RequestTimeout, err)
Expand Down

0 comments on commit 1888942

Please sign in to comment.