diff --git a/http/middleware/chi.go b/http/middleware/chi.go new file mode 100644 index 0000000000..b9387b40e2 --- /dev/null +++ b/http/middleware/chi.go @@ -0,0 +1,59 @@ +package middleware + +import ( + "fmt" + "net/http" + + "github.com/go-chi/chi/v5" +) + +// SmartRedirectSlashes is a middleware that matches the request path with +// patterns added to the router and redirects it. +// +// If a pattern is added to the router with a trailing slash, any matches on +// that pattern without a trailing slash will be redirected to the version with +// the slash. If a pattern does not have a trailing slash, matches on that +// pattern with a trailing slash will be redirected to the version without. +// +// This middleware depends on chi, so it needs to be mounted on chi's router. +// It make the router behavior similar to httptreemux. +func SmartRedirectSlashes(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + rctx := chi.RouteContext(r.Context()) + if rctx != nil { + var path string + if rctx.RoutePath != "" { + path = rctx.RoutePath + } else { + path = r.URL.Path + } + var method string + if rctx.RouteMethod != "" { + method = rctx.RouteMethod + } else { + method = r.Method + } + if len(path) > 1 { + if rctx.Routes != nil { + if !rctx.Routes.Match(chi.NewRouteContext(), method, path) { + if path[len(path)-1] == '/' { + path = path[:len(path)-1] + } else { + path += "/" + } + if rctx.Routes.Match(chi.NewRouteContext(), method, path) { + if r.URL.RawQuery != "" { + path = fmt.Sprintf("%s?%s", path, r.URL.RawQuery) + } + redirectURL := fmt.Sprintf("//%s%s", r.Host, path) + http.Redirect(w, r, redirectURL, http.StatusMovedPermanently) + return + } + } + } + } + } + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) +} diff --git a/http/middleware/chi_test.go b/http/middleware/chi_test.go new file mode 100644 index 0000000000..ac502ac9fc --- /dev/null +++ b/http/middleware/chi_test.go @@ -0,0 +1,58 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + goahttp "goa.design/goa/v3/http" +) + +func TestSmartRedirectSlashes(t *testing.T) { + cases := []struct { + Pattern string + URL string + Status int + Location string + }{ + {"/users", "/users", http.StatusOK, ""}, + {"/users", "/users/", http.StatusMovedPermanently, "/users"}, + {"/users/", "/users/", http.StatusOK, ""}, + {"/users/", "/users", http.StatusMovedPermanently, "/users/"}, + {"/users/{id}", "/users/123", http.StatusOK, ""}, + {"/users/{id}", "/users/123/", http.StatusMovedPermanently, "/users/123"}, + {"/users/{id}/", "/users/123/", http.StatusOK, ""}, + {"/users/{id}/", "/users/123", http.StatusMovedPermanently, "/users/123/"}, + {"/users/{id}/posts/{post_id}", "/users/123/posts/456", http.StatusOK, ""}, + {"/users/{id}/posts/{post_id}", "/users/123/posts/456/", http.StatusMovedPermanently, "/users/123/posts/456"}, + {"/users/{id}/posts/{post_id}/", "/users/123/posts/456/", http.StatusOK, ""}, + {"/users/{id}/posts/{post_id}/", "/users/123/posts/456", http.StatusMovedPermanently, "/users/123/posts/456/"}, + {"/users/{id}/posts/{*post_id}", "/users/123/posts/456/789", http.StatusOK, ""}, + {"/users/{id}/posts/{*post_id}", "/users/123/posts/456/789/", http.StatusOK, ""}, + {"/users", "/users?name=foo", http.StatusOK, ""}, + {"/users", "/users/?name=foo", http.StatusMovedPermanently, "/users?name=foo"}, + {"/users/", "/users/?name=foo", http.StatusOK, ""}, + {"/users/", "/users?name=foo", http.StatusMovedPermanently, "/users/?name=foo"}, + } + + for _, c := range cases { + t.Run(c.Pattern, func(t *testing.T) { + var called bool + mux := goahttp.NewMuxer() + mux.Use(SmartRedirectSlashes) + handler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + called = true + }) + mux.Handle("GET", c.Pattern, handler) + req, _ := http.NewRequest("GET", c.URL, nil) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + assert.Equal(t, c.Status, w.Code) + assert.Equal(t, w.Code == http.StatusOK, called) + if w.Code == http.StatusMovedPermanently { + assert.Equal(t, c.Location, w.Header().Get("Location")) + } + }) + } +}