From 2bb059089efc0d0d8f70b800225fa3c381545d97 Mon Sep 17 00:00:00 2001 From: Johnny Steenbergen Date: Wed, 17 Jan 2024 22:16:19 -0600 Subject: [PATCH] chore(allsrv): add v2 Server create API Here we're leaning into a more structured API. This utilizes versioned URLs so that our endpoints can evolve in a meaningful way. We also make use of the [JSON-API spec](https://jsonapi.org/). This provides a common structure for our consumers. JSON-API is more opinionated than other API specs, but it has client libs that are widely available across most languages. We've chosen not to implement the entire spec, but enough to show off the core benefits of using a spec (not limited to JSON-API spec). The JSON-API spec is VERY structured (for better or worse), and would make this a [level 3 RMM](https://www.crummy.com/writing/speaking/2008-QCon/) compliant service, when the links/relationships are included. That can be incredibly powerful. As maintainers/developers we get the following from using an API Spec: * Standardized API shape, provides for strong abstractions * With a spec/standardization you can now remove the boilerplate altogether and potentially generate :allthetransportthings: with simple tooling * We eliminate some bike-shedding about API design. Kind of like `gofmt`, the API is no one's favorite, yet the API is everyone's favorite Consumers benefit in the following ways: * A surprised consumer is an unhappy consumer, following a Spec (even a bad one), helps inform consumers and becomes simpler over time to reason about. * Consumers may not require any SDK/client lib and can traverse the API on their own. This is part of the salespitch for RMM lvl 3/JSON-API, though I'm not in 100% agreement that is worth the effort. We've introduced a naive URI versioning scheme. There are a lot of ways to slice this bread. The simplest is arguably the URI versioning scheme, which is why we're using it here. However, there are a number of other options available as well. Versioning is a tough pill to swallow for most orgs. There are many strategies, and every strategy has 1000x opinions about why THIS IS THE WAY. Explore the links below yourself, determine what's important to your organization and go from there. Take note, there are many conflicting opinions in the resources above :hidethepain:. Another thing to take note of here is our use of middleware has increased to include some additional checks. In this case we have some additional checks, that all return the same response (via the API spec), and creates a one stop shop for these orthogonal concerns. For flavor, we've made use of generics to adhere to not only the JSON-API spec, but also the reduce the boilerplate in dealing with handlers. We'll expand on this in a bit. Next we'll take a look at making our tests more flexible so that we can extend our testcases without having to duplicate the entire test. Refs: [Intro to Versioning a Rest API](https://www.freecodecamp.org/news/how-to-version-a-rest-api/) Refs: [Versioning Rest Web API Best Practices - MSFT](https://learn.microsoft.com/en-us/azure/architecture/best-practices/api-design#versioning-a-restful-web-api) Refs: [API Design Cheat Sheet](https://github.com/RestCheatSheet/api-cheat-sheet#api-design-cheat-sheet) --- allsrv/errors.go | 9 +- allsrv/server.go | 52 ++++-- allsrv/server_v2.go | 381 +++++++++++++++++++++++++++++++++++++++ allsrv/server_v2_test.go | 128 +++++++++++++ 4 files changed, 550 insertions(+), 20 deletions(-) create mode 100644 allsrv/server_v2.go create mode 100644 allsrv/server_v2_test.go diff --git a/allsrv/errors.go b/allsrv/errors.go index 9d75e35..d303578 100644 --- a/allsrv/errors.go +++ b/allsrv/errors.go @@ -1,15 +1,18 @@ package allsrv const ( - errTypeExists = "exists" - errTypeNotFound = "not found" + errTypeUnknown = iota + errTypeExists + errTypeInvalid + errTypeNotFound + errTypeUnAuthed ) // Err provides a lightly structured error that we can attach behavior. Additionally, // the use of fields makes it possible for us to enrich our logging infra without // blowing up the message cardinality. type Err struct { - Type string + Type int Msg string Fields []any } diff --git a/allsrv/server.go b/allsrv/server.go index 7e2fd8d..aebd3f0 100644 --- a/allsrv/server.go +++ b/allsrv/server.go @@ -5,8 +5,10 @@ import ( "encoding/json" "log" "net/http" + "time" "github.com/gofrs/uuid" + "github.com/hashicorp/go-metrics" ) /* @@ -59,33 +61,40 @@ type ( } ) -type Server struct { - db DB // 1) - mux *http.ServeMux // 4) +type serverOpts struct { + authFn func(http.Handler) http.Handler + idFn func() string + nowFn func() time.Time - authFn func(http.Handler) http.Handler // 3) - idFn func() string // 11) + met *metrics.Metrics + mux *http.ServeMux } // WithBasicAuth sets the authorization fn for the server to basic auth. // 3) -func WithBasicAuth(user, pass string) func(*Server) { - return func(s *Server) { +func WithBasicAuth(user, pass string) func(*serverOpts) { + return func(s *serverOpts) { s.authFn = basicAuth(user, pass) } } // WithIDFn sets the id generation fn for the server. -func WithIDFn(fn func() string) func(*Server) { - return func(s *Server) { +func WithIDFn(fn func() string) func(*serverOpts) { + return func(s *serverOpts) { s.idFn = fn } } -func NewServer(db DB, opts ...func(*Server)) *Server { - s := Server{ - db: db, - mux: http.NewServeMux(), // 4) +type Server struct { + db DB // 1) + mux *http.ServeMux // 4) + + authFn func(http.Handler) http.Handler // 3) + idFn func() string // 11) +} + +func NewServer(db DB, opts ...func(*serverOpts)) *Server { + opt := serverOpts{ authFn: func(next http.Handler) http.Handler { // 3) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // defaults to no auth @@ -96,9 +105,17 @@ func NewServer(db DB, opts ...func(*Server)) *Server { // defaults to using a uuid return uuid.Must(uuid.NewV4()).String() }, + mux: http.NewServeMux(), } for _, o := range opts { - o(&s) + o(&opt) + } + + s := Server{ + db: db, + mux: opt.mux, // 4) + authFn: opt.authFn, + idFn: opt.idFn, } s.routes() @@ -122,9 +139,10 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { type Foo struct { // 6) - ID string `json:"id" gorm:"id"` - Name string `json:"name" gorm:"name"` - Note string `json:"note" gorm:"note"` + ID string `json:"id" gorm:"id"` + Name string `json:"name" gorm:"name"` + Note string `json:"note" gorm:"note"` + CreatedAt time.Time `json:"-" gorm:"created_at"` } func (s *Server) createFoo(w http.ResponseWriter, r *http.Request) { diff --git a/allsrv/server_v2.go b/allsrv/server_v2.go new file mode 100644 index 0000000..68177b3 --- /dev/null +++ b/allsrv/server_v2.go @@ -0,0 +1,381 @@ +package allsrv + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "time" + + "github.com/gofrs/uuid" + "github.com/hashicorp/go-metrics" +) + +func WithMetrics(mets *metrics.Metrics) func(*serverOpts) { + return func(o *serverOpts) { + o.met = mets + } +} + +func WithMux(mux *http.ServeMux) func(*serverOpts) { + return func(o *serverOpts) { + o.mux = mux + } +} + +func WithNowFn(fn func() time.Time) func(*serverOpts) { + return func(o *serverOpts) { + o.nowFn = fn + } +} + +type ServerV2 struct { + db DB // 1) + + mux *http.ServeMux + mw func(next http.Handler) http.Handler + idFn func() string // 11) + nowFn func() time.Time +} + +func NewServerV2(db DB, opts ...func(*serverOpts)) *ServerV2 { + opt := serverOpts{ + mux: http.NewServeMux(), + idFn: func() string { return uuid.Must(uuid.NewV4()).String() }, + nowFn: func() time.Time { return time.Now().UTC() }, + } + for _, o := range opts { + o(&opt) + } + + s := ServerV2{ + db: db, + mux: opt.mux, + idFn: opt.idFn, + nowFn: opt.nowFn, + } + + var mw []func(http.Handler) http.Handler + if opt.authFn != nil { + mw = append(mw, opt.authFn) + } + mw = append(mw, withTraceID, withStartTime) + if opt.met != nil { // put metrics last since these are executed LIFO + mw = append(mw, ObserveHandler("v2", opt.met)) + } + mw = append(mw, recoverer) + + s.mw = applyMW(mw...) + + s.routes() + + return &s +} + +func (s *ServerV2) routes() { + withContentTypeJSON := applyMW(contentTypeJSON, s.mw) + + // 9) + s.mux.Handle("POST /v1/foos", withContentTypeJSON(jsonIn(http.StatusCreated, s.createFooV1))) +} + +func (s *ServerV2) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // 4) + s.mux.ServeHTTP(w, r) +} + +// API envelope types +type ( + // RespResourceBody represents a JSON-API response body. + // https://jsonapi.org/format/#document-top-level + // + // note: data can be either an array or a single resource object. This allows for both. + RespResourceBody[Attrs any | []any] struct { + Meta RespMeta `json:"meta"` + Errs []RespErr `json:"errors,omitempty"` + Data *RespData[Attrs] `json:"data,omitempty"` + } + + // RespData represents a JSON-API data response. + // https://jsonapi.org/format/#document-top-level + RespData[Attr any | []Attr] struct { + Type string `json:"type"` + ID string `json:"id"` + Attributes Attr `json:"attributes"` + + // omitting the relationships here for brevity not at lvl 3 RMM + } + + // RespMeta represents a JSON-API meta object. The data here is + // useful for our example service. You can add whatever non-standard + // context that is relevant to your domain here. + // https://jsonapi.org/format/#document-meta + RespMeta struct { + TookMilli int `json:"took_ms"` + TraceID string `json:"trace_id"` + } + + // RespErr represents a JSON-API error object. Do note that we + // aren't implementing the entire error type. Just the most impactful + // bits for this workshop. Mainly, skipping Title & description separation. + // https://jsonapi.org/format/#error-objects + RespErr struct { + Status int `json:"status,string"` + Code int `json:"code"` + Msg string `json:"message"` + Source *RespErrSource `json:"source"` + } + + // RespErrSource represents a JSON-API err source. + // https://jsonapi.org/format/#error-objects + RespErrSource struct { + Pointer string `json:"pointer"` + Parameter string `json:"parameter,omitempty"` + Header string `json:"header,omitempty"` + } +) + +type ( + // ReqCreateFooV1 represents the request body for the create foo API. + ReqCreateFooV1 struct { + Name string `json:"name"` + Note string `json:"note"` + } + + // FooAttrs are the attributes for foo data. + FooAttrs struct { + Name string `json:"name"` + Note string `json:"note"` + CreatedAt string `json:"created_at"` + } +) + +func (s *ServerV2) createFooV1(ctx context.Context, req ReqCreateFooV1) (RespData[FooAttrs], []RespErr) { + newFoo := Foo{ + ID: s.idFn(), + Name: req.Name, + Note: req.Note, + CreatedAt: s.nowFn(), + } + if err := s.db.CreateFoo(ctx, newFoo); err != nil { + return RespData[FooAttrs]{}, toRespErrs(err) + } + + out := newFooData(newFoo.ID, FooAttrs{ + Name: newFoo.Name, + Note: newFoo.Note, + CreatedAt: toTimestamp(newFoo.CreatedAt), + }) + return out, nil +} + +func newFooData(id string, attrs FooAttrs) RespData[FooAttrs] { + return RespData[FooAttrs]{ + Type: "foo", + ID: id, + Attributes: attrs, + } +} + +func toTimestamp(t time.Time) string { + return t.Format(time.RFC3339) +} + +func jsonIn[ReqBody, Attr any](successCode int, fn func(context.Context, ReqBody) (RespData[Attr], []RespErr)) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var ( + reqBody ReqBody + errs []RespErr + out *RespData[Attr] + ) + if respErr := decodeReq(r, &reqBody); respErr != nil { + errs = append(errs, *respErr) + } else { + var data RespData[Attr] + data, errs = fn(r.Context(), reqBody) + if len(errs) == 0 { + out = &data + } + } + + status := successCode + for _, e := range errs { + if e.Status > status { + status = e.Status + } + } + + w.WriteHeader(status) + json.NewEncoder(w).Encode(RespResourceBody[Attr]{ + Meta: getMeta(r.Context()), + Errs: errs, + Data: out, + }) // 10.b) + }) +} + +func decodeReq(r *http.Request, v any) *RespErr { + if err := json.NewDecoder(r.Body).Decode(v); err != nil { + respErr := RespErr{ + Status: http.StatusBadRequest, + Msg: "failed to decode request body: " + err.Error(), + Source: &RespErrSource{ + Pointer: "/data", + }, + Code: errTypeInvalid, + } + if unmarshErr := new(json.UnmarshalTypeError); errors.As(err, &unmarshErr) { + respErr.Source.Pointer += "/attributes/" + unmarshErr.Field + } + return &respErr + } + + return nil +} + +func toRespErrs(err error) []RespErr { + if e := new(Err); errors.As(err, e) { + return []RespErr{{ + Code: errCode(e), + Msg: e.Msg, + }} + } + + errs, ok := err.(interface{ Unwrap() []error }) + if !ok { + return nil + } + + var out []RespErr + for _, e := range errs.Unwrap() { + out = append(out, toRespErrs(e)...) + } + + return out +} + +func errCode(err *Err) int { + switch err.Type { + case errTypeExists: + return http.StatusConflict + case errTypeNotFound: + return http.StatusNotFound + default: + return http.StatusInternalServerError + } +} + +// WithBasicAuthV2 sets the authorization fn for the server to basic auth. +// 3) +func WithBasicAuthV2(adminUser, adminPass string) func(*serverOpts) { + return func(s *serverOpts) { + s.authFn = func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if user, pass, ok := r.BasicAuth(); !(ok && user == adminUser && pass == adminPass) { + w.WriteHeader(http.StatusUnauthorized) // 9) + json.NewEncoder(w).Encode(RespResourceBody[any]{ + Meta: getMeta(r.Context()), + Errs: []RespErr{{ + Status: http.StatusUnauthorized, + Code: errTypeUnAuthed, + Msg: "unauthorized access", + Source: &RespErrSource{ + Header: "Authorization", + }, + }}, + }) + return + } + next.ServeHTTP(w, r) + }) + } + } +} + +func contentTypeJSON(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ct := r.Header.Get("Content-Type") + if ct != "application/json" { + w.WriteHeader(http.StatusUnsupportedMediaType) + json.NewEncoder(w).Encode(RespResourceBody[any]{ + Meta: getMeta(r.Context()), + Errs: []RespErr{{ + Code: http.StatusUnsupportedMediaType, + Msg: "received invalid media type", + }}, + }) + return + } + next.ServeHTTP(w, r) + }) +} + +func getMeta(ctx context.Context) RespMeta { + return RespMeta{ + TookMilli: int(took(ctx).Milliseconds()), + TraceID: getTraceID(ctx), + } +} + +func recoverer(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + rvr := recover() + if rvr == nil { + return + } + + if rvr == http.ErrAbortHandler { + // we don't recover http.ErrAbortHandler so the response + // to the client is aborted, this should not be logged + panic(rvr) + } + + w.WriteHeader(http.StatusInternalServerError) + }() + + next.ServeHTTP(w, r) + }) +} + +const ( + ctxStartTime = "start" + ctxTraceID = "trace-id" +) + +func withTraceID(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + traceID := r.Header.Get("X-Mess-Trace-Id") + if traceID == "" { + traceID = uuid.Must(uuid.NewV4()).String() + } + ctx := context.WithValue(r.Context(), ctxTraceID, traceID) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func getTraceID(ctx context.Context) string { + traceID, _ := ctx.Value(ctxTraceID).(string) + return traceID +} + +func withStartTime(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), ctxStartTime, time.Now()) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func took(ctx context.Context) time.Duration { + t, _ := ctx.Value(ctxStartTime).(time.Time) + return time.Since(t) +} + +func applyMW[T any](fns ...func(T) T) func(T) T { + return func(v T) T { + for i := len(fns) - 1; i >= 0; i-- { + v = fns[i](v) + } + return v + } +} diff --git a/allsrv/server_v2_test.go b/allsrv/server_v2_test.go new file mode 100644 index 0000000..f35de81 --- /dev/null +++ b/allsrv/server_v2_test.go @@ -0,0 +1,128 @@ +package allsrv_test + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/jsteenb2/mess/allsrv" +) + +func TestServerV2(t *testing.T) { + t.Run("foo create", func(t *testing.T) { + t.Run("when provided a valid foo should pass", func(t *testing.T) { + db := new(allsrv.InmemDB) + + var svr http.Handler = allsrv.NewServerV2( + db, + allsrv.WithBasicAuthV2("dodgers@stink.com", "PaSsWoRd"), + allsrv.WithMetrics(newTestMetrics(t)), + allsrv.WithIDFn(func() string { + return "id1" + }), + allsrv.WithNowFn(func() time.Time { + return time.Time{}.UTC().Add(time.Hour) + }), + ) + + req := httptest.NewRequest("POST", "/v1/foos", newJSONBody(t, allsrv.ReqCreateFooV1{ + Name: "first-foo", + Note: "some note", + })) + req.Header.Set("Content-Type", "application/json") + req.SetBasicAuth("dodgers@stink.com", "PaSsWoRd") + rec := httptest.NewRecorder() + + svr.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusCreated, rec.Code) + expectData[allsrv.FooAttrs](t, rec.Body, allsrv.RespData[allsrv.FooAttrs]{ + Type: "foo", + ID: "id1", + Attributes: allsrv.FooAttrs{ + Name: "first-foo", + Note: "some note", + CreatedAt: time.Time{}.UTC().Add(time.Hour).Format(time.RFC3339), + }, + }) + + got, err := db.ReadFoo(context.TODO(), "id1") + require.NoError(t, err) + + want := allsrv.Foo{ + ID: "id1", + Name: "first-foo", + Note: "some note", + CreatedAt: time.Time{}.UTC().Add(time.Hour), + } + assert.Equal(t, want, got) + }) + + t.Run("when missing required auth should fail", func(t *testing.T) { + var svr http.Handler = allsrv.NewServerV2( + new(allsrv.InmemDB), + allsrv.WithBasicAuthV2("dodgers@stink.com", "PaSsWoRd"), + allsrv.WithMetrics(newTestMetrics(t)), + allsrv.WithIDFn(func() string { + return "id1" + }), + allsrv.WithNowFn(func() time.Time { + return time.Time{}.UTC().Add(time.Hour) + }), + ) + + req := httptest.NewRequest("POST", "/v1/foos", newJSONBody(t, allsrv.ReqCreateFooV1{ + Name: "first-foo", + Note: "some note", + })) + req.Header.Set("Content-Type", "application/json") + req.SetBasicAuth("dodgers@stink.com", "WRONGO") + rec := httptest.NewRecorder() + + svr.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusUnauthorized, rec.Code) + expectErrs(t, rec.Body, func(t *testing.T, got []allsrv.RespErr) { + require.Len(t, got, 1) + + want := allsrv.RespErr{ + Status: 401, + Code: 4, + Msg: "unauthorized access", + Source: &allsrv.RespErrSource{ + Header: "Authorization", + }, + } + assert.Equal(t, want, got[0]) + }) + }) + }) +} + +func expectErrs(t *testing.T, r io.Reader, fn func(t *testing.T, got []allsrv.RespErr)) { + t.Helper() + + expectJSONBody(t, r, func(t *testing.T, got allsrv.RespResourceBody[any]) { + require.Nil(t, got.Data) + require.NotEmpty(t, got.Errs) + + fn(t, got.Errs) + }) +} + +func expectData[Attrs any | []any](t *testing.T, r io.Reader, want allsrv.RespData[Attrs]) { + t.Helper() + + expectJSONBody(t, r, func(t *testing.T, got allsrv.RespResourceBody[Attrs]) { + require.Empty(t, got.Errs) + require.NotNil(t, got.Data) + + assert.Equal(t, want, *got.Data) + }) +}