Skip to content

Commit

Permalink
Revert reorganisations, so the PR is easier to review. Will add those…
Browse files Browse the repository at this point in the history
… back in a dedicated PR.
  • Loading branch information
ro-tex committed Oct 5, 2021
1 parent 2320c88 commit 1dbb085
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 167 deletions.
136 changes: 104 additions & 32 deletions api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,18 @@ func (api *API) userLimitsGET(w http.ResponseWriter, req *http.Request, _ httpro

// userStatsGET returns statistics about an existing user.
func (api *API) userStatsGET(w http.ResponseWriter, req *http.Request, _ httprouter.Params) {
u, ok := api.userFromContext(w, req, false)
if !ok {
sub, _, _, err := jwt.TokenFromContext(req.Context())
if err != nil {
api.WriteError(w, err, http.StatusUnauthorized)
return
}
u, err := api.staticDB.UserBySub(req.Context(), sub, false)
if errors.Contains(err, database.ErrUserNotFound) {
api.WriteError(w, err, http.StatusNotFound)
return
}
if err != nil {
api.WriteError(w, err, http.StatusInternalServerError)
return
}
us, err := api.staticDB.UserStats(req.Context(), *u)
Expand Down Expand Up @@ -269,6 +279,12 @@ func (api *API) userPOST(w http.ResponseWriter, req *http.Request, _ httprouter.
// userPUT allows changing some user information.
// This method receives its parameters as a JSON object.
func (api *API) userPUT(w http.ResponseWriter, req *http.Request, _ httprouter.Params) {
sub, _, _, err := jwt.TokenFromContext(req.Context())
if err != nil {
api.WriteError(w, err, http.StatusUnauthorized)
return
}

// Read and parse the request body.
bodyBytes, err := ioutil.ReadAll(req.Body)
if err != nil {
Expand All @@ -291,8 +307,13 @@ func (api *API) userPUT(w http.ResponseWriter, req *http.Request, _ httprouter.P
}

// Fetch the user from the DB.
u, ok := api.userFromContext(w, req, false)
if !ok {
u, err := api.staticDB.UserBySub(req.Context(), sub, false)
if errors.Contains(err, database.ErrUserNotFound) {
api.WriteError(w, err, http.StatusNotFound)
return
}
if err != nil {
api.WriteError(w, err, http.StatusInternalServerError)
return
}

Expand All @@ -309,7 +330,7 @@ func (api *API) userPUT(w http.ResponseWriter, req *http.Request, _ httprouter.P
api.WriteError(w, err, http.StatusInternalServerError)
return
}
if err == nil && su.Sub != u.Sub {
if err == nil && su.Sub != sub {
err = errors.New("this stripe customer id belongs to another user")
api.WriteError(w, err, http.StatusBadRequest)
return
Expand All @@ -334,7 +355,7 @@ func (api *API) userPUT(w http.ResponseWriter, req *http.Request, _ httprouter.P
api.WriteError(w, err, http.StatusInternalServerError)
return
}
if err == nil && eu.Sub != u.Sub {
if err == nil && eu.Sub != sub {
err = errors.New("this email is already in use")
api.WriteError(w, err, http.StatusBadRequest)
return
Expand Down Expand Up @@ -368,17 +389,23 @@ func (api *API) userPUT(w http.ResponseWriter, req *http.Request, _ httprouter.P

// userUploadsGET returns all uploads made by the current user.
func (api *API) userUploadsGET(w http.ResponseWriter, req *http.Request, _ httprouter.Params) {
u, ok := api.userFromContext(w, req, false)
if !ok {
sub, _, _, err := jwt.TokenFromContext(req.Context())
if err != nil {
api.WriteError(w, err, http.StatusUnauthorized)
return
}
if err := req.ParseForm(); err != nil {
u, err := api.staticDB.UserBySub(req.Context(), sub, true)
if err != nil {
api.WriteError(w, err, http.StatusInternalServerError)
return
}
if err = req.ParseForm(); err != nil {
api.WriteError(w, err, http.StatusBadRequest)
return
}
offset, err1 := fetchOffset(req.Form)
pageSize, err2 := fetchPageSize(req.Form)
if err := errors.Compose(err1, err2); err != nil {
if err = errors.Compose(err1, err2); err != nil {
api.WriteError(w, err, http.StatusBadRequest)
return
}
Expand All @@ -398,17 +425,23 @@ func (api *API) userUploadsGET(w http.ResponseWriter, req *http.Request, _ httpr

// userDownloadsGET returns all downloads made by the current user.
func (api *API) userDownloadsGET(w http.ResponseWriter, req *http.Request, _ httprouter.Params) {
u, ok := api.userFromContext(w, req, false)
if !ok {
sub, _, _, err := jwt.TokenFromContext(req.Context())
if err != nil {
api.WriteError(w, err, http.StatusUnauthorized)
return
}
if err := req.ParseForm(); err != nil {
u, err := api.staticDB.UserBySub(req.Context(), sub, true)
if err != nil {
api.WriteError(w, err, http.StatusInternalServerError)
return
}
if err = req.ParseForm(); err != nil {
api.WriteError(w, err, http.StatusBadRequest)
return
}
offset, err1 := fetchOffset(req.Form)
pageSize, err2 := fetchPageSize(req.Form)
if err := errors.Compose(err1, err2); err != nil {
if err = errors.Compose(err1, err2); err != nil {
api.WriteError(w, err, http.StatusBadRequest)
return
}
Expand Down Expand Up @@ -456,11 +489,16 @@ func (api *API) userConfirmGET(w http.ResponseWriter, req *http.Request, _ httpr
// email, in case the previous one didn't arrive for some reason.
// The user needs to be logged in.
func (api *API) userReconfirmPOST(w http.ResponseWriter, req *http.Request, _ httprouter.Params) {
u, ok := api.userFromContext(w, req, false)
if !ok {
sub, _, _, err := jwt.TokenFromContext(req.Context())
if err != nil {
api.WriteError(w, err, http.StatusUnauthorized)
return
}
u, err := api.staticDB.UserBySub(req.Context(), sub, true)
if err != nil {
api.WriteError(w, err, http.StatusInternalServerError)
return
}
var err error
u.EmailConfirmationTokenExpiration = time.Now().UTC().Add(database.EmailConfirmationTokenTTL).Truncate(time.Millisecond)
u.EmailConfirmationToken, err = lib.GenerateUUID()
if err != nil {
Expand Down Expand Up @@ -588,6 +626,11 @@ func (api *API) userRecoverPOST(w http.ResponseWriter, req *http.Request, _ http

// trackUploadPOST registers a new upload in the system.
func (api *API) trackUploadPOST(w http.ResponseWriter, req *http.Request, ps httprouter.Params) {
sub, _, _, err := jwt.TokenFromContext(req.Context())
if err != nil {
api.WriteError(w, err, http.StatusUnauthorized)
return
}
sl := ps.ByName("skylink")
if sl == "" {
api.WriteError(w, errors.New("missing parameter 'skylink'"), http.StatusBadRequest)
Expand All @@ -602,8 +645,9 @@ func (api *API) trackUploadPOST(w http.ResponseWriter, req *http.Request, ps htt
api.WriteError(w, err, http.StatusInternalServerError)
return
}
u, ok := api.userFromContext(w, req, false)
if !ok {
u, err := api.staticDB.UserBySub(req.Context(), sub, true)
if err != nil {
api.WriteError(w, err, http.StatusInternalServerError)
return
}
_, err = api.staticDB.UploadCreate(req.Context(), *u, *skylink)
Expand All @@ -630,7 +674,12 @@ func (api *API) trackUploadPOST(w http.ResponseWriter, req *http.Request, ps htt

// trackDownloadPOST registers a new download in the system.
func (api *API) trackDownloadPOST(w http.ResponseWriter, req *http.Request, ps httprouter.Params) {
err := req.ParseForm()
sub, _, _, err := jwt.TokenFromContext(req.Context())
if err != nil {
api.WriteError(w, err, http.StatusUnauthorized)
return
}
err = req.ParseForm()
if err != nil {
api.WriteError(w, err, http.StatusBadRequest)
return
Expand Down Expand Up @@ -665,8 +714,9 @@ func (api *API) trackDownloadPOST(w http.ResponseWriter, req *http.Request, ps h
api.WriteError(w, err, http.StatusInternalServerError)
return
}
u, ok := api.userFromContext(w, req, false)
if !ok {
u, err := api.staticDB.UserBySub(req.Context(), sub, true)
if err != nil {
api.WriteError(w, err, http.StatusInternalServerError)
return
}
err = api.staticDB.DownloadCreate(req.Context(), *u, *skylink, downloadedBytes)
Expand All @@ -690,11 +740,17 @@ func (api *API) trackDownloadPOST(w http.ResponseWriter, req *http.Request, ps h

// trackRegistryReadPOST registers a new registry read in the system.
func (api *API) trackRegistryReadPOST(w http.ResponseWriter, req *http.Request, _ httprouter.Params) {
u, ok := api.userFromContext(w, req, false)
if !ok {
sub, _, _, err := jwt.TokenFromContext(req.Context())
if err != nil {
api.WriteError(w, err, http.StatusUnauthorized)
return
}
u, err := api.staticDB.UserBySub(req.Context(), sub, true)
if err != nil {
api.WriteError(w, err, http.StatusInternalServerError)
return
}
_, err := api.staticDB.RegistryReadCreate(req.Context(), *u)
_, err = api.staticDB.RegistryReadCreate(req.Context(), *u)
if err != nil {
api.WriteError(w, err, http.StatusInternalServerError)
return
Expand All @@ -704,11 +760,17 @@ func (api *API) trackRegistryReadPOST(w http.ResponseWriter, req *http.Request,

// trackRegistryWritePOST registers a new registry write in the system.
func (api *API) trackRegistryWritePOST(w http.ResponseWriter, req *http.Request, _ httprouter.Params) {
u, ok := api.userFromContext(w, req, false)
if !ok {
sub, _, _, err := jwt.TokenFromContext(req.Context())
if err != nil {
api.WriteError(w, err, http.StatusUnauthorized)
return
}
u, err := api.staticDB.UserBySub(req.Context(), sub, true)
if err != nil {
api.WriteError(w, err, http.StatusInternalServerError)
return
}
_, err := api.staticDB.RegistryWriteCreate(req.Context(), *u)
_, err = api.staticDB.RegistryWriteCreate(req.Context(), *u)
if err != nil {
api.WriteError(w, err, http.StatusInternalServerError)
return
Expand All @@ -718,8 +780,18 @@ func (api *API) trackRegistryWritePOST(w http.ResponseWriter, req *http.Request,

// userUploadsDELETE unpins all uploads of a skylink uploaded by the user.
func (api *API) userUploadsDELETE(w http.ResponseWriter, req *http.Request, ps httprouter.Params) {
u, ok := api.userFromContext(w, req, false)
if !ok {
sub, _, _, err := jwt.TokenFromContext(req.Context())
if err != nil {
api.WriteError(w, err, http.StatusUnauthorized)
return
}
u, err := api.staticDB.UserBySub(req.Context(), sub, false)
if errors.Contains(err, database.ErrUserNotFound) {
api.WriteError(w, err, http.StatusNotFound)
return
}
if err != nil {
api.WriteError(w, err, http.StatusInternalServerError)
return
}
sl := ps.ByName("skylink")
Expand Down Expand Up @@ -772,13 +844,13 @@ func (api *API) checkUserQuotas(ctx context.Context, u *database.User) {
// context and then fetches the user from the database. If that operation is
// successful it returns the user and `true`, otherwise it writes the respective
// error to the ResponseWriter and returns `nil, false`.
func (api *API) userFromContext(w http.ResponseWriter, req *http.Request, create bool) (*database.User, bool) {
func (api *API) userFromContext(w http.ResponseWriter, req *http.Request) (*database.User, bool) {
sub, _, _, err := jwt.TokenFromContext(req.Context())
if err != nil {
api.WriteError(w, err, http.StatusUnauthorized)
return nil, false
}
u, err := api.staticDB.UserBySub(req.Context(), sub, create)
u, err := api.staticDB.UserBySub(req.Context(), sub, false)
if errors.Contains(err, database.ErrUserNotFound) {
api.WriteError(w, err, http.StatusNotFound)
return nil, false
Expand Down
1 change: 0 additions & 1 deletion api/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ func tokenFromRequest(r *http.Request) (string, error) {

// userFromRequest returns a user object based on the JWT within the request.
// Note that this method does not rely on a token being stored in the context.
// TODO Do we need this?
func (api *API) userFromRequest(r *http.Request) *database.User {
t, err := tokenFromRequest(r)
if err != nil {
Expand Down
Loading

0 comments on commit 1dbb085

Please sign in to comment.