Skip to content

Commit

Permalink
add client session ID tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
UpcraftLP committed May 14, 2024
1 parent 8439519 commit 436fce8
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 32 deletions.
18 changes: 10 additions & 8 deletions api/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"log"
"net/http"

"github.com/pagefaultgames/rogueserver/api/account"
"github.com/pagefaultgames/rogueserver/api/daily"
"github.com/pagefaultgames/rogueserver/db"
"log"
"net/http"
)

func Init(mux *http.ServeMux) error {
Expand All @@ -49,14 +48,17 @@ func Init(mux *http.ServeMux) error {
mux.HandleFunc("GET /game/classicsessioncount", handleGameClassicSessionCount)

// savedata
mux.HandleFunc("GET /savedata/get", handleGetSaveData)
mux.HandleFunc("POST /savedata/update", handleSaveData)
mux.HandleFunc("GET /savedata/delete", handleSaveData) // TODO use deleteSystemSave
mux.HandleFunc("POST /savedata/clear", handleSaveData) // TODO use clearSessionData
mux.HandleFunc("GET /savedata/newclear", handleNewClear)
mux.HandleFunc("GET /savedata/get", legacyHandleGetSaveData)
mux.HandleFunc("POST /savedata/update", legacyHandleSaveData)
mux.HandleFunc("GET /savedata/delete", legacyHandleSaveData) // TODO use deleteSystemSave
mux.HandleFunc("POST /savedata/clear", legacyHandleSaveData) // TODO use clearSessionData
mux.HandleFunc("GET /savedata/newclear", legacyHandleNewClear)

// new session
mux.HandleFunc("POST /savedata/updateall", handleUpdateAll)
mux.HandleFunc("POST /savedata/verify", handleSessionVerify)
mux.HandleFunc("GET /savedata/system", handleGetSystemData)
mux.HandleFunc("GET /savedata/session", handleGetSessionData)

// daily
mux.HandleFunc("GET /daily/seed", handleDailySeed)
Expand Down
185 changes: 172 additions & 13 deletions api/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,53 @@ func handleGameClassicSessionCount(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(strconv.Itoa(classicSessionCount)))
}

func handleGetSaveData(w http.ResponseWriter, r *http.Request) {
func handleGetSessionData(w http.ResponseWriter, r *http.Request) {
token, uuid, err := tokenAndUuidFromRequest(r)
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
return
}

var slot int
if r.URL.Query().Has("slot") {
slot, err = strconv.Atoi(r.URL.Query().Get("slot"))
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
return
}
}

var clientSessionId string
if r.URL.Query().Has("clientSessionId") {
clientSessionId = r.URL.Query().Get("clientSessionId")
} else {
httpError(w, r, fmt.Errorf("missing clientSessionId"), http.StatusBadRequest)
}

err = db.UpdateActiveSession(token, clientSessionId)
if err != nil {
httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest)
return
}

var save any
save, err = savedata.Get(uuid, 1, slot)
if errors.Is(err, sql.ErrNoRows) {
http.Error(w, err.Error(), http.StatusNotFound)
return
}

if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}

jsonResponse(w, r, save)
}

const legacyClientSessionId = "LEGACY_CLIENT"

func legacyHandleGetSaveData(w http.ResponseWriter, r *http.Request) {
token, uuid, err := tokenAndUuidFromRequest(r)
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
Expand All @@ -173,7 +219,7 @@ func handleGetSaveData(w http.ResponseWriter, r *http.Request) {

var save any
if datatype == 0 {
err = db.UpdateActiveSession(uuid, token)
err = db.UpdateActiveSession(token, legacyClientSessionId) // we dont have a client id
if err != nil {
httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest)
return
Expand Down Expand Up @@ -222,7 +268,7 @@ func clearSessionData(w http.ResponseWriter, r *http.Request) {
save = session

var active bool
active, err = db.IsActiveSession(token)
active, err = db.IsActiveSession(token, legacyClientSessionId) //TODO unfinished, read token from query
if err != nil {
httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusBadRequest)
return
Expand Down Expand Up @@ -309,7 +355,7 @@ func deleteSystemSave(w http.ResponseWriter, r *http.Request) {
}

var active bool
active, err = db.IsActiveSession(token)
active, err = db.IsActiveSession(token, legacyClientSessionId) //TODO unfinished, read token from query
if err != nil {
httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusInternalServerError)
return
Expand Down Expand Up @@ -363,7 +409,7 @@ func deleteSystemSave(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}

func handleSaveData(w http.ResponseWriter, r *http.Request) {
func legacyHandleSaveData(w http.ResponseWriter, r *http.Request) {
token, uuid, err := tokenAndUuidFromRequest(r)
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
Expand All @@ -388,6 +434,14 @@ func handleSaveData(w http.ResponseWriter, r *http.Request) {
}
}

var clientSessionId string
if r.URL.Query().Has("clientSessionId") {
clientSessionId = r.URL.Query().Get("clientSessionId")
}
if clientSessionId == "" {
clientSessionId = legacyClientSessionId
}

var save any
// /savedata/get and /savedata/delete specify datatype, but don't expect data in body
if r.URL.Path != "/savedata/get" && r.URL.Path != "/savedata/delete" {
Expand Down Expand Up @@ -416,14 +470,14 @@ func handleSaveData(w http.ResponseWriter, r *http.Request) {
var active bool
if r.URL.Path == "/savedata/get" {
if datatype == 0 {
err = db.UpdateActiveSession(uuid, token)
err = db.UpdateActiveSession(token, clientSessionId)
if err != nil {
httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest)
return
}
}
} else {
active, err = db.IsActiveSession(token)
active, err = db.IsActiveSession(token, clientSessionId)
if err != nil {
httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusBadRequest)
return
Expand Down Expand Up @@ -517,9 +571,10 @@ func handleSaveData(w http.ResponseWriter, r *http.Request) {
}

type CombinedSaveData struct {
System defs.SystemSaveData `json:"system"`
Session defs.SessionSaveData `json:"session"`
SessionSlotId int `json:"sessionSlotId"`
System defs.SystemSaveData `json:"system"`
Session defs.SessionSaveData `json:"session"`
SessionSlotId int `json:"sessionSlotId"`
ClientSessionId string `json:"clientSessionId"`
}

// TODO wrap this in a transaction
Expand All @@ -531,6 +586,14 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) {
return
}

var clientSessionId string
if r.URL.Query().Has("clientSessionId") {
clientSessionId = r.URL.Query().Get("clientSessionId")
}
if clientSessionId == "" {
clientSessionId = legacyClientSessionId
}

var data CombinedSaveData
err = json.NewDecoder(r.Body).Decode(&data)
if err != nil {
Expand All @@ -539,7 +602,7 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) {
}

var active bool
active, err = db.IsActiveSession(token)
active, err = db.IsActiveSession(token, clientSessionId)
if err != nil {
httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusBadRequest)
return
Expand Down Expand Up @@ -584,7 +647,104 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}

func handleNewClear(w http.ResponseWriter, r *http.Request) {
type SessionVerifyResponse struct {
Valid bool `json:"valid"`
SessionData *defs.SessionSaveData `json:"sessionData"`
}

type SessionVerifyRequest struct {
ClientSessionId string `json:"clientSessionId"`
Slot int `json:"slot"`
}

func handleSessionVerify(w http.ResponseWriter, r *http.Request) {
var token []byte
token, err := tokenFromRequest(r)
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
return
}

var input SessionVerifyRequest
err = json.NewDecoder(r.Body).Decode(&input)
if err != nil {
httpError(w, r, fmt.Errorf("failed to decode request body: %s", err), http.StatusBadRequest)
return
}

var active bool
active, err = db.IsActiveSession(token, input.ClientSessionId)
if err != nil {
httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusBadRequest)
return
}

response := SessionVerifyResponse{
Valid: active,
}

// not valid, send server state
if !active {
err = db.UpdateActiveSession(token, input.ClientSessionId)
if err != nil {
httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest)
return
}

var uuid []byte
uuid, err = db.FetchUUIDFromToken(token)
if err != nil {
httpError(w, r, fmt.Errorf("failed to fetch UUID from token: %s", err), http.StatusInternalServerError)
}
var storedSaveData defs.SessionSaveData
storedSaveData, err = db.ReadSessionSaveData(uuid, input.Slot)
if err != nil {
httpError(w, r, fmt.Errorf("failed to read session save data: %s", err), http.StatusInternalServerError)
return
}

response.SessionData = &storedSaveData
}

jsonResponse(w, r, response)
}

func handleGetSystemData(w http.ResponseWriter, r *http.Request) {
token, uuid, err := tokenAndUuidFromRequest(r)
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
return
}

var clientSessionId string
if r.URL.Query().Has("clientSessionId") {
clientSessionId = r.URL.Query().Get("clientSessionId")
} else {
httpError(w, r, fmt.Errorf("missing clientSessionId"), http.StatusBadRequest)
}

err = db.UpdateActiveSession(token, clientSessionId)
if err != nil {
httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest)
return
}

var save any //TODO this is always system save data
save, err = savedata.Get(uuid, 0, 0)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
http.Error(w, err.Error(), http.StatusNotFound)
} else {
httpError(w, r, err, http.StatusInternalServerError)
}

return
}

jsonResponse(w, r, save)
}

func legacyHandleNewClear(w http.ResponseWriter, r *http.Request) {
uuid, err := uuidFromRequest(r)
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
Expand All @@ -610,7 +770,6 @@ func handleNewClear(w http.ResponseWriter, r *http.Request) {
}

// daily

func handleDailySeed(w http.ResponseWriter, r *http.Request) {
seed, err := db.GetDailyRunSeed()
if err != nil {
Expand Down
24 changes: 13 additions & 11 deletions db/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,6 @@ func AddAccountSession(username string, token []byte) error {
return err
}

_, err = handle.Exec("UPDATE sessions s JOIN accounts a ON a.uuid = s.uuid SET s.active = 1 WHERE a.username = ? AND a.lastLoggedIn IS NULL", username)
if err != nil {
return err
}

_, err = handle.Exec("UPDATE accounts SET lastLoggedIn = UTC_TIMESTAMP() WHERE username = ?", username)
if err != nil {
return err
Expand Down Expand Up @@ -213,18 +208,25 @@ func UpdateTrainerIds(trainerId, secretId int, uuid []byte) error {
return nil
}

func IsActiveSession(token []byte) (bool, error) {
var active int
err := handle.QueryRow("SELECT `active` FROM sessions WHERE token = ?", token).Scan(&active)
func IsActiveSession(token []byte, clientSessionId string) (bool, error) {
var storedId string
err := handle.QueryRow("SELECT `clientSessionId` FROM sessions WHERE token = ?", token).Scan(&storedId)
if err != nil {
return false, err
}
if storedId == "" {
err = UpdateActiveSession(token, clientSessionId)
if err != nil {
return false, err
}
return true, nil
}

return active == 1, nil
return storedId == clientSessionId, nil
}

func UpdateActiveSession(uuid []byte, token []byte) error {
_, err := handle.Exec("UPDATE sessions SET `active` = CASE WHEN token = ? THEN 1 ELSE 0 END WHERE uuid = ?", token, uuid)
func UpdateActiveSession(token []byte, clientSessionId string) error {
_, err := handle.Exec("UPDATE sessions SET clientSessionId = ? WHERE token = ?", clientSessionId, token)
if err != nil {
return err
}
Expand Down
8 changes: 8 additions & 0 deletions db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ func Init(username, password, protocol, address, database string) error {

func setupDb(tx *sql.Tx) error {
queries := []string{
// MIGRATION 000

`CREATE TABLE IF NOT EXISTS accounts (uuid BINARY(16) NOT NULL PRIMARY KEY, username VARCHAR(16) UNIQUE NOT NULL, hash BINARY(32) NOT NULL, salt BINARY(16) NOT NULL, registered TIMESTAMP NOT NULL, lastLoggedIn TIMESTAMP DEFAULT NULL, lastActivity TIMESTAMP DEFAULT NULL, banned TINYINT(1) NOT NULL DEFAULT 0, trainerId SMALLINT(5) UNSIGNED DEFAULT 0, secretId SMALLINT(5) UNSIGNED DEFAULT 0)`,
`CREATE INDEX IF NOT EXISTS accountsByActivity ON accounts (lastActivity)`,

Expand All @@ -168,6 +170,12 @@ func setupDb(tx *sql.Tx) error {
`CREATE TABLE IF NOT EXISTS systemSaveData (uuid BINARY(16) PRIMARY KEY, data LONGBLOB, timestamp TIMESTAMP)`,

`CREATE TABLE IF NOT EXISTS sessionSaveData (uuid BINARY(16), slot TINYINT, data LONGBLOB, timestamp TIMESTAMP, PRIMARY KEY (uuid, slot))`,

// ----------------------------------
// MIGRATION 001

`ALTER TABLE sessions DROP COLUMN IF EXISTS active`,
`ALTER TABLE sessions ADD COLUMN IF NOT EXISTS clientSessionId VARCHAR(32)`,
}

for _, q := range queries {
Expand Down

0 comments on commit 436fce8

Please sign in to comment.