diff --git a/go/admin/admin.go b/go/admin/admin.go index aec1632b..ca7e3e1d 100644 --- a/go/admin/admin.go +++ b/go/admin/admin.go @@ -51,7 +51,7 @@ type Admin struct { ghAppId string ghAppSecret string - ghTokenSalt string + auth string dbCfg *psdb.Config dbClient *psdb.Client @@ -64,7 +64,7 @@ func (a *Admin) AddToCommand(cmd *cobra.Command) { cmd.Flags().Var(&a.Mode, flagMode, "Specify the mode on which the server will run") cmd.Flags().StringVar(&a.ghAppId, flagAdminAppId, "", "The ID of the GitHub App") cmd.Flags().StringVar(&a.ghAppSecret, flagAdminAppSecret, "", "The secret of the GitHub App") - cmd.Flags().StringVar(&a.ghTokenSalt, flagGhAuth, "", "The salt string to salt the GitHub Token") + cmd.Flags().StringVar(&a.auth, flagGhAuth, "", "The salt string to salt the GitHub Token") _ = viper.BindPFlag(flagPort, cmd.Flags().Lookup(flagPort)) _ = viper.BindPFlag(flagMode, cmd.Flags().Lookup(flagMode)) @@ -139,7 +139,7 @@ func (a *Admin) Run() error { a.router.GET("/admin", a.login) a.router.GET("/admin/login", a.handleGitHubLogin) a.router.GET("/admin/auth/callback", a.handleGitHubCallback) - a.router.POST("/admin/executions/add", a.handleExecutionsAdd) + a.router.POST("/admin/executions/add", a.authMiddleware(), a.handleExecutionsAdd) a.router.GET("/admin/dashboard", a.authMiddleware(), a.dashboard) return a.router.Run(":" + a.port) diff --git a/go/admin/api.go b/go/admin/api.go index 2f9ce0cd..8aed1533 100644 --- a/go/admin/api.go +++ b/go/admin/api.go @@ -24,6 +24,7 @@ import ( "encoding/json" "net/http" "strings" + "sync" "time" "github.com/gin-gonic/gin" @@ -42,6 +43,10 @@ var ( } oauthStateString = random.String(10) // A random string to protect against CSRF attacks client *goGithub.Client + orgName = "vitessio" + tokens = make(map[string]oauth2.Token) + + mu sync.Mutex ) const ( @@ -65,21 +70,65 @@ func (a *Admin) dashboard(c *gin.Context) { a.render(c, gin.H{}, "base.html") } +func CreateGhClient(token *oauth2.Token) *goGithub.Client { + return goGithub.NewClient(oauthConf.Client(context.Background(), token)) +} + func (a *Admin) authMiddleware() gin.HandlerFunc { return func(c *gin.Context) { - _, err := c.Cookie("ghtoken") + cookie, err := c.Cookie("ghtoken") if err != nil { - // User not authenticated, redirect to login c.Redirect(http.StatusSeeOther, "/admin/login") c.Abort() return } - // User is authenticated, proceed to the next handler + mu.Lock() + defer mu.Unlock() + + token, ok := tokens[cookie] + + if !ok { + c.Redirect(http.StatusSeeOther, "/admin/login") + c.Abort() + return + } + + client := CreateGhClient(&token) + + isMaintainer, err := a.GetUser(client) + + if err != nil { + slog.Error("Error getting user: ", err) + c.AbortWithStatus(http.StatusInternalServerError) + return + } + + if !isMaintainer { + c.String(http.StatusForbidden, "You must be a maintainer in the %s organization to access this page.", orgName) + c.Abort() + return + } + c.Next() } } +func (a *Admin) GetUser(client *goGithub.Client) (bool, error) { + user, _, err := client.Users.Get(context.Background(), "") + if err != nil { + return false, err + } + + isMaintainer, err := a.checkUserOrgMembership(client, user.GetLogin(), orgName) + if err != nil { + slog.Error("Error checking org membership: ", err) + return false, err + } + + return isMaintainer, nil +} + func (a *Admin) handleGitHubLogin(c *gin.Context) { if a.Mode == server.ProductionMode { oauthConf.RedirectURL = "https://benchmark.vitess.io/admin/auth/callback" @@ -105,27 +154,17 @@ func (a *Admin) handleGitHubCallback(c *gin.Context) { return } - client := goGithub.NewClient(oauthConf.Client(context.Background(), token)) - - user, _, err := client.Users.Get(context.Background(), "") - if err != nil { - slog.Error("Failed to get user: ", err) - c.AbortWithStatus(http.StatusInternalServerError) - return - } - - slog.Infof("Authenticated user: %s", user.GetLogin()) + client := CreateGhClient(token) - orgName := "vitessio" - isMaintainer, err := a.checkUserOrgMembership(client, user.GetLogin(), orgName) - if err != nil { - slog.Error("Error checking org membership: ", err) - c.AbortWithStatus(http.StatusInternalServerError) - return - } + isMaintainer, err := a.GetUser(client) if isMaintainer { - c.SetCookie("ghtoken", token.AccessToken, int(token.Expiry.Sub(time.Now()).Seconds()), "/", "localhost", true, true) + mu.Lock() + defer mu.Unlock() + + randomKey := random.String(32) + tokens[randomKey] = *token + c.SetCookie("ghtoken", randomKey, int(time.Hour.Seconds()), "/", "localhost", true, true) c.Redirect(http.StatusSeeOther, "/admin/dashboard") } else { @@ -159,7 +198,6 @@ func (a *Admin) checkUserOrgMembership(client *goGithub.Client, username, orgNam } func (a *Admin) handleExecutionsAdd(c *gin.Context) { - slog.Info("Adding execution") source := c.PostForm("source") sha := c.PostForm("sha") workloads := c.PostFormArray("workloads") @@ -177,7 +215,9 @@ func (a *Admin) handleExecutionsAdd(c *gin.Context) { return } - encryptedToken := server.Encrypt(token, a.ghTokenSalt) + encryptedToken := server.Encrypt(token, a.auth) + + slog.Info("Encrypted token: ", encryptedToken) requestPayload := ExecutionRequest{ Auth: encryptedToken, diff --git a/go/tools/server/utils.go b/go/tools/server/utils.go index 74e7eca2..914d1040 100644 --- a/go/tools/server/utils.go +++ b/go/tools/server/utils.go @@ -25,7 +25,7 @@ import ( "encoding/hex" "fmt" "io" - "log/slog" + "log" ) const ( @@ -33,7 +33,7 @@ const ( ) func Encrypt(stringToEncrypt string, keyString string) (encryptedString string) { - slog.Info(stringToEncrypt, keyString) + log.Println(stringToEncrypt, keyString) key, _ := hex.DecodeString(keyString) plaintext := []byte(stringToEncrypt)