Skip to content

Commit

Permalink
Add user_profile_method to upstream SSO provider
Browse files Browse the repository at this point in the history
  • Loading branch information
MatMaul committed Nov 14, 2024
1 parent 0327883 commit 63e680d
Show file tree
Hide file tree
Showing 30 changed files with 521 additions and 153 deletions.
11 changes: 11 additions & 0 deletions crates/cli/src/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,15 @@ pub async fn config_sync(
}
};

let user_profile_method = match provider.user_profile_method {
mas_config::UpstreamOAuth2UserProfileMethod::Auto => {
mas_data_model::UpstreamOAuthProviderUserProfileMethod::Auto
}
mas_config::UpstreamOAuth2UserProfileMethod::UserinfoEndpoint => {
mas_data_model::UpstreamOAuthProviderUserProfileMethod::UserinfoEndpoint
}
};

repo.upstream_oauth_provider()
.upsert(
clock,
Expand All @@ -241,13 +250,15 @@ pub async fn config_sync(
brand_name: provider.brand_name,
scope: provider.scope.parse()?,
token_endpoint_auth_method: provider.token_endpoint_auth_method.into(),
user_profile_method,
token_endpoint_signing_alg: provider
.token_endpoint_auth_signing_alg
.clone(),
client_id: provider.client_id,
encrypted_client_secret,
claims_imports: map_claims_imports(&provider.claims_imports),
token_endpoint_override: provider.token_endpoint,
userinfo_endpoint_override: provider.userinfo_endpoint,
authorization_endpoint_override: provider.authorization_endpoint,
jwks_uri_override: provider.jwks_uri,
discovery_mode,
Expand Down
1 change: 1 addition & 0 deletions crates/config/src/sections/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ pub use self::{
EmailImportPreference as UpstreamOAuth2EmailImportPreference,
ImportAction as UpstreamOAuth2ImportAction, PkceMethod as UpstreamOAuth2PkceMethod,
SetEmailVerification as UpstreamOAuth2SetEmailVerification, UpstreamOAuth2Config,
UserProfileMethod as UpstreamOAuth2UserProfileMethod,
},
};
use crate::util::ConfigurationSection;
Expand Down
34 changes: 34 additions & 0 deletions crates/config/src/sections/upstream_oauth2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,26 @@ impl From<TokenAuthMethod> for OAuthClientAuthenticationMethod {
}
}

/// Whether to fetch the user profile from the userinfo endpoint,
/// or to rely on the data returned in the id_token from the token_endpoint
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum UserProfileMethod {
/// Use the userinfo endpoint if `openid` is not included in `scopes`
#[default]
Auto,

/// Always use the userinfo endpoint
UserinfoEndpoint,
}

impl UserProfileMethod {
#[allow(clippy::trivially_copy_pass_by_ref)]
const fn is_default(&self) -> bool {
matches!(self, UserProfileMethod::Auto)
}
}

/// How to handle a claim
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
#[serde(rename_all = "lowercase")]
Expand Down Expand Up @@ -401,6 +421,14 @@ pub struct Provider {
#[serde(skip_serializing_if = "Option::is_none")]
pub token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,

/// Whether to fetch the user profile from the userinfo endpoint,
/// or to rely on the data returned in the id_token from the token_endpoint.
///
/// Defaults to `auto`, which uses the userinfo endpoint if `openid` is not
/// included in `scopes`, and the ID token otherwise.
#[serde(default, skip_serializing_if = "UserProfileMethod::is_default")]
pub user_profile_method: UserProfileMethod,

/// The scopes to request from the provider
pub scope: String,

Expand All @@ -424,6 +452,12 @@ pub struct Provider {
#[serde(skip_serializing_if = "Option::is_none")]
pub authorization_endpoint: Option<Url>,

/// The URL to use for the provider's userinfo endpoint
///
/// Defaults to the `userinfo_endpoint` provided through discovery
#[serde(skip_serializing_if = "Option::is_none")]
pub userinfo_endpoint: Option<Url>,

/// The URL to use for the provider's token endpoint
///
/// Defaults to the `token_endpoint` provided through discovery
Expand Down
1 change: 1 addition & 0 deletions crates/data-model/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub use self::{
UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
UpstreamOAuthProviderImportAction, UpstreamOAuthProviderImportPreference,
UpstreamOAuthProviderPkceMode, UpstreamOAuthProviderSubjectPreference,
UpstreamOAuthProviderUserProfileMethod,
},
user_agent::{DeviceType, UserAgent},
users::{
Expand Down
1 change: 1 addition & 0 deletions crates/data-model/src/upstream_oauth2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub use self::{
PkceMode as UpstreamOAuthProviderPkceMode,
SetEmailVerification as UpsreamOAuthProviderSetEmailVerification,
SubjectPreference as UpstreamOAuthProviderSubjectPreference, UpstreamOAuthProvider,
UserProfileMethod as UpstreamOAuthProviderUserProfileMethod,
},
session::{UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState},
};
47 changes: 47 additions & 0 deletions crates/data-model/src/upstream_oauth2/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,51 @@ impl std::fmt::Display for PkceMode {
}
}

/// Whether to fetch the user profile from the userinfo endpoint,
/// or to rely on the data returned in the id_token from the token_endpoint
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum UserProfileMethod {
/// Use the userinfo endpoint if `openid` is not included in `scopes`
#[default]
Auto,

/// Always use the userinfo endpoint
UserinfoEndpoint,
}

#[derive(Debug, Clone, Error)]
#[error("Invalid user profile method {0:?}")]
pub struct InvalidUserProfileMethodError(String);

impl std::str::FromStr for UserProfileMethod {
type Err = InvalidUserProfileMethodError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"auto" => Ok(Self::Auto),
"userinfo_endpoint" => Ok(Self::UserinfoEndpoint),
s => Err(InvalidUserProfileMethodError(s.to_owned())),
}
}
}

impl UserProfileMethod {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::Auto => "auto",
Self::UserinfoEndpoint => "userinfo_endpoint",
}
}
}

impl std::fmt::Display for UserProfileMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct UpstreamOAuthProvider {
pub id: Ulid,
Expand All @@ -127,11 +172,13 @@ pub struct UpstreamOAuthProvider {
pub jwks_uri_override: Option<Url>,
pub authorization_endpoint_override: Option<Url>,
pub token_endpoint_override: Option<Url>,
pub userinfo_endpoint_override: Option<Url>,
pub scope: Scope,
pub client_id: String,
pub encrypted_client_secret: Option<String>,
pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
pub token_endpoint_auth_method: OAuthClientAuthenticationMethod,
pub user_profile_method: UserProfileMethod,
pub created_at: DateTime<Utc>,
pub disabled_at: Option<DateTime<Utc>>,
pub claims_imports: ClaimsImports,
Expand Down
21 changes: 20 additions & 1 deletion crates/data-model/src/upstream_oauth2/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ pub enum UpstreamOAuthAuthorizationSessionState {
completed_at: DateTime<Utc>,
link_id: Ulid,
id_token: Option<String>,
userinfo: Option<String>,
},
Consumed {
completed_at: DateTime<Utc>,
consumed_at: DateTime<Utc>,
link_id: Ulid,
id_token: Option<String>,
userinfo: Option<String>,
},
}

Expand All @@ -42,12 +44,14 @@ impl UpstreamOAuthAuthorizationSessionState {
completed_at: DateTime<Utc>,
link: &UpstreamOAuthLink,
id_token: Option<String>,
userinfo: Option<String>,
) -> Result<Self, InvalidTransitionError> {
match self {
Self::Pending => Ok(Self::Completed {
completed_at,
link_id: link.id,
id_token,
userinfo,
}),
Self::Completed { .. } | Self::Consumed { .. } => Err(InvalidTransitionError),
}
Expand All @@ -67,11 +71,13 @@ impl UpstreamOAuthAuthorizationSessionState {
completed_at,
link_id,
id_token,
userinfo,
} => Ok(Self::Consumed {
completed_at,
link_id,
consumed_at,
id_token,
userinfo,
}),
Self::Pending | Self::Consumed { .. } => Err(InvalidTransitionError),
}
Expand Down Expand Up @@ -124,6 +130,16 @@ impl UpstreamOAuthAuthorizationSessionState {
}
}

#[must_use]
pub fn userinfo(&self) -> Option<&str> {
match self {
Self::Pending => None,
Self::Completed { userinfo, .. } | Self::Consumed { userinfo, .. } => {
userinfo.as_deref()
}
}
}

/// Get the time at which the upstream OAuth 2.0 authorization session was
/// consumed.
///
Expand Down Expand Up @@ -201,8 +217,11 @@ impl UpstreamOAuthAuthorizationSession {
completed_at: DateTime<Utc>,
link: &UpstreamOAuthLink,
id_token: Option<String>,
userinfo: Option<String>,
) -> Result<Self, InvalidTransitionError> {
self.state = self.state.complete(completed_at, link, id_token)?;
self.state = self
.state
.complete(completed_at, link, id_token, userinfo)?;
Ok(self)
}

Expand Down
18 changes: 17 additions & 1 deletion crates/handlers/src/upstream_oauth2/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,18 @@ impl<'a> LazyProviderInfos<'a> {
Ok(self.load().await?.token_endpoint())
}

/// Get the userinfo endpoint for the provider.
///
/// Uses [`UpstreamOAuthProvider.userinfo_endpoint_override`] if set,
/// otherwise uses the one from discovery.
pub async fn userinfo_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
if let Some(userinfo_endpoint) = &self.provider.userinfo_endpoint_override {
return Ok(userinfo_endpoint);
}

Ok(self.load().await?.userinfo_endpoint())
}

/// Get the PKCE methods supported by the provider.
///
/// If the mode is set to auto, it will use the ones from discovery,
Expand Down Expand Up @@ -274,7 +286,9 @@ mod tests {
// XXX: sadly, we can't test HTTPS requests with wiremock, so we can only test
// 'insecure' discovery

use mas_data_model::UpstreamOAuthProviderClaimsImports;
use mas_data_model::{
UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderUserProfileMethod,
};
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_storage::{clock::MockClock, Clock};
use oauth2_types::scope::{Scope, OPENID};
Expand Down Expand Up @@ -386,8 +400,10 @@ mod tests {
brand_name: None,
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Insecure,
pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
user_profile_method: UpstreamOAuthProviderUserProfileMethod::Auto,
jwks_uri_override: None,
authorization_endpoint_override: None,
userinfo_endpoint_override: None,
token_endpoint_override: None,
scope: Scope::from_iter([OPENID]),
client_id: "client_id".to_owned(),
Expand Down
Loading

0 comments on commit 63e680d

Please sign in to comment.