Skip to content

Commit

Permalink
Merge pull request #92 from dkj/seamless_auth_flow
Browse files Browse the repository at this point in the history
Seamless auth flow
  • Loading branch information
kjsanger authored Feb 12, 2025
2 parents d76a960 + 764f95c commit 6b93ecd
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 73 deletions.
158 changes: 91 additions & 67 deletions server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import (
"github.com/rs/zerolog/hlog"
)

const RedirectURIState = "redirect_uri"

// HandlerChain is a function that takes an http.Handler and returns a new http.Handler
// wrapping the input handler. Each handler in the chain should process the request in
// some way, and then call the next handler. Ideally, the functionality of each handler
Expand Down Expand Up @@ -107,45 +109,54 @@ func HandleHomePage(server *SqyrrlServer) http.Handler {
})
}

func HandleLogin(server *SqyrrlServer) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger := server.logger
logger.Trace().Msg("LoginHandler called")

req, _ := httputil.DumpRequest(r, true)
logger.Trace().Str("request", string(req)).Msg("HandleLogin request")
// RedirectToIdentityServer redirects the user to the identity server for use within
// the LoginHandler and iRODSGetHandler on finding authenticaiton required.
func RedirectToIdentityServer(w http.ResponseWriter, r *http.Request, server *SqyrrlServer, redirect_uri string) {
logger := server.logger
logger.Trace().Msg("LoginHandler called")

if !server.sqyrrlConfig.EnableOIDC {
logger.Error().Msg("OIDC is not enabled")
writeErrorResponse(logger, w, http.StatusForbidden)
return
}
req, _ := httputil.DumpRequest(r, true)
logger.Trace().Str("request", string(req)).Msg("HandleLogin request")

w.Header().Add("Cache-Control", "no-cache") // See https://github.com/okta/samples-golang/issues/20
if !server.sqyrrlConfig.EnableOIDC {
logger.Error().Msg("OIDC is not enabled")
writeErrorResponse(logger, w, http.StatusForbidden)
return
}

state, err := cryptoRandString(16) // Minimum 128 bits required
if err != nil {
logger.Err(err).Msg("Failed to generate a random state")
writeErrorResponse(logger, w, http.StatusInternalServerError)
return
}
w.Header().Add("Cache-Control", "no-cache") // See https://github.com/okta/samples-golang/issues/20

// https://github.com/OWASP/CheatSheetSeries/blob/master/cheatsheets/Session_Management_Cheat_Sheet.md#renew-the-session-id-after-any-privilege-level-change
err = server.sessionManager.RenewToken(r.Context())
if err != nil {
logger.Err(err).Msg("Failed to renew session token")
writeErrorResponse(logger, w, http.StatusInternalServerError)
return
}
server.sessionManager.Put(r.Context(), SessionKeyState, state)
state, err := cryptoRandString(16) // Minimum 128 bits required
if err != nil {
logger.Err(err).Msg("Failed to generate a random state")
writeErrorResponse(logger, w, http.StatusInternalServerError)
return
}

authURL := server.oauth2Config.AuthCodeURL(state)
logger.Info().
Str("auth_url", authURL).
Str("state", state).
Msg("Redirecting to auth URL")
// https://github.com/OWASP/CheatSheetSeries/blob/master/cheatsheets/Session_Management_Cheat_Sheet.md#renew-the-session-id-after-any-privilege-level-change
err = server.sessionManager.RenewToken(r.Context())
if err != nil {
logger.Err(err).Msg("Failed to renew session token")
writeErrorResponse(logger, w, http.StatusInternalServerError)
return
}
server.sessionManager.Put(r.Context(), SessionKeyState, state)
// store where to send the user after login
server.sessionManager.Put(r.Context(), RedirectURIState, redirect_uri)

authURL := server.oauth2Config.AuthCodeURL(state)
logger.Info().
Str("auth_redirect_url", authURL).
Str("state", state).
Str("eventual_redirect_uri", redirect_uri).
Msg("Redirecting to auth URL")

http.Redirect(w, r, authURL, http.StatusFound)
}

http.Redirect(w, r, authURL, http.StatusFound)
func HandleLogin(server *SqyrrlServer) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
RedirectToIdentityServer(w, r, server, "")
})
}

Expand Down Expand Up @@ -232,9 +243,12 @@ func HandleAuthCallback(server *SqyrrlServer) http.Handler {
Str("email", claims.Email).
Msg("User logged in")

logger.Debug().Msg("Redirecting logged in user to home page")
// find where to send the user after login - could be the home page or a path requiring auth
redirect_uri := "/" + server.sessionManager.GetString(r.Context(), RedirectURIState)

http.Redirect(w, r, "/", http.StatusFound)
logger.Debug().Str("redirect_uri", redirect_uri).Msg("Redirecting logged in user")

http.Redirect(w, r, redirect_uri, http.StatusFound)
})
}

Expand Down Expand Up @@ -341,45 +355,55 @@ func HandleIRODSGet(server *SqyrrlServer) http.Handler {
localZone := server.iRODSAccount.ClientZone

var isReadable bool
isReadable, err = IsPublicReadable(logger, rodsFs, objPath)
if err != nil {
logger.Err(err).Msg("Failed to check if the object is public readable")
writeErrorResponse(logger, w, http.StatusInternalServerError)
return
}

if server.isAuthenticated(r) {
// The username obtained from the email address does not include the iRODS
// zone. We use the local zone to which the Sqyrrl server is connected as
// the user's zone.
userName := iRODSUsernameFromEmail(logger, server.getSessionUserEmail(r))
userZone := server.sqyrrlConfig.IRODSZoneForOIDC

logger.Debug().Str("user", userName).Msg("User is authenticated")

isReadable, err = IsReadableByUser(logger, rodsFs, localZone, userName, userZone, objPath)
if err != nil {
logger.Err(err).Msg("Failed to check if the object is readable")
writeErrorResponse(logger, w, http.StatusInternalServerError)
return
}
if isReadable {
logger.Debug().
Str("path", objPath).
Msg("Requested path is public readable")
} else {
if server.isAuthenticated(r) {
// The username obtained from the email address does not include the iRODS
// zone. We use the local zone to which the Sqyrrl server is connected as
// the user's zone.
userName := iRODSUsernameFromEmail(logger, server.getSessionUserEmail(r))
userZone := server.sqyrrlConfig.IRODSZoneForOIDC

logger.Debug().Str("user", userName).Msg("User is authenticated")

isReadable, err = IsReadableByUser(logger, rodsFs, localZone, userName, userZone, objPath)
if err != nil {
logger.Err(err).Msg("Failed to check if the object is readable")
writeErrorResponse(logger, w, http.StatusInternalServerError)
return
}

if !isReadable {
logger.Info().
Str("path", objPath).
Str("user", userName).
Str("zone", userZone).
Msg("Requested path is not readable by this user")
writeErrorResponse(logger, w, http.StatusForbidden)
return
}
} else if server.sqyrrlConfig.EnableOIDC {
logger.Debug().Msg("User is not authenticated")

if !isReadable {
logger.Info().
Str("path", objPath).
Str("user", userName).
Str("zone", userZone).
Msg("Requested path is not readable by this user")
writeErrorResponse(logger, w, http.StatusForbidden)
Msg("Requested path is not public readable - redirecting to login")
RedirectToIdentityServer(w, r, server, r.URL.Path)
return
}
} else {
logger.Debug().Msg("User is not authenticated")
isReadable, err = IsPublicReadable(logger, rodsFs, objPath)
if err != nil {
logger.Err(err).Msg("Failed to check if the object is public readable")
writeErrorResponse(logger, w, http.StatusInternalServerError)
return
}

if !isReadable {
} else {
logger.Info().
Str("path", objPath).
Msg("Requested path is not public readable")
Msg("Requested path is not public readable - and no OIDC enabled")
writeErrorResponse(logger, w, http.StatusForbidden)
return
}
Expand Down
5 changes: 3 additions & 2 deletions server/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ var _ = Describe("iRODS Get Handler", func() {
Expect(err).NotTo(HaveOccurred())

r, err = http.NewRequest("GET", getURL, nil)
Expect(err).NotTo(HaveOccurred())
})

When("the user is not in the public group", func() {
Expand Down Expand Up @@ -248,11 +249,11 @@ var _ = Describe("iRODS Get Handler", func() {
Expect(err).NotTo(HaveOccurred())
})

It("should return Forbidden", func(ctx SpecContext) {
It("should return Ok", func(ctx SpecContext) {
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, r)

Expect(rec.Code).To(Equal(http.StatusForbidden))
Expect(rec.Code).To(Equal(http.StatusOK))
}, SpecTimeout(specTimeout))
})
})
Expand Down
8 changes: 5 additions & 3 deletions server/server_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,17 @@ package server_test
import (
"errors"
"fmt"
"github.com/alexedwards/scs/v2"
"github.com/cyverse/go-irodsclient/config"
"github.com/cyverse/go-irodsclient/irods/connection"
"math/rand"
"os"
"path/filepath"
"slices"
"testing"
"time"

"github.com/alexedwards/scs/v2"
"github.com/cyverse/go-irodsclient/config"
"github.com/cyverse/go-irodsclient/irods/connection"

"github.com/cyverse/go-irodsclient/fs"
ifs "github.com/cyverse/go-irodsclient/irods/fs"
"github.com/cyverse/go-irodsclient/irods/types"
Expand Down Expand Up @@ -195,6 +196,7 @@ var _ = BeforeSuite(func(ctx SpecContext) {
Expect(err).NotTo(HaveOccurred())
if !inGroup {
err = ifs.AddGroupMember(suiteConn, group, userInOthers, testZone)
Expect(err).NotTo(HaveOccurred())
}
}

Expand Down
4 changes: 3 additions & 1 deletion server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ package server_test

import (
"crypto/tls"
"github.com/alexedwards/scs/v2"
"net"
"net/http"
"net/url"
"os"
"sync"

"github.com/alexedwards/scs/v2"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"

Expand Down Expand Up @@ -112,6 +113,7 @@ var _ = Describe("Server startup and shutdown", func() {
It("returns an error", func() {
config.IRODSEnvFilePath = "nonexistent.json"
err := server.Configure(suiteLogger, &config)
Expect(err).NotTo(HaveOccurred())

_, err = server.NewSqyrrlServer(suiteLogger, &config, scs.New())
Expect(err).To(MatchError("stat nonexistent.json: no such file or directory"))
Expand Down

0 comments on commit 6b93ecd

Please sign in to comment.