From f41c697d0a516ef217e3f824c350ba80929602ba Mon Sep 17 00:00:00 2001 From: Nico Bistolfi <4433590+nicobistolfi@users.noreply.github.com> Date: Sun, 22 Sep 2024 13:41:37 -0700 Subject: [PATCH] Codecoverage improve middlewares (#19) * codecov * codecove middleware * testing code coverage * deleting token cache test for now --- .github/workflows/codecov.yml | 3 +- go.mod | 7 +- go.sum | 11 +- .../api/middleware/auth_middleware_test.go | 51 ++++++ internal/api/middleware/cors.go | 30 ++-- .../api/middleware/cors_middleware_test.go | 77 ++++++++ .../api/middleware/logging_middleware_test.go | 46 +++++ internal/api/middleware/rate_limiter.go | 17 +- internal/api/middleware/rate_limiter_test.go | 166 ++++++++++++++++++ internal/api/middleware/token.go | 30 +++- internal/api/middleware/token_test.go | 78 ++++++++ 11 files changed, 491 insertions(+), 25 deletions(-) create mode 100644 internal/api/middleware/auth_middleware_test.go create mode 100644 internal/api/middleware/cors_middleware_test.go create mode 100644 internal/api/middleware/logging_middleware_test.go create mode 100644 internal/api/middleware/rate_limiter_test.go create mode 100644 internal/api/middleware/token_test.go diff --git a/.github/workflows/codecov.yml b/.github/workflows/codecov.yml index 00cae3b..f835bab 100644 --- a/.github/workflows/codecov.yml +++ b/.github/workflows/codecov.yml @@ -1,7 +1,6 @@ name: Run tests and upload coverage -on: - push +on: push jobs: test: diff --git a/go.mod b/go.mod index 4a3b757..b454a22 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,6 @@ go 1.22.5 require ( github.com/aws/aws-lambda-go v1.47.0 github.com/awslabs/aws-lambda-go-api-proxy v0.16.2 - github.com/gin-contrib/cors v1.7.2 github.com/gin-gonic/gin v1.9.1 github.com/golang-jwt/jwt/v5 v5.2.1 github.com/joho/godotenv v1.5.1 @@ -14,6 +13,12 @@ require ( golang.org/x/time v0.5.0 ) +require ( + github.com/kr/pretty v0.3.0 // indirect + github.com/rogpeppe/go-internal v1.8.0 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect +) + require ( github.com/bytedance/sonic v1.11.6 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect diff --git a/go.sum b/go.sum index b7d3eff..3ee8b9b 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,7 @@ github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/ github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -17,8 +18,6 @@ github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nos github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= -github.com/gin-contrib/cors v1.7.2 h1:oLDHxdg8W/XDoN/8zamqk/Drgt4oVZDvaV0YmvVICQw= -github.com/gin-contrib/cors v1.7.2/go.mod h1:SUJVARKgQ40dmrzgXEVxj2m7Ig1v1qIboQkPDTQ9t2E= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= @@ -46,8 +45,12 @@ github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa02 github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= @@ -67,8 +70,10 @@ github.com/onsi/gomega v1.27.10 h1:naR28SdDFlqrG6kScpT8VWpu1xWY5nJRCF3XaYyBjhI= github.com/onsi/gomega v1.27.10/go.mod h1:RsS8tutOdbdgzbPtzzATp12yT7kM5I5aElG3evPbQ0M= github.com/pelletier/go-toml/v2 v2.2.1 h1:9TA9+T8+8CUCO2+WYnDLCgrYi9+omqKXyjDtosvtEhg= github.com/pelletier/go-toml/v2 v2.2.1/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -111,8 +116,10 @@ golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= google.golang.org/protobuf v1.34.0 h1:Qo/qEd2RZPCf2nKuorzksSknv0d3ERwp1vFG38gSmH4= google.golang.org/protobuf v1.34.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/api/middleware/auth_middleware_test.go b/internal/api/middleware/auth_middleware_test.go new file mode 100644 index 0000000..6c9d99c --- /dev/null +++ b/internal/api/middleware/auth_middleware_test.go @@ -0,0 +1,51 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" +) + +func TestAuthMiddlewareFunc(t *testing.T) { + // Set Gin to Test Mode + gin.SetMode(gin.TestMode) + + // Create a new Gin engine + r := gin.New() + + // Use the AuthMiddleware + protected := r.Group("/") + protected.Use(AuthMiddleware()) + // Define a test route + protected.GET("/protected", func(c *gin.Context) { + c.String(http.StatusOK, "protected") + }) + + tests := []struct { + name string + token string + expectedStatus int + }{ + {"Valid Token", "Bearer valid_token", http.StatusOK}, + {"Invalid Token", "Bearer invalid_token", http.StatusOK}, // This is expected to pass, as the auth only checks for the token presence + {"No Token", "", http.StatusUnauthorized}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, _ := http.NewRequest("GET", "/protected", nil) + if tt.token != "" { + req.Header.Set("Authorization", tt.token) + } + resp := httptest.NewRecorder() + + r.ServeHTTP(resp, req) + + if resp.Code != tt.expectedStatus { + t.Errorf("Expected status %d; got %d", tt.expectedStatus, resp.Code) + } + }) + } +} diff --git a/internal/api/middleware/cors.go b/internal/api/middleware/cors.go index 5f0c77f..a41d625 100644 --- a/internal/api/middleware/cors.go +++ b/internal/api/middleware/cors.go @@ -1,27 +1,33 @@ package middleware import ( + "net/http" "os" "strings" - "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" ) func CORSMiddleware() gin.HandlerFunc { - config := cors.DefaultConfig() + return func(c *gin.Context) { + origin := c.Request.Header.Get("Origin") + allowedOrigins := os.Getenv("ALLOWED_ORIGINS") - allowedOrigins := os.Getenv("ALLOWED_ORIGINS") - if allowedOrigins != "" { - config.AllowOrigins = strings.Split(allowedOrigins, ",") - } else { - config.AllowOriginFunc = func(origin string) bool { - return strings.HasPrefix(origin, "http://localhost") || strings.HasPrefix(origin, "https://localhost") + if origin != "" && (allowedOrigins == "*" || strings.Contains(allowedOrigins, origin)) { + c.Header("Access-Control-Allow-Origin", origin) + } else { + c.Header("Access-Control-Allow-Origin", "*") } - } - config.AllowMethods = []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"} - config.AllowHeaders = []string{"Origin", "Content-Type", "Accept", "Authorization"} + c.Header("Access-Control-Allow-Methods", "GET,POST,PUT,PATCH,DELETE,OPTIONS") + c.Header("Access-Control-Allow-Headers", "Authorization,Content-Type") + c.Header("Access-Control-Allow-Credentials", "true") + + if c.Request.Method == "OPTIONS" { + c.AbortWithStatus(http.StatusOK) + return + } - return cors.New(config) + c.Next() + } } diff --git a/internal/api/middleware/cors_middleware_test.go b/internal/api/middleware/cors_middleware_test.go new file mode 100644 index 0000000..ea39d23 --- /dev/null +++ b/internal/api/middleware/cors_middleware_test.go @@ -0,0 +1,77 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/gin-gonic/gin" +) + +func TestCORSMiddlewareFunc(t *testing.T) { + // Set Gin to Test Mode + gin.SetMode(gin.TestMode) + + // Create a new Gin engine + r := gin.New() + + // Use the CORSMiddleware + r.Use(CORSMiddleware()) + + // Define a test route + r.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "test") + }) + + tests := []struct { + name string + method string + origin string + expectedStatus int + expectedHeaders map[string]string + }{ + { + name: "OPTIONS request", + method: "OPTIONS", + origin: "http://example.com", + expectedStatus: http.StatusOK, + expectedHeaders: map[string]string{ + "Access-Control-Allow-Origin": "http://example.com", + "Access-Control-Allow-Methods": "GET,POST,PUT,PATCH,DELETE,OPTIONS", + "Access-Control-Allow-Headers": "Authorization,Content-Type", + "Access-Control-Allow-Credentials": "true", + }, + }, + { + name: "GET request", + method: "GET", + origin: "http://example.com", + expectedStatus: http.StatusOK, + expectedHeaders: map[string]string{ + "Access-Control-Allow-Origin": "http://example.com", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + os.Setenv("ALLOWED_ORIGINS", tt.origin) + req, _ := http.NewRequest(tt.method, "/test", nil) + req.Header.Set("Origin", tt.origin) + resp := httptest.NewRecorder() + + r.ServeHTTP(resp, req) + + if resp.Code != tt.expectedStatus { + t.Errorf("Expected status %d; got %d", tt.expectedStatus, resp.Code) + } + + for key, value := range tt.expectedHeaders { + if resp.Header().Get(key) != value { + t.Errorf("Expected header %s to be %s; got %s", key, value, resp.Header().Get(key)) + } + } + }) + } +} diff --git a/internal/api/middleware/logging_middleware_test.go b/internal/api/middleware/logging_middleware_test.go new file mode 100644 index 0000000..24603b9 --- /dev/null +++ b/internal/api/middleware/logging_middleware_test.go @@ -0,0 +1,46 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + customLogger "github.com/nicobistolfi/go-rest-api/pkg" +) + +func TestLoggerMiddlewareFunc(t *testing.T) { + customLogger.Init() + // Set Gin to Test Mode + gin.SetMode(gin.TestMode) + + // Create a new Gin engine + r := gin.New() + + // Create a custom logger + logger := customLogger.Log + + // Use the LoggerMiddleware + r.Use(LoggerMiddleware(logger)) + + // Define a test route + r.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "test") + }) + + // Create a test request + req, _ := http.NewRequest("GET", "/test", nil) + resp := httptest.NewRecorder() + + // Serve the request + r.ServeHTTP(resp, req) + + // Check the status code + if resp.Code != http.StatusOK { + t.Errorf("Expected status %d; got %d", http.StatusOK, resp.Code) + } + + // Add more assertions here to check logging behavior + // For example, you could use a custom io.Writer to capture log output + // and assert on its contents +} diff --git a/internal/api/middleware/rate_limiter.go b/internal/api/middleware/rate_limiter.go index 1ac6494..ed58639 100644 --- a/internal/api/middleware/rate_limiter.go +++ b/internal/api/middleware/rate_limiter.go @@ -9,7 +9,7 @@ import ( "golang.org/x/time/rate" ) -func RateLimiter(r rate.Limit, b int) gin.HandlerFunc { +func RateLimiter(r rate.Limit, b int, keyPrefixes ...string) gin.HandlerFunc { type client struct { limiter *rate.Limiter lastSeen int64 //lint:ignore U1000 This field is currently unused but may be used in future implementations @@ -23,10 +23,25 @@ func RateLimiter(r rate.Limit, b int) gin.HandlerFunc { return func(c *gin.Context) { // Use only IP if Authorization header is empty key := c.ClientIP() + // Check if c.ClientIP() is empty + if key == "" { + // get the first X-Real-Ip header + key = c.GetHeader("X-Real-Ip") + } + if key == "" { + // get the first X-Forwarded-For header + key = c.GetHeader("X-Forwarded-For") + } + if auth := c.GetHeader("Authorization"); auth != "" { key += ":" + auth } + // Add the optional keyPrefix to the key if provided + if len(keyPrefixes) > 0 && keyPrefixes[0] != "" { + key = keyPrefixes[0] + ":" + key + } + mu.Lock() if _, found := clients[key]; !found { clients[key] = &client{limiter: rate.NewLimiter(r, b)} diff --git a/internal/api/middleware/rate_limiter_test.go b/internal/api/middleware/rate_limiter_test.go new file mode 100644 index 0000000..cfc32b8 --- /dev/null +++ b/internal/api/middleware/rate_limiter_test.go @@ -0,0 +1,166 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + "golang.org/x/time/rate" +) + +func TestRateLimiter(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + limit rate.Limit + burst int + requests int + expectedStatus int + }{ + {"Under limit", rate.Limit(10), 5, 5, http.StatusOK}, + {"At limit", rate.Limit(10), 10, 10, http.StatusOK}, + {"Over limit", rate.Limit(10), 10, 11, http.StatusTooManyRequests}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := gin.New() + r.Use(RateLimiter(tt.limit, tt.burst, tt.name)) + r.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "test") + }) + + for i := 0; i < tt.requests; i++ { + req, _ := http.NewRequest("GET", "/test", nil) + resp := httptest.NewRecorder() + r.ServeHTTP(resp, req) + + t.Logf("Request %d: Status Code %d", i+1, resp.Code) + + if i == tt.requests-1 { + if resp.Code != tt.expectedStatus { + t.Errorf("Expected status %d, got %d", tt.expectedStatus, resp.Code) + } + } else if resp.Code != http.StatusOK { + t.Errorf("Request %d: Expected status %d, got %d", i+1, http.StatusOK, resp.Code) + } + } + }) + } +} + +func TestRateLimiterWithDifferentClients(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(RateLimiter(rate.Limit(1), 1)) + r.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "test") + }) + + clients := []string{"1.1.1.1", "2.2.2.2"} + + for _, client := range clients { + req, _ := http.NewRequest("GET", "/test", nil) + req.Header.Set("X-Forwarded-For", client) + resp := httptest.NewRecorder() + r.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Errorf("Expected status %d for client %s, got %d", http.StatusOK, client, resp.Code) + } + } +} + +func TestRateLimiterWithAuthHeader(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(RateLimiter(rate.Limit(1), 1)) + r.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "test") + }) + + req, _ := http.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer token1") + resp := httptest.NewRecorder() + r.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Errorf("Expected status %d, got %d", http.StatusOK, resp.Code) + } + + // Same IP, different token + req, _ = http.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer token2") + resp = httptest.NewRecorder() + r.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Errorf("Expected status %d, got %d", http.StatusOK, resp.Code) + } +} + +func TestRateLimiterBurst(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(RateLimiter(rate.Limit(1), 3)) // 1 request per second, burst of 3 + r.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "test") + }) + + for i := 0; i < 4; i++ { + req, _ := http.NewRequest("GET", "/test", nil) + resp := httptest.NewRecorder() + r.ServeHTTP(resp, req) + + if i < 3 { + if resp.Code != http.StatusOK { + t.Errorf("Expected status %d for request %d, got %d", http.StatusOK, i+1, resp.Code) + } + } else { + if resp.Code != http.StatusTooManyRequests { + t.Errorf("Expected status %d for request %d, got %d", http.StatusTooManyRequests, i+1, resp.Code) + } + } + } +} + +func TestRateLimiterRecovery(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(RateLimiter(rate.Limit(1), 1)) // 1 request per second + r.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "test") + }) + + // First request should succeed + req, _ := http.NewRequest("GET", "/test", nil) + resp := httptest.NewRecorder() + r.ServeHTTP(resp, req) + if resp.Code != http.StatusOK { + t.Errorf("Expected status %d, got %d", http.StatusOK, resp.Code) + } + + // Second request should fail + resp = httptest.NewRecorder() + r.ServeHTTP(resp, req) + if resp.Code != http.StatusTooManyRequests { + t.Errorf("Expected status %d, got %d", http.StatusTooManyRequests, resp.Code) + } + + // Wait for rate limit to reset + time.Sleep(1100 * time.Millisecond) + + // Third request should succeed + resp = httptest.NewRecorder() + r.ServeHTTP(resp, req) + if resp.Code != http.StatusOK { + t.Errorf("Expected status %d, got %d", http.StatusOK, resp.Code) + } +} diff --git a/internal/api/middleware/token.go b/internal/api/middleware/token.go index 44e6f94..ab200d9 100644 --- a/internal/api/middleware/token.go +++ b/internal/api/middleware/token.go @@ -50,8 +50,19 @@ var ( cacheExpiry time.Duration ) -func VerifyToken() gin.HandlerFunc { +func VerifyToken(customCacheExpiry ...string) gin.HandlerFunc { return func(c *gin.Context) { + verifyCacheExpiry := cacheExpiry + if len(customCacheExpiry) > 0 { + verifyCacheExpiryParsed, err := time.ParseDuration(customCacheExpiry[0]) + if err != nil { + logger.Warn("Invalid cache expiry, using default", zap.Error(err)) + } else { + fmt.Printf("Using custom cache expiry: %v\n", verifyCacheExpiryParsed) + verifyCacheExpiry = verifyCacheExpiryParsed + fmt.Printf("verifyCacheExpiry: %v\n", verifyCacheExpiry) + } + } // Get the token from the context set by AuthMiddleware token, exists := c.Get("auth_token") authHeader, authHeaderExists := c.Get("auth_header") @@ -75,6 +86,7 @@ func VerifyToken() gin.HandlerFunc { } tokenURL := os.Getenv("TOKEN_URL") + fmt.Printf("tokenURL: %v\n", tokenURL) if tokenURL == "" { c.JSON(http.StatusInternalServerError, gin.H{"error": "TOKEN_URL not set"}) @@ -87,17 +99,21 @@ func VerifyToken() gin.HandlerFunc { entry, found := tokenCache[tokenString] cacheMutex.RUnlock() - if found && time.Now().Before(entry.expiry) { - // add a header to the response to indicate that the token is cached + valid := time.Now().Before(entry.expiry) + fmt.Printf("Valid: %v\n", valid) + + if found && valid { + fmt.Println("Token is still valid in cache") + // Token is still valid in cache c.Header("X-Token-Cache", "HIT") c.Set("user", entry.profile) c.Next() return - } else { - // add a header to the response to indicate that the token is not cached - c.Header("X-Token-Cache", "MISS") } + // Token not found in cache or expired + c.Header("X-Token-Cache", "MISS") + // Validate token using the TOKEN_URL req, err := http.NewRequest("GET", tokenURL, nil) if err != nil { @@ -160,7 +176,7 @@ func VerifyToken() gin.HandlerFunc { cacheMutex.Lock() tokenCache[tokenString] = cacheEntry{ profile: profile, - expiry: time.Now().Add(cacheExpiry), + expiry: time.Now().Add(verifyCacheExpiry), } cacheMutex.Unlock() diff --git a/internal/api/middleware/token_test.go b/internal/api/middleware/token_test.go new file mode 100644 index 0000000..b2f394d --- /dev/null +++ b/internal/api/middleware/token_test.go @@ -0,0 +1,78 @@ +package middleware + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/gin-gonic/gin" +) + +func TestVerifyToken(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Setup a mock token server + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") == "Bearer valid_token" { + profile := Profile{ + ID: "123", + Email: "test@example.com", + Name: "Test User", + } + json.NewEncoder(w).Encode(profile) + } else { + w.WriteHeader(http.StatusUnauthorized) + } + })) + defer mockServer.Close() + + os.Setenv("TOKEN_URL", mockServer.URL) + defer os.Unsetenv("TOKEN_URL") + + r := gin.New() + r.Use(func(c *gin.Context) { + c.Set("auth_token", c.GetHeader("Authorization")) + c.Set("auth_header", "Authorization") + c.Next() + }) + r.Use(VerifyToken()) + r.GET("/test", func(c *gin.Context) { + user, _ := c.Get("user") + c.JSON(http.StatusOK, user) + }) + + tests := []struct { + name string + token string + expectedStatus int + }{ + {"Valid token", "Bearer valid_token", http.StatusOK}, + {"Invalid token", "Bearer invalid_token", http.StatusUnauthorized}, + {"Missing token", "", http.StatusUnauthorized}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, _ := http.NewRequest("GET", "/test", nil) + if tt.token != "" { + req.Header.Set("Authorization", tt.token) + } + resp := httptest.NewRecorder() + r.ServeHTTP(resp, req) + + if resp.Code != tt.expectedStatus { + t.Errorf("Expected status %d, got %d", tt.expectedStatus, resp.Code) + } + + if tt.expectedStatus == http.StatusOK { + var profile Profile + json.NewDecoder(resp.Body).Decode(&profile) + if profile.ID != "123" || profile.Email != "test@example.com" || profile.Name != "Test User" { + t.Errorf("Unexpected profile data") + } + } + }) + } +}