Skip to content

Commit

Permalink
Token auth (#508)
Browse files Browse the repository at this point in the history
* WIP basic auth token

* remove output.diff

* implemented reviewed changes

* clean up config.go
  • Loading branch information
smekuria1 authored Dec 20, 2023
1 parent 7c408b4 commit 966a34d
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 9 deletions.
7 changes: 6 additions & 1 deletion internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ type Config struct {
NavbarTitle string
Env map[string]string
TLS *TLS
IsAuthToken bool
AuthToken string
}

type TLS struct {
Expand Down Expand Up @@ -71,7 +73,8 @@ func LoadConfig(userHomeDir string) error {
_ = viper.BindEnv("navbarTitle", "DAGU_NAVBAR_TITLE")
_ = viper.BindEnv("tls.certFile", "DAGU_CERT_FILE")
_ = viper.BindEnv("tls.keyFile", "DAGU_KEY_FILE")

_ = viper.BindEnv("isAuthToken", "DAGU_IS_AUTHTOKEN")
_ = viper.BindEnv("authToken", "DAGU_AUTHTOKEN")
command := "dagu"
if ex, err := os.Executable(); err == nil {
command = ex
Expand All @@ -93,6 +96,8 @@ func LoadConfig(userHomeDir string) error {
viper.SetDefault("adminLogsDir", path.Join(appHome, "logs", "admin"))
viper.SetDefault("navbarColor", "")
viper.SetDefault("navbarTitle", "Dagu")
viper.SetDefault("isAuthToken", "0")
viper.SetDefault("authToken", "0")

viper.AutomaticEnv()

Expand Down
8 changes: 8 additions & 0 deletions service/frontend/fx.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package frontend
import (
"context"
"embed"

"github.com/dagu-dev/dagu/internal/config"
"github.com/dagu-dev/dagu/internal/logger"
"github.com/dagu-dev/dagu/service/frontend/handlers"
Expand Down Expand Up @@ -53,6 +54,13 @@ func New(params Params) *server.Server {
AssetsFS: assetsFS,
}

if params.Config.IsAuthToken {

serverParams.AuthToken = &server.AuthToken{
Token: params.Config.AuthToken,
}
}

if params.Config.IsBasicAuth {
serverParams.BasicAuth = &server.BasicAuth{
Username: params.Config.BasicAuthUsername,
Expand Down
15 changes: 13 additions & 2 deletions service/frontend/middleware/global.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package middleware

import (
"github.com/go-chi/chi/v5/middleware"
"net/http"
"strings"

"github.com/go-chi/chi/v5/middleware"
)

func SetupGlobalMiddleware(handler http.Handler) http.Handler {
Expand All @@ -12,12 +13,15 @@ func SetupGlobalMiddleware(handler http.Handler) http.Handler {
next = middleware.Logger(next)
next = middleware.Recoverer(next)

if authToken != nil {
next = TokenAuth("restricted", authToken.Token)(next)
}

if basicAuth != nil {
next = middleware.BasicAuth(
"restricted", map[string]string{basicAuth.Username: basicAuth.Password},
)(next)
}

next = prefixChecker(next)

return next
Expand All @@ -26,21 +30,28 @@ func SetupGlobalMiddleware(handler http.Handler) http.Handler {
var (
defaultHandler http.Handler
basicAuth *BasicAuth
authToken *AuthToken
)

type Options struct {
Handler http.Handler
BasicAuth *BasicAuth
AuthToken *AuthToken
}

type BasicAuth struct {
Username string
Password string
}

type AuthToken struct {
Token string
}

func Setup(opts *Options) {
defaultHandler = opts.Handler
basicAuth = opts.BasicAuth
authToken = opts.AuthToken
}

func prefixChecker(next http.Handler) http.Handler {
Expand Down
40 changes: 40 additions & 0 deletions service/frontend/middleware/tokenAuth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package middleware

import (
"crypto/subtle"
"fmt"
"net/http"
"strings"
)

// TokenAuth implements a similar middleware handler like go-chi's BasicAuth middleware but for bearer tokens
func TokenAuth(realm string, token string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := strings.Split(r.Header.Get("Authorization"), " ")
if len(authHeader) < 2 {
tokenAuthFailed(w, realm)
return
}

bearer := authHeader[1]
if bearer == "" {
tokenAuthFailed(w, realm)
return
}

if subtle.ConstantTimeCompare([]byte(bearer), []byte(token)) != 1 {
tokenAuthFailed(w, realm)
return
}

next.ServeHTTP(w, r)

})
}
}

func tokenAuthFailed(w http.ResponseWriter, realm string) {
w.Header().Add("WWW-Authenticate", fmt.Sprintf(`Bearer realm="%s"`, realm))
w.WriteHeader(http.StatusUnauthorized)
}
24 changes: 18 additions & 6 deletions service/frontend/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@ package server
import (
"context"
"errors"
"io/fs"
"net/http"
"os"
"os/signal"
"syscall"

"github.com/dagu-dev/dagu/internal/config"
"github.com/dagu-dev/dagu/internal/logger"
"github.com/dagu-dev/dagu/internal/logger/tag"
"github.com/dagu-dev/dagu/service/frontend/restapi"
"github.com/go-openapi/loads"
flags "github.com/jessevdk/go-flags"
"io/fs"
"net/http"
"os"
"os/signal"
"syscall"

pkgmiddleware "github.com/dagu-dev/dagu/service/frontend/middleware"
"github.com/dagu-dev/dagu/service/frontend/restapi/operations"
Expand All @@ -26,10 +27,15 @@ type BasicAuth struct {
Password string
}

type AuthToken struct {
Token string
}

type Params struct {
Host string
Port int
BasicAuth *BasicAuth
AuthToken *AuthToken
TLS *config.TLS
Logger logger.Logger
Handlers []New
Expand All @@ -40,6 +46,7 @@ type Server struct {
host string
port int
basicAuth *BasicAuth
authToken *AuthToken
tls *config.TLS
logger logger.Logger
server *restapi.Server
Expand All @@ -56,6 +63,7 @@ func NewServer(params Params) *Server {
host: params.Host,
port: params.Port,
basicAuth: params.BasicAuth,
authToken: params.AuthToken,
tls: params.TLS,
logger: params.Logger,
handlers: params.Handlers,
Expand All @@ -77,6 +85,11 @@ func (svr *Server) Serve(ctx context.Context) (err error) {
middlewareOptions := &pkgmiddleware.Options{
Handler: svr.defaultRoutes(chi.NewRouter()),
}
if svr.authToken != nil {
middlewareOptions.AuthToken = &pkgmiddleware.AuthToken{
Token: svr.authToken.Token,
}
}
if svr.basicAuth != nil {
middlewareOptions.BasicAuth = &pkgmiddleware.BasicAuth{
Username: svr.basicAuth.Username,
Expand All @@ -90,7 +103,6 @@ func (svr *Server) Serve(ctx context.Context) (err error) {
svr.logger.Error("failed to load API spec", tag.Error(err))
return err
}

api := operations.NewDaguAPI(swaggerSpec)
for _, h := range svr.handlers {
h.Configure(api)
Expand Down

0 comments on commit 966a34d

Please sign in to comment.