diff --git a/internal/config/config.go b/internal/config/config.go index 73106fbde..b9f746370 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -30,6 +30,8 @@ type Config struct { NavbarTitle string Env map[string]string TLS *TLS + IsAuthToken bool + AuthToken string } type TLS struct { @@ -71,7 +73,8 @@ func LoadConfig(userHomeDir string) error { _ = viper.BindEnv("navbarTitle", "DAGU_NAVBAR_TITLE") _ = viper.BindEnv("tls.certFile", "DAGU_CERT_FILE") _ = viper.BindEnv("tls.keyFile", "DAGU_KEY_FILE") - + _ = viper.BindEnv("isAuthToken", "DAGU_IS_AUTHTOKEN") + _ = viper.BindEnv("authToken", "DAGU_AUTHTOKEN") command := "dagu" if ex, err := os.Executable(); err == nil { command = ex @@ -93,6 +96,8 @@ func LoadConfig(userHomeDir string) error { viper.SetDefault("adminLogsDir", path.Join(appHome, "logs", "admin")) viper.SetDefault("navbarColor", "") viper.SetDefault("navbarTitle", "Dagu") + viper.SetDefault("isAuthToken", "0") + viper.SetDefault("authToken", "0") viper.AutomaticEnv() diff --git a/service/frontend/fx.go b/service/frontend/fx.go index 8be9f6925..fce54497f 100644 --- a/service/frontend/fx.go +++ b/service/frontend/fx.go @@ -3,6 +3,7 @@ package frontend import ( "context" "embed" + "github.com/dagu-dev/dagu/internal/config" "github.com/dagu-dev/dagu/internal/logger" "github.com/dagu-dev/dagu/service/frontend/handlers" @@ -53,6 +54,13 @@ func New(params Params) *server.Server { AssetsFS: assetsFS, } + if params.Config.IsAuthToken { + + serverParams.AuthToken = &server.AuthToken{ + Token: params.Config.AuthToken, + } + } + if params.Config.IsBasicAuth { serverParams.BasicAuth = &server.BasicAuth{ Username: params.Config.BasicAuthUsername, diff --git a/service/frontend/middleware/global.go b/service/frontend/middleware/global.go index 09a9b9cb9..2be887924 100644 --- a/service/frontend/middleware/global.go +++ b/service/frontend/middleware/global.go @@ -1,9 +1,10 @@ package middleware import ( - "github.com/go-chi/chi/v5/middleware" "net/http" "strings" + + "github.com/go-chi/chi/v5/middleware" ) func SetupGlobalMiddleware(handler http.Handler) http.Handler { @@ -12,12 +13,15 @@ func SetupGlobalMiddleware(handler http.Handler) http.Handler { next = middleware.Logger(next) next = middleware.Recoverer(next) + if authToken != nil { + next = TokenAuth("restricted", authToken.Token)(next) + } + if basicAuth != nil { next = middleware.BasicAuth( "restricted", map[string]string{basicAuth.Username: basicAuth.Password}, )(next) } - next = prefixChecker(next) return next @@ -26,11 +30,13 @@ func SetupGlobalMiddleware(handler http.Handler) http.Handler { var ( defaultHandler http.Handler basicAuth *BasicAuth + authToken *AuthToken ) type Options struct { Handler http.Handler BasicAuth *BasicAuth + AuthToken *AuthToken } type BasicAuth struct { @@ -38,9 +44,14 @@ type BasicAuth struct { Password string } +type AuthToken struct { + Token string +} + func Setup(opts *Options) { defaultHandler = opts.Handler basicAuth = opts.BasicAuth + authToken = opts.AuthToken } func prefixChecker(next http.Handler) http.Handler { diff --git a/service/frontend/middleware/tokenAuth.go b/service/frontend/middleware/tokenAuth.go new file mode 100644 index 000000000..3fad420bd --- /dev/null +++ b/service/frontend/middleware/tokenAuth.go @@ -0,0 +1,40 @@ +package middleware + +import ( + "crypto/subtle" + "fmt" + "net/http" + "strings" +) + +// TokenAuth implements a similar middleware handler like go-chi's BasicAuth middleware but for bearer tokens +func TokenAuth(realm string, token string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeader := strings.Split(r.Header.Get("Authorization"), " ") + if len(authHeader) < 2 { + tokenAuthFailed(w, realm) + return + } + + bearer := authHeader[1] + if bearer == "" { + tokenAuthFailed(w, realm) + return + } + + if subtle.ConstantTimeCompare([]byte(bearer), []byte(token)) != 1 { + tokenAuthFailed(w, realm) + return + } + + next.ServeHTTP(w, r) + + }) + } +} + +func tokenAuthFailed(w http.ResponseWriter, realm string) { + w.Header().Add("WWW-Authenticate", fmt.Sprintf(`Bearer realm="%s"`, realm)) + w.WriteHeader(http.StatusUnauthorized) +} diff --git a/service/frontend/server/server.go b/service/frontend/server/server.go index eca1b3c4b..d6eb8bc40 100644 --- a/service/frontend/server/server.go +++ b/service/frontend/server/server.go @@ -3,17 +3,18 @@ package server import ( "context" "errors" + "io/fs" + "net/http" + "os" + "os/signal" + "syscall" + "github.com/dagu-dev/dagu/internal/config" "github.com/dagu-dev/dagu/internal/logger" "github.com/dagu-dev/dagu/internal/logger/tag" "github.com/dagu-dev/dagu/service/frontend/restapi" "github.com/go-openapi/loads" flags "github.com/jessevdk/go-flags" - "io/fs" - "net/http" - "os" - "os/signal" - "syscall" pkgmiddleware "github.com/dagu-dev/dagu/service/frontend/middleware" "github.com/dagu-dev/dagu/service/frontend/restapi/operations" @@ -26,10 +27,15 @@ type BasicAuth struct { Password string } +type AuthToken struct { + Token string +} + type Params struct { Host string Port int BasicAuth *BasicAuth + AuthToken *AuthToken TLS *config.TLS Logger logger.Logger Handlers []New @@ -40,6 +46,7 @@ type Server struct { host string port int basicAuth *BasicAuth + authToken *AuthToken tls *config.TLS logger logger.Logger server *restapi.Server @@ -56,6 +63,7 @@ func NewServer(params Params) *Server { host: params.Host, port: params.Port, basicAuth: params.BasicAuth, + authToken: params.AuthToken, tls: params.TLS, logger: params.Logger, handlers: params.Handlers, @@ -77,6 +85,11 @@ func (svr *Server) Serve(ctx context.Context) (err error) { middlewareOptions := &pkgmiddleware.Options{ Handler: svr.defaultRoutes(chi.NewRouter()), } + if svr.authToken != nil { + middlewareOptions.AuthToken = &pkgmiddleware.AuthToken{ + Token: svr.authToken.Token, + } + } if svr.basicAuth != nil { middlewareOptions.BasicAuth = &pkgmiddleware.BasicAuth{ Username: svr.basicAuth.Username, @@ -90,7 +103,6 @@ func (svr *Server) Serve(ctx context.Context) (err error) { svr.logger.Error("failed to load API spec", tag.Error(err)) return err } - api := operations.NewDaguAPI(swaggerSpec) for _, h := range svr.handlers { h.Configure(api)