Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SmartRedirectSlashes to http/middleware #3393

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions http/middleware/chi.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package middleware

import (
"fmt"
"net/http"

"github.com/go-chi/chi/v5"
)

// SmartRedirectSlashes is a middleware that matches the request path with
// patterns added to the router and redirects it.
//
// If a pattern is added to the router with a trailing slash, any matches on
// that pattern without a trailing slash will be redirected to the version with
// the slash. If a pattern does not have a trailing slash, matches on that
// pattern with a trailing slash will be redirected to the version without.
//
// This middleware depends on chi, so it needs to be mounted on chi's router.
// It make the router behavior similar to httptreemux.
func SmartRedirectSlashes(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
rctx := chi.RouteContext(r.Context())
if rctx != nil {
var path string
if rctx.RoutePath != "" {
path = rctx.RoutePath
} else {
path = r.URL.Path
}
var method string
if rctx.RouteMethod != "" {
method = rctx.RouteMethod
} else {
method = r.Method
}
if len(path) > 1 {
if rctx.Routes != nil {
if !rctx.Routes.Match(chi.NewRouteContext(), method, path) {
if path[len(path)-1] == '/' {
path = path[:len(path)-1]
} else {
path += "/"
}
if rctx.Routes.Match(chi.NewRouteContext(), method, path) {
if r.URL.RawQuery != "" {
path = fmt.Sprintf("%s?%s", path, r.URL.RawQuery)
}
redirectURL := fmt.Sprintf("//%s%s", r.Host, path)
http.Redirect(w, r, redirectURL, http.StatusMovedPermanently)
return
}
}
}
}
}
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
58 changes: 58 additions & 0 deletions http/middleware/chi_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package middleware

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
goahttp "goa.design/goa/v3/http"
)

func TestSmartRedirectSlashes(t *testing.T) {
cases := []struct {
Pattern string
URL string
Status int
Location string
}{
{"/users", "/users", http.StatusOK, ""},
{"/users", "/users/", http.StatusMovedPermanently, "/users"},
{"/users/", "/users/", http.StatusOK, ""},
{"/users/", "/users", http.StatusMovedPermanently, "/users/"},
{"/users/{id}", "/users/123", http.StatusOK, ""},
{"/users/{id}", "/users/123/", http.StatusMovedPermanently, "/users/123"},
{"/users/{id}/", "/users/123/", http.StatusOK, ""},
{"/users/{id}/", "/users/123", http.StatusMovedPermanently, "/users/123/"},
{"/users/{id}/posts/{post_id}", "/users/123/posts/456", http.StatusOK, ""},
{"/users/{id}/posts/{post_id}", "/users/123/posts/456/", http.StatusMovedPermanently, "/users/123/posts/456"},
{"/users/{id}/posts/{post_id}/", "/users/123/posts/456/", http.StatusOK, ""},
{"/users/{id}/posts/{post_id}/", "/users/123/posts/456", http.StatusMovedPermanently, "/users/123/posts/456/"},
{"/users/{id}/posts/{*post_id}", "/users/123/posts/456/789", http.StatusOK, ""},
{"/users/{id}/posts/{*post_id}", "/users/123/posts/456/789/", http.StatusOK, ""},
{"/users", "/users?name=foo", http.StatusOK, ""},
{"/users", "/users/?name=foo", http.StatusMovedPermanently, "/users?name=foo"},
{"/users/", "/users/?name=foo", http.StatusOK, ""},
{"/users/", "/users?name=foo", http.StatusMovedPermanently, "/users/?name=foo"},
}

for _, c := range cases {
t.Run(c.Pattern, func(t *testing.T) {
var called bool
mux := goahttp.NewMuxer()
mux.Use(SmartRedirectSlashes)
handler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
called = true
})
mux.Handle("GET", c.Pattern, handler)
req, _ := http.NewRequest("GET", c.URL, nil)
w := httptest.NewRecorder()
mux.ServeHTTP(w, req)
assert.Equal(t, c.Status, w.Code)
assert.Equal(t, w.Code == http.StatusOK, called)
if w.Code == http.StatusMovedPermanently {
assert.Equal(t, c.Location, w.Header().Get("Location"))
}
})
}
}