Skip to content

Commit

Permalink
codecove middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
nicobistolfi committed Sep 22, 2024
1 parent 9d15736 commit 4dfebe4
Show file tree
Hide file tree
Showing 8 changed files with 538 additions and 20 deletions.
51 changes: 51 additions & 0 deletions internal/api/middleware/auth_middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package middleware

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/gin-gonic/gin"
)

func TestAuthMiddlewareFunc(t *testing.T) {
// Set Gin to Test Mode
gin.SetMode(gin.TestMode)

// Create a new Gin engine
r := gin.New()

// Use the AuthMiddleware
protected := r.Group("/")
protected.Use(AuthMiddleware())
// Define a test route
protected.GET("/protected", func(c *gin.Context) {
c.String(http.StatusOK, "protected")
})

tests := []struct {
name string
token string
expectedStatus int
}{
{"Valid Token", "Bearer valid_token", http.StatusOK},
{"Invalid Token", "Bearer invalid_token", http.StatusOK}, // This is expected to pass, as the auth only checks for the token presence
{"No Token", "", http.StatusUnauthorized},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, _ := http.NewRequest("GET", "/protected", nil)
if tt.token != "" {
req.Header.Set("Authorization", tt.token)
}
resp := httptest.NewRecorder()

r.ServeHTTP(resp, req)

if resp.Code != tt.expectedStatus {
t.Errorf("Expected status %d; got %d", tt.expectedStatus, resp.Code)
}
})
}
}
30 changes: 18 additions & 12 deletions internal/api/middleware/cors.go
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
package middleware

import (
"net/http"
"os"
"strings"

"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
)

func CORSMiddleware() gin.HandlerFunc {
config := cors.DefaultConfig()
return func(c *gin.Context) {
origin := c.Request.Header.Get("Origin")
allowedOrigins := os.Getenv("ALLOWED_ORIGINS")

allowedOrigins := os.Getenv("ALLOWED_ORIGINS")
if allowedOrigins != "" {
config.AllowOrigins = strings.Split(allowedOrigins, ",")
} else {
config.AllowOriginFunc = func(origin string) bool {
return strings.HasPrefix(origin, "http://localhost") || strings.HasPrefix(origin, "https://localhost")
if origin != "" && (allowedOrigins == "*" || strings.Contains(allowedOrigins, origin)) {
c.Header("Access-Control-Allow-Origin", origin)
} else {
c.Header("Access-Control-Allow-Origin", "*")
}
}

config.AllowMethods = []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"}
config.AllowHeaders = []string{"Origin", "Content-Type", "Accept", "Authorization"}
c.Header("Access-Control-Allow-Methods", "GET,POST,PUT,PATCH,DELETE,OPTIONS")
c.Header("Access-Control-Allow-Headers", "Authorization,Content-Type")
c.Header("Access-Control-Allow-Credentials", "true")

if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(http.StatusOK)
return
}

return cors.New(config)
c.Next()
}
}
77 changes: 77 additions & 0 deletions internal/api/middleware/cors_middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package middleware

import (
"net/http"
"net/http/httptest"
"os"
"testing"

"github.com/gin-gonic/gin"
)

func TestCORSMiddlewareFunc(t *testing.T) {
// Set Gin to Test Mode
gin.SetMode(gin.TestMode)

// Create a new Gin engine
r := gin.New()

// Use the CORSMiddleware
r.Use(CORSMiddleware())

// Define a test route
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "test")
})

tests := []struct {
name string
method string
origin string
expectedStatus int
expectedHeaders map[string]string
}{
{
name: "OPTIONS request",
method: "OPTIONS",
origin: "http://example.com",
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "http://example.com",
"Access-Control-Allow-Methods": "GET,POST,PUT,PATCH,DELETE,OPTIONS",
"Access-Control-Allow-Headers": "Authorization,Content-Type",
"Access-Control-Allow-Credentials": "true",
},
},
{
name: "GET request",
method: "GET",
origin: "http://example.com",
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "http://example.com",
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
os.Setenv("ALLOWED_ORIGINS", tt.origin)
req, _ := http.NewRequest(tt.method, "/test", nil)
req.Header.Set("Origin", tt.origin)
resp := httptest.NewRecorder()

r.ServeHTTP(resp, req)

if resp.Code != tt.expectedStatus {
t.Errorf("Expected status %d; got %d", tt.expectedStatus, resp.Code)
}

for key, value := range tt.expectedHeaders {
if resp.Header().Get(key) != value {
t.Errorf("Expected header %s to be %s; got %s", key, value, resp.Header().Get(key))
}
}
})
}
}
46 changes: 46 additions & 0 deletions internal/api/middleware/logging_middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package middleware

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/gin-gonic/gin"
customLogger "github.com/nicobistolfi/go-rest-api/pkg"
)

func TestLoggerMiddlewareFunc(t *testing.T) {
customLogger.Init()
// Set Gin to Test Mode
gin.SetMode(gin.TestMode)

// Create a new Gin engine
r := gin.New()

// Create a custom logger
logger := customLogger.Log

// Use the LoggerMiddleware
r.Use(LoggerMiddleware(logger))

// Define a test route
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "test")
})

// Create a test request
req, _ := http.NewRequest("GET", "/test", nil)
resp := httptest.NewRecorder()

// Serve the request
r.ServeHTTP(resp, req)

// Check the status code
if resp.Code != http.StatusOK {
t.Errorf("Expected status %d; got %d", http.StatusOK, resp.Code)
}

// Add more assertions here to check logging behavior
// For example, you could use a custom io.Writer to capture log output
// and assert on its contents
}
19 changes: 18 additions & 1 deletion internal/api/middleware/rate_limiter.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package middleware

import (
"fmt"
"net/http"
"sync"

Expand All @@ -9,7 +10,7 @@ import (
"golang.org/x/time/rate"
)

func RateLimiter(r rate.Limit, b int) gin.HandlerFunc {
func RateLimiter(r rate.Limit, b int, keyPrefixes ...string) gin.HandlerFunc {
type client struct {
limiter *rate.Limiter
lastSeen int64 //lint:ignore U1000 This field is currently unused but may be used in future implementations
Expand All @@ -23,10 +24,26 @@ func RateLimiter(r rate.Limit, b int) gin.HandlerFunc {
return func(c *gin.Context) {
// Use only IP if Authorization header is empty
key := c.ClientIP()
// Check if c.ClientIP() is empty
if key == "" {
// get the first X-Real-Ip header
key = c.GetHeader("X-Real-Ip")
}
if key == "" {
// get the first X-Forwarded-For header
key = c.GetHeader("X-Forwarded-For")
}

fmt.Println("key", key)
if auth := c.GetHeader("Authorization"); auth != "" {
key += ":" + auth
}

// Add the optional keyPrefix to the key if provided
if len(keyPrefixes) > 0 && keyPrefixes[0] != "" {
key = keyPrefixes[0] + ":" + key
}

mu.Lock()
if _, found := clients[key]; !found {
clients[key] = &client{limiter: rate.NewLimiter(r, b)}
Expand Down
Loading

0 comments on commit 4dfebe4

Please sign in to comment.