Skip to content

Commit

Permalink
Fix use token subsequent requests
Browse files Browse the repository at this point in the history
  • Loading branch information
TimVosch committed May 13, 2024
1 parent da28226 commit f5880f8
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 106 deletions.
5 changes: 4 additions & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ services:
environment:
- STATIC_PATH=services/dashboard/static
- HTTP_BASE=/dashboard
- SB_API=http://caddy
- EP_CORE=http://caddy
- EP_TRACING=http://caddy
- EP_WORKERS=http://caddy
- EP_MEASUREMENTS=http://caddy

core:
build:
Expand Down
24 changes: 21 additions & 3 deletions pkg/auth/middleware.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package auth

import (
"context"
"encoding/json"
"fmt"
"log"
Expand All @@ -12,6 +13,7 @@ import (
"github.com/golang-jwt/jwt"

"sensorbucket.nl/sensorbucket/internal/web"
"sensorbucket.nl/sensorbucket/pkg/api"
)

type claims struct {
Expand Down Expand Up @@ -62,6 +64,7 @@ func Protect() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, err := GetTenant(r.Context()); err != nil {
log.Println("[Auth] token is missing tenant!")
web.HTTPError(w, ErrUnauthorized)
return
}
Expand All @@ -70,6 +73,7 @@ func Protect() func(http.Handler) http.Handler {
// return
//}
if _, err := GetPermissions(r.Context()); err != nil {
log.Println("[Auth] token is missing permissions!")
web.HTTPError(w, ErrUnauthorized)
return
}
Expand All @@ -79,6 +83,18 @@ func Protect() func(http.Handler) http.Handler {
}
}

func ForwardRequestAuthentication() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := strings.TrimPrefix(strings.TrimPrefix(r.Header.Get("Authorization"), "bearer "), "Bearer ")
r = r.WithContext(context.WithValue(
r.Context(), api.ContextAccessToken, token,
))
next.ServeHTTP(w, r)
})
}
}

// Authentication middleware for checking the validity of any present JWT
// Checks if the JWT is signed using the given secret
// Serves the next HTTP handler if there is no JWT or if the JWT is OK
Expand All @@ -93,8 +109,10 @@ func Authenticate(keyClient jwksClient) func(http.Handler) http.Handler {
return
}

tokenStr, ok := strings.CutPrefix(auth, "Bearer ")
// Cheating, removes Bearer and bearer case independently
tokenStr, ok := strings.CutPrefix(auth[1:], "earer ")
if !ok {
log.Printf("[Error] authentication failed err because the Authorization header is malformed\n")
web.HTTPError(w, ErrAuthHeaderInvalidFormat)
return
}
Expand All @@ -103,12 +121,12 @@ func Authenticate(keyClient jwksClient) func(http.Handler) http.Handler {
c := claims{}
token, err := jwt.ParseWithClaims(tokenStr, &c, validateJWTFunc(keyClient))
if err != nil {
log.Printf("[Error] authentication failed err: %s", err)
log.Printf("[Error] authentication failed err: %s\n", err)
web.HTTPError(w, ErrUnauthorized)
return
}
if !token.Valid {
log.Printf("[Error] authentication failed err: %s", err)
log.Printf("[Error] authentication failed err: %s\n", err)
web.HTTPError(w, ErrUnauthorized)
return
}
Expand Down
75 changes: 42 additions & 33 deletions services/dashboard/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ import (
"embed"
"errors"
"fmt"
"log"
"net/http"
"net/url"
"os"
"os/signal"
"strings"
"time"

"github.com/go-chi/chi/v5"
Expand All @@ -29,10 +29,13 @@ func main() {
}

var (
HTTP_ADDR = env.Could("HTTP_ADDR", ":3000")
HTTP_BASE = env.Could("HTTP_BASE", "")
AUTH_JWKS_URL = env.Could("AUTH_JWKS_URL", "http://oathkeeper:4456/.well-known/jwks.json")
SB_API = env.Must("SB_API")
HTTP_ADDR = env.Could("HTTP_ADDR", ":3000")
HTTP_BASE = env.Could("HTTP_BASE", "")
AUTH_JWKS_URL = env.Could("AUTH_JWKS_URL", "http://oathkeeper:4456/.well-known/jwks.json")
EP_CORE = env.Must("EP_CORE")
EP_WORKERS = env.Must("EP_WORKERS")
EP_TRACING = env.Must("EP_TRACING")
EP_MEASUREMENTS = env.Must("EP_MEASUREMENTS")
)

//go:embed static/*
Expand All @@ -45,52 +48,46 @@ func Run() error {

router := chi.NewRouter()
jwks := auth.NewJWKSHttpClient(AUTH_JWKS_URL)
router.Use(middleware.Logger, auth.Authenticate(jwks), auth.Protect())
router.Use(
middleware.Logger,
auth.ForwardRequestAuthentication(),
auth.Authenticate(jwks),
auth.Protect(),
)

var baseURL *url.URL
if HTTP_BASE != "" {
baseURL, _ = url.Parse(HTTP_BASE)
views.SetBase(baseURL)
}

// Middleware to pass on basic auth to the client api
router.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
key := r.Header.Get("Authorization")
key = strings.Join(strings.Split(key, " ")[1:], "")
r = r.WithContext(context.WithValue(
r.Context(), api.ContextAPIKeys, api.APIKey{
Key: key,
Prefix: "Bearer",
}))
next.ServeHTTP(w, r)
})
})

// Serve static files
fileServer := http.FileServer(http.FS(staticFS))
router.Handle("/static/*", fileServer)

sbURL, err := url.Parse(SB_API)
if err != nil {
return fmt.Errorf("could not parse SB_API url: %w", err)
}
cfg := api.NewConfiguration()
cfg.Scheme = sbURL.Scheme
cfg.Host = sbURL.Host
apiClient := api.NewAPIClient(cfg)

router.Get("/", func(w http.ResponseWriter, r *http.Request) {
u := "/overview"
if baseURL != nil {
u = baseURL.JoinPath("overview").String()
}
http.Redirect(w, r, u, http.StatusFound)
})
router.Mount("/overview", routes.CreateOverviewPageHandler(apiClient))
router.Mount("/ingress", routes.CreateIngressPageHandler(apiClient))
router.Mount("/workers", routes.CreateWorkerPageHandler(apiClient))
router.Mount("/pipelines", routes.CreatePipelinePageHandler(apiClient))
router.Mount("/overview", routes.CreateOverviewPageHandler(
createAPIClient(EP_CORE),
createAPIClient(EP_MEASUREMENTS),
))
router.Mount("/ingress", routes.CreateIngressPageHandler(
createAPIClient(EP_CORE),
createAPIClient(EP_TRACING),
createAPIClient(EP_WORKERS),
))
router.Mount("/workers", routes.CreateWorkerPageHandler(
createAPIClient(EP_WORKERS),
))
router.Mount("/pipelines", routes.CreatePipelinePageHandler(
createAPIClient(EP_WORKERS),
createAPIClient(EP_CORE),
))
srv := &http.Server{
Addr: HTTP_ADDR,
WriteTimeout: 5 * time.Second,
Expand All @@ -106,6 +103,7 @@ func Run() error {
fmt.Printf("HTTP Server listening on: %s\n", srv.Addr)

// Wait for fatal error or interrupt signal
var err error
select {
case <-ctx.Done():
case err = <-errC:
Expand All @@ -121,3 +119,14 @@ func Run() error {

return err
}

func createAPIClient(baseurl string) *api.APIClient {
sbURL, err := url.Parse(baseurl)
if err != nil {
log.Fatalf("could not parse APIClient url: %s\n", err)
}
cfg := api.NewConfiguration()
cfg.Scheme = sbURL.Scheme
cfg.Host = sbURL.Host
return api.NewAPIClient(cfg)
}
24 changes: 14 additions & 10 deletions services/dashboard/routes/ingress.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@ import (
)

type IngressPageHandler struct {
router chi.Router
client *api.APIClient
router chi.Router
coreClient *api.APIClient
tracingClient *api.APIClient
workersClient *api.APIClient
}

func CreateIngressPageHandler(client *api.APIClient) *IngressPageHandler {
func CreateIngressPageHandler(core, tracing, workers *api.APIClient) *IngressPageHandler {
handler := &IngressPageHandler{
router: chi.NewRouter(),
client: client,
router: chi.NewRouter(),
coreClient: core,
tracingClient: tracing,
workersClient: workers,
}
handler.SetupRoutes(handler.router)
return handler
Expand All @@ -40,18 +44,18 @@ func (h *IngressPageHandler) SetupRoutes(r chi.Router) {
}

func (h *IngressPageHandler) createViewIngresses(ctx context.Context) ([]views.Ingress, error) {
resIngresses, _, err := h.client.TracingApi.ListIngresses(ctx).Limit(30).Execute()
resIngresses, _, err := h.tracingClient.TracingApi.ListIngresses(ctx).Limit(30).Execute()
if err != nil {
return nil, fmt.Errorf("error listing ingresses: %w", err)
}
resPipelines, _, err := h.client.PipelinesApi.ListPipelines(ctx).Execute()
resPipelines, _, err := h.coreClient.PipelinesApi.ListPipelines(ctx).Execute()
if err != nil {
return nil, fmt.Errorf("error listing pipelines: %w", err)
}

plSteps := lo.FlatMap(resPipelines.Data, func(p api.Pipeline, _ int) []string { return p.Steps })
plSteps = lo.Uniq(plSteps)
resWorkers, _, err := h.client.WorkersApi.ListWorkers(ctx).Id(plSteps).Execute()
resWorkers, _, err := h.workersClient.WorkersApi.ListWorkers(ctx).Id(plSteps).Execute()
if err != nil {
return nil, fmt.Errorf("error listing workers: %w", err)
}
Expand All @@ -61,7 +65,7 @@ func (h *IngressPageHandler) createViewIngresses(ctx context.Context) ([]views.I

traceIDs := lo.Map(resIngresses.Data, func(ing api.ArchivedIngress, _ int) string { return ing.GetTracingId() })
traceIDs = lo.Uniq(traceIDs)
resLogs, _, err := h.client.TracingApi.ListTraces(ctx).TracingId(traceIDs).Execute()
resLogs, _, err := h.tracingClient.TracingApi.ListTraces(ctx).TracingId(traceIDs).Execute()
if err != nil {
return nil, fmt.Errorf("error listing traces: %w", err)
}
Expand All @@ -72,7 +76,7 @@ func (h *IngressPageHandler) createViewIngresses(ctx context.Context) ([]views.I
deviceIDs := lo.FilterMap(resLogs.Data, func(traceLog api.Trace, _ int) (int64, bool) {
return traceLog.DeviceId, traceLog.DeviceId > 0
})
resDevices, _, err := h.client.DevicesApi.ListDevices(ctx).Id(lo.Uniq(deviceIDs)).Execute()
resDevices, _, err := h.coreClient.DevicesApi.ListDevices(ctx).Id(lo.Uniq(deviceIDs)).Execute()
if err != nil {
return nil, fmt.Errorf("error listing devices: %w", err)
}
Expand Down
Loading

0 comments on commit f5880f8

Please sign in to comment.