Skip to content

Commit

Permalink
chore(allsrv): replace http.DefaultServeMux with isolated `*http.Se…
Browse files Browse the repository at this point in the history
…rveMux` dependency

This resolves the panic adding routes with the same pattern multiple
times. Now each `Server`, has its own `*http.ServeMux`. Now tests run independent
of one another and we avoid the pain of GLOBALS!

The tests should now pass :-)
  • Loading branch information
jsteenb2 committed Jul 5, 2024
1 parent efa38f4 commit 526c39c
Showing 1 changed file with 24 additions and 22 deletions.
46 changes: 24 additions & 22 deletions allsrv/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"errors"
"log"
"net/http"

"github.com/gofrs/uuid"
)

Expand All @@ -18,7 +18,7 @@ import (
a) what happens if we forget that copy pasta?
3) auth is hardcoded to basic auth
a) what happens if we want to adapt some other means of auth?
4) router being used is the GLOBAL http.DefaultServeMux
4) router being used is the GLOBAL http.DefaultServeMux
a) should avoid globals
b) what happens if you have multiple servers in this go module who reference default serve mux?
5) no tests
Expand All @@ -45,10 +45,11 @@ import (
*/

type Server struct {
db *InmemDB // 1)

db *InmemDB // 1)
mux *http.ServeMux // 4)

user, pass string // 3)

idFn func() string // 11)
}

Expand All @@ -62,6 +63,7 @@ func WithIDFn(fn func() string) func(*Server) {
func NewServer(db *InmemDB, user, pass string, opts ...func(*Server)) *Server {
s := Server{
db: db,
mux: http.NewServeMux(), // 4)
user: user,
pass: pass,
idFn: func() string {
Expand All @@ -72,22 +74,22 @@ func NewServer(db *InmemDB, user, pass string, opts ...func(*Server)) *Server {
for _, o := range opts {
o(&s)
}

s.routes()
return &s
}

func (s *Server) routes() {
// 4) 7) 9) 10)
http.Handle("POST /foo", http.HandlerFunc(s.createFoo))
http.Handle("GET /foo", http.HandlerFunc(s.readFoo))
http.Handle("PUT /foo", http.HandlerFunc(s.updateFoo))
http.Handle("DELETE /foo", http.HandlerFunc(s.delFoo))
s.mux.Handle("POST /foo", http.HandlerFunc(s.createFoo))
s.mux.Handle("GET /foo", http.HandlerFunc(s.readFoo))
s.mux.Handle("PUT /foo", http.HandlerFunc(s.updateFoo))
s.mux.Handle("DELETE /foo", http.HandlerFunc(s.delFoo))
}

func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// 4)
http.DefaultServeMux.ServeHTTP(w, r)
s.mux.ServeHTTP(w, r)
}

type Foo struct {
Expand All @@ -103,20 +105,20 @@ func (s *Server) createFoo(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized) // 9)
return
}

var f Foo
if err := json.NewDecoder(r.Body).Decode(&f); err != nil {
w.WriteHeader(http.StatusForbidden) // 9)
return
}

f.ID = s.idFn() // 11)

if err := s.db.CreateFoo(f); err != nil {
w.WriteHeader(http.StatusInternalServerError) // 9)
return
}

w.WriteHeader(http.StatusCreated)
if err := json.NewEncoder(w).Encode(f); err != nil {
log.Printf("unexpected error writing json value to response body: " + err.Error()) // 8) 10)
Expand All @@ -129,13 +131,13 @@ func (s *Server) readFoo(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized) // 9)
return
}

f, err := s.db.readFoo(r.URL.Query().Get("id"))
if err != nil {
w.WriteHeader(http.StatusNotFound) // 9)
return
}

if err := json.NewEncoder(w).Encode(f); err != nil {
log.Printf("unexpected error writing json value to response body: " + err.Error()) // 8) 10)
}
Expand All @@ -147,13 +149,13 @@ func (s *Server) updateFoo(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized) // 9)
return
}

var f Foo
if err := json.NewDecoder(r.Body).Decode(&f); err != nil {
w.WriteHeader(http.StatusForbidden) // 9)
return
}

if err := s.db.updateFoo(f); err != nil {
w.WriteHeader(http.StatusInternalServerError) // 9)
return
Expand All @@ -166,7 +168,7 @@ func (s *Server) delFoo(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized) // 9)
return
}

if err := s.db.delFoo(r.URL.Query().Get("id")); err != nil {
w.WriteHeader(http.StatusNotFound) // 9)
return
Expand All @@ -184,9 +186,9 @@ func (db *InmemDB) CreateFoo(f Foo) error {
return errors.New("foo " + f.Name + " exists") // 8)
}
}

db.m = append(db.m, f)

return nil
}

Expand Down

0 comments on commit 526c39c

Please sign in to comment.