Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OAuth flow account creation/login #460

Merged
merged 7 commits into from
Feb 11, 2024
32 changes: 28 additions & 4 deletions backend/prisma-cli/prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ model User {
degree_name String?
degree_starting_year Int?
created_at DateTime @default(now())
updated_at DateTime
updated_at DateTime @default(now())
role UserRole @default(User)
applications Application[]
OrganisationAdmins OrganisationAdmins[]
Ratings Ratings[]
Ratings Rating[]

@@map("users")
}

enum UserRole {
Expand All @@ -31,6 +33,8 @@ model Organisation {
updated_at DateTime @default(now())
campaigns Campaign[]
OrganisationAdmins OrganisationAdmins[]

@@map("organisations")
}

model OrganisationAdmins {
Expand All @@ -40,6 +44,8 @@ model OrganisationAdmins {
user_id BigInt

@@id([organisation_id, user_id])

@@map("organisation_admins")
}

model Campaign {
Expand All @@ -56,6 +62,8 @@ model Campaign {
organisation_id BigInt
roles CampaignRole[]
questions Question[]

@@map("campaigns")
}

model CampaignRole {
Expand All @@ -70,6 +78,8 @@ model CampaignRole {
created_at DateTime @default(now())
updated_at DateTime @default(now())
Application Application[]

@@map("campaign_roles")
}

model Question {
Expand All @@ -84,6 +94,8 @@ model Question {
campaignId BigInt?
MultiOptionQuestion MultiOptionQuestion[]
Answer Answer[]

@@map("questions")
}

enum QuestionType {
Expand All @@ -98,6 +110,8 @@ model MultiOptionQuestion {
text String
question Question @relation(fields: [question_id], references: [id])
question_id BigInt

@@map("multi_option_questions")
}

model Application {
Expand All @@ -111,7 +125,9 @@ model Application {
answers Answer[]
created_at DateTime @default(now())
updated_at DateTime @default(now())
ratings Ratings[]
ratings Rating[]

@@map("applications")
}

enum ApplicationStatus {
Expand All @@ -128,23 +144,29 @@ model Answer {
question_id BigInt
shortAnswerAnswers ShortAnswerAnswer[]
multiOptionAnswers MultiOptionAnswer[]

@@map("answers")
}

model ShortAnswerAnswer {
id BigInt @id @default(autoincrement())
text String
Answer Answer @relation(fields: [answer_id], references: [id])
answer_id BigInt

@@map("short_answer_answers")
}

model MultiOptionAnswer {
id BigInt @id @default(autoincrement())
option Int
Answer Answer @relation(fields: [answer_id], references: [id])
answer_id BigInt

@@map("multi_option_answers")
}

model Ratings {
model Rating {
id BigInt @id @default(autoincrement())
Application Application @relation(fields: [application_id], references: [id])
application_id BigInt
Expand All @@ -153,4 +175,6 @@ model Ratings {
rating Int
created_at DateTime @default(now())
updated_at DateTime @default(now())

@@map("ratings")
}
1 change: 1 addition & 0 deletions backend/server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ chrono = { version = "0.4.26", features = ["serde"] }
oauth2 = "4.4.1"
log = "0.4.20"
uuid = { version = "1.5.0", features = ["serde", "v4"] }
rs-snowflake = "0.6.0"
jsonwebtoken = "9.1.0"
6 changes: 3 additions & 3 deletions backend/server/src/handler/auth.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::models::app::AppState;
use crate::models::auth::{AuthRequest, UserProfile};
use crate::models::auth::{AuthRequest, GoogleUserProfile};
use crate::service::auth::create_or_get_user_id;
use axum::extract::{Query, State};
use axum::http::StatusCode;
Expand Down Expand Up @@ -41,9 +41,9 @@ pub async fn google_callback(
Err(e) => return Err((StatusCode::INTERNAL_SERVER_ERROR, e.to_string())),
};

// let profile = profile.json::<UserProfile>().await?;
let profile = profile.json::<GoogleUserProfile>().await?;

// let user_id = create_or_get_user_id(profile.email, state.db).await?;
let user_id = create_or_get_user_id(profile.email, profile.name, state.db).await?;

// TODO: Create a JWT from this user_id and return to the user.
Ok("woohoo")
Expand Down
5 changes: 5 additions & 0 deletions backend/server/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::env;
use axum::{routing::get, Router};
use jsonwebtoken::{DecodingKey, EncodingKey};
use snowflake::SnowflakeIdGenerator;
use sqlx::postgres::PgPoolOptions;
use models::app::AppState;
mod handler;
Expand Down Expand Up @@ -28,12 +29,16 @@ async fn main() {
// Initialise reqwest client
let ctx = reqwest::Client::new();

// Initialise Snowflake Generator
let snowflake_generator = SnowflakeIdGenerator::new(1, 1);

// Add all data to AppState
let state = AppState {
db: pool,
ctx,
encoding_key,
decoding_key,
snowflake_generator,
};

let app = Router::new()
Expand Down
2 changes: 2 additions & 0 deletions backend/server/src/models/app.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use jsonwebtoken::{DecodingKey, EncodingKey};
use reqwest::Client as ReqwestClient;
use snowflake::SnowflakeIdGenerator;
use sqlx::{Pool, Postgres};

#[derive(Clone)]
Expand All @@ -8,4 +9,5 @@ pub struct AppState {
pub ctx: ReqwestClient,
pub decoding_key: DecodingKey,
pub encoding_key: EncodingKey,
pub snowflake_generator: SnowflakeIdGenerator,
}
3 changes: 2 additions & 1 deletion backend/server/src/models/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ pub struct AuthRequest {
}

#[derive(Deserialize, Serialize)]
pub struct UserProfile {
pub struct GoogleUserProfile {
pub name: String,
pub email: String,
}

Expand Down
2 changes: 1 addition & 1 deletion backend/server/src/models/user.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize};

#[derive(Deserialize, Serialize, sqlx::Type, Clone)]
#[sqlx(type_name = "user_role", rename_all = "PascalCase")]
#[sqlx(type_name = "UserRole", rename_all = "PascalCase")]
pub enum UserRole {
User,
SuperUser,
Expand Down
28 changes: 24 additions & 4 deletions backend/server/src/service/auth.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,41 @@
use anyhow::Result;
use jsonwebtoken::{DecodingKey, EncodingKey};
use snowflake::SnowflakeIdGenerator;
use sqlx::{Pool, Postgres};
use crate::models::user::UserRole;

/// Checks if a user exists in DB based on given email address. If so, their user_id is returned.
/// Otherwise, a new user is created in the DB, and the new id is returned.
/// This function is used in OAuth flows to login/signup users when they click the
/// "Sign in with ___" buttons. The returned user_id will be used to generate a JWT to be
/// used as a token for the user's browser.
pub async fn create_or_get_user_id(email: String, pool: Pool<Postgres>) -> Result<i64> {
// TODO: See if user (by email) exists in the database and return their id. If not, insert them, and return the new id.
pub async fn create_or_get_user_id(email: String, name: String, pool: Pool<Postgres>, mut snowflake_generator: SnowflakeIdGenerator) -> Result<i64> {
let possible_user_id = sqlx::query!("SELECT id FROM users WHERE email = $1", email)
.fetch_optional(&pool)
.await?;

if let Some(result) = possible_user_id {
return Ok(result.id);
}

let user_id = snowflake_generator.real_time_generate();

let response = sqlx::query!(
"INSERT INTO users (id, email, name) VALUES ($1, $2, $3)",
user_id, email, name
)
.execute(&pool)
.await?;

let user_id = 1;
return Ok(user_id);
}

pub async fn is_super_user(user_id: i64, pool: &Pool<Postgres>) -> Result<bool> {
let is_super_user = sqlx::query!("SELECT EXISTS(SELECT 1 FROM users WHERE id = $1 AND role = $2)", user_id, UserRole::SuperUser)
let is_super_user = sqlx::query!(
"SELECT EXISTS(SELECT 1 FROM users WHERE id = $1 AND role = $2)",
user_id,
UserRole::SuperUser as UserRole
)
.fetch_one(pool)
.await?;

Expand Down
Loading