Skip to content

Commit

Permalink
Add fetch_userinfo to upstream SSO provider
Browse files Browse the repository at this point in the history
  • Loading branch information
MatMaul committed Nov 25, 2024
1 parent 84776a4 commit 35ecaf6
Show file tree
Hide file tree
Showing 27 changed files with 389 additions and 137 deletions.
2 changes: 2 additions & 0 deletions crates/cli/src/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,12 @@ pub async fn config_sync(
encrypted_client_secret,
claims_imports: map_claims_imports(&provider.claims_imports),
token_endpoint_override: provider.token_endpoint,
userinfo_endpoint_override: provider.userinfo_endpoint,
authorization_endpoint_override: provider.authorization_endpoint,
jwks_uri_override: provider.jwks_uri,
discovery_mode,
pkce_mode,
fetch_userinfo: provider.fetch_userinfo,
response_mode,
additional_authorization_parameters: provider
.additional_authorization_parameters
Expand Down
12 changes: 12 additions & 0 deletions crates/config/src/sections/upstream_oauth2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,12 +465,24 @@ pub struct Provider {
#[serde(default, skip_serializing_if = "PkceMethod::is_default")]
pub pkce_method: PkceMethod,

/// Whether to fetch the user profile from the userinfo endpoint,
/// or to rely on the data returned in the id_token from the token_endpoint.
///
/// Defaults to `false`.
pub fetch_userinfo: bool,

/// The URL to use for the provider's authorization endpoint
///
/// Defaults to the `authorization_endpoint` provided through discovery
#[serde(skip_serializing_if = "Option::is_none")]
pub authorization_endpoint: Option<Url>,

/// The URL to use for the provider's userinfo endpoint
///
/// Defaults to the `userinfo_endpoint` provided through discovery
#[serde(skip_serializing_if = "Option::is_none")]
pub userinfo_endpoint: Option<Url>,

/// The URL to use for the provider's token endpoint
///
/// Defaults to the `token_endpoint` provided through discovery
Expand Down
2 changes: 2 additions & 0 deletions crates/data-model/src/upstream_oauth2/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ pub struct UpstreamOAuthProvider {
pub authorization_endpoint_override: Option<Url>,
pub scope: Scope,
pub token_endpoint_override: Option<Url>,
pub userinfo_endpoint_override: Option<Url>,
pub fetch_userinfo: bool,
pub client_id: String,
pub encrypted_client_secret: Option<String>,
pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
Expand Down
25 changes: 22 additions & 3 deletions crates/data-model/src/upstream_oauth2/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ pub enum UpstreamOAuthAuthorizationSessionState {
link_id: Ulid,
id_token: Option<String>,
extra_callback_parameters: Option<serde_json::Value>,
userinfo: Option<serde_json::Value>,
},
Consumed {
completed_at: DateTime<Utc>,
consumed_at: DateTime<Utc>,
link_id: Ulid,
id_token: Option<String>,
extra_callback_parameters: Option<serde_json::Value>,
userinfo: Option<serde_json::Value>,
},
}

Expand All @@ -45,13 +47,15 @@ impl UpstreamOAuthAuthorizationSessionState {
link: &UpstreamOAuthLink,
id_token: Option<String>,
extra_callback_parameters: Option<serde_json::Value>,
userinfo: Option<serde_json::Value>,
) -> Result<Self, InvalidTransitionError> {
match self {
Self::Pending => Ok(Self::Completed {
completed_at,
link_id: link.id,
id_token,
extra_callback_parameters,
userinfo,
}),
Self::Completed { .. } | Self::Consumed { .. } => Err(InvalidTransitionError),
}
Expand All @@ -72,12 +76,14 @@ impl UpstreamOAuthAuthorizationSessionState {
link_id,
id_token,
extra_callback_parameters,
userinfo,
} => Ok(Self::Consumed {
completed_at,
link_id,
consumed_at,
id_token,
extra_callback_parameters,
userinfo,
}),
Self::Pending | Self::Consumed { .. } => Err(InvalidTransitionError),
}
Expand Down Expand Up @@ -151,6 +157,14 @@ impl UpstreamOAuthAuthorizationSessionState {
}
}

#[must_use]
pub fn userinfo(&self) -> Option<&serde_json::Value> {
match self {
Self::Pending => None,
Self::Completed { userinfo, .. } | Self::Consumed { userinfo, .. } => userinfo.as_ref(),
}
}

/// Get the time at which the upstream OAuth 2.0 authorization session was
/// consumed.
///
Expand Down Expand Up @@ -229,10 +243,15 @@ impl UpstreamOAuthAuthorizationSession {
link: &UpstreamOAuthLink,
id_token: Option<String>,
extra_callback_parameters: Option<serde_json::Value>,
userinfo: Option<serde_json::Value>,
) -> Result<Self, InvalidTransitionError> {
self.state =
self.state
.complete(completed_at, link, id_token, extra_callback_parameters)?;
self.state = self.state.complete(
completed_at,
link,
id_token,
extra_callback_parameters,
userinfo,
)?;
Ok(self)
}

Expand Down
16 changes: 15 additions & 1 deletion crates/handlers/src/upstream_oauth2/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,18 @@ impl<'a> LazyProviderInfos<'a> {
Ok(self.load().await?.token_endpoint())
}

/// Get the userinfo endpoint for the provider.
///
/// Uses [`UpstreamOAuthProvider.userinfo_endpoint_override`] if set,
/// otherwise uses the one from discovery.
pub async fn userinfo_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
if let Some(userinfo_endpoint) = &self.provider.userinfo_endpoint_override {
return Ok(userinfo_endpoint);
}

Ok(self.load().await?.userinfo_endpoint())
}

/// Get the PKCE methods supported by the provider.
///
/// If the mode is set to auto, it will use the ones from discovery,
Expand Down Expand Up @@ -387,10 +399,12 @@ mod tests {
brand_name: None,
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Insecure,
pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
fetch_userinfo: false,
jwks_uri_override: None,
authorization_endpoint_override: None,
scope: Scope::from_iter([OPENID]),
userinfo_endpoint_override: None,
token_endpoint_override: None,
scope: Scope::from_iter([OPENID]),
client_id: "client_id".to_owned(),
encrypted_client_secret: None,
token_endpoint_signing_alg: None,
Expand Down
36 changes: 31 additions & 5 deletions crates/handlers/src/upstream_oauth2/callback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use mas_storage::{
use mas_templates::{FormPostContext, Templates};
use oauth2_types::errors::ClientErrorCode;
use serde::{Deserialize, Serialize};
use serde_json::json;
use thiserror::Error;
use ulid::Ulid;

Expand Down Expand Up @@ -117,14 +118,15 @@ pub(crate) enum RouteError {
},

#[error(transparent)]
Internal(Box<dyn std::error::Error>),
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
}

impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
impl_from_error_for_route!(mas_oidc_client::error::JwksError);
impl_from_error_for_route!(mas_oidc_client::error::TokenAuthorizationCodeError);
impl_from_error_for_route!(mas_oidc_client::error::UserInfoError);
impl_from_error_for_route!(super::ProviderCredentialsError);
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);

Expand Down Expand Up @@ -274,33 +276,56 @@ pub(crate) async fn handler(
redirect_uri,
};

let id_token_verification_data = JwtVerificationData {
let verification_data = JwtVerificationData {
issuer: &provider.issuer,
jwks: &jwks,
// TODO: make that configurable
signing_algorithm: &mas_iana::jose::JsonWebSignatureAlg::Rs256,
client_id: &provider.client_id,
};

let (response, id_token) =
let (response, id_token_map) =
mas_oidc_client::requests::authorization_code::access_token_with_authorization_code(
&client,
client_credentials,
lazy_metadata.token_endpoint().await?,
code,
validation_data,
Some(id_token_verification_data),
Some(verification_data),
clock.now(),
&mut rng,
)
.await?;

let (_header, id_token) = id_token.ok_or(RouteError::MissingIDToken)?.into_parts();
let (_header, id_token) = id_token_map
.clone()
.ok_or(RouteError::MissingIDToken)?
.into_parts();

let mut context = AttributeMappingContext::new().with_id_token_claims(id_token);
if let Some(extra_callback_parameters) = extra_callback_parameters.clone() {
context = context.with_extra_callback_parameters(extra_callback_parameters);
}

let userinfo = if provider.fetch_userinfo {
Some(json!(
mas_oidc_client::requests::userinfo::fetch_userinfo(
&client,
lazy_metadata.userinfo_endpoint().await?,
response.access_token.as_str(),
Some(verification_data),
&id_token_map.ok_or(RouteError::MissingIDToken)?,
)
.await?
))
} else {
None
};

if let Some(userinfo) = userinfo.clone() {
context = context.with_userinfo_claims(userinfo);
}

let context = context.build();

let env = environment();
Expand Down Expand Up @@ -341,6 +366,7 @@ pub(crate) async fn handler(
&link,
response.id_token,
extra_callback_parameters,
userinfo,
)
.await?;

Expand Down
9 changes: 9 additions & 0 deletions crates/handlers/src/upstream_oauth2/link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,9 @@ pub(crate) async fn get(
if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
}
if let Some(userinfo) = upstream_session.userinfo() {
context = context.with_userinfo_claims(userinfo.clone());
}
let context = context.build();

let ctx = if provider.claims_imports.displayname.ignore() {
Expand Down Expand Up @@ -582,6 +585,9 @@ pub(crate) async fn post(
if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
}
if let Some(userinfo) = upstream_session.userinfo() {
context = context.with_userinfo_claims(userinfo.clone());
}
let context = context.build();

// Is the email verified according to the upstream provider?
Expand Down Expand Up @@ -921,6 +927,8 @@ mod tests {
claims_imports,
authorization_endpoint_override: None,
token_endpoint_override: None,
userinfo_endpoint_override: None,
fetch_userinfo: false,
jwks_uri_override: None,
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
Expand Down Expand Up @@ -958,6 +966,7 @@ mod tests {
&link,
Some(id_token.into_string()),
None,
None,
)
.await
.unwrap();
Expand Down
49 changes: 38 additions & 11 deletions crates/handlers/src/upstream_oauth2/template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use minijinja::{
pub(crate) struct AttributeMappingContext {
id_token_claims: Option<HashMap<String, serde_json::Value>>,
extra_callback_parameters: Option<serde_json::Value>,
userinfo_claims: Option<serde_json::Value>,
}

impl AttributeMappingContext {
Expand All @@ -46,6 +47,11 @@ impl AttributeMappingContext {
self
}

pub fn with_userinfo_claims(mut self, userinfo_claims: serde_json::Value) -> Self {
self.userinfo_claims = Some(userinfo_claims);
self
}

pub fn build(self) -> Value {
Value::from_object(self)
}
Expand All @@ -54,7 +60,25 @@ impl AttributeMappingContext {
impl Object for AttributeMappingContext {
fn get_value(self: &Arc<Self>, name: &Value) -> Option<Value> {
match name.as_str()? {
"user" | "id_token_claims" => self.id_token_claims.as_ref().map(Value::from_serialize),
"user" => {
if self.id_token_claims.is_none() && self.userinfo_claims.is_none() {
return None;
}
let mut merged_user: HashMap<String, serde_json::Value> = HashMap::new();
if let serde_json::Value::Object(userinfo) = self
.userinfo_claims
.clone()
.unwrap_or(serde_json::Value::Null)
{
merged_user.extend(userinfo);
}
if let Some(id_token) = self.id_token_claims.clone() {
merged_user.extend(id_token);
}
Some(Value::from_serialize(merged_user))
}
"id_token_claims" => self.id_token_claims.as_ref().map(Value::from_serialize),
"userinfo_claims" => self.userinfo_claims.as_ref().map(Value::from_serialize),
"extra_callback_parameters" => self
.extra_callback_parameters
.as_ref()
Expand All @@ -64,17 +88,20 @@ impl Object for AttributeMappingContext {
}

fn enumerate(self: &Arc<Self>) -> Enumerator {
match (
self.id_token_claims.is_some(),
self.extra_callback_parameters.is_some(),
) {
(true, true) => {
Enumerator::Str(&["user", "id_token_claims", "extra_callback_parameters"])
}
(true, false) => Enumerator::Str(&["user", "id_token_claims"]),
(false, true) => Enumerator::Str(&["extra_callback_parameters"]),
(false, false) => Enumerator::Str(&["user"]),
let mut attrs = Vec::new();
if self.id_token_claims.is_some() || self.userinfo_claims.is_none() {
attrs.push(minijinja::Value::from("user"));
}
if self.id_token_claims.is_some() {
attrs.push(minijinja::Value::from("id_token_claims"));
}
if self.userinfo_claims.is_some() {
attrs.push(minijinja::Value::from("userinfo_claims"));
}
if self.extra_callback_parameters.is_some() {
attrs.push(minijinja::Value::from("extra_callback_parameters"));
}
Enumerator::Values(attrs)
}
}

Expand Down
4 changes: 4 additions & 0 deletions crates/handlers/src/views/login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,11 +403,13 @@ mod test {
scope: [OPENID].into_iter().collect(),
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
token_endpoint_signing_alg: None,
fetch_userinfo: false,
client_id: "client".to_owned(),
encrypted_client_secret: None,
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
authorization_endpoint_override: None,
token_endpoint_override: None,
userinfo_endpoint_override: None,
jwks_uri_override: None,
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
Expand Down Expand Up @@ -439,11 +441,13 @@ mod test {
scope: [OPENID].into_iter().collect(),
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
token_endpoint_signing_alg: None,
fetch_userinfo: false,
client_id: "client".to_owned(),
encrypted_client_secret: None,
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
authorization_endpoint_override: None,
token_endpoint_override: None,
userinfo_endpoint_override: None,
jwks_uri_override: None,
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
Expand Down
Loading

0 comments on commit 35ecaf6

Please sign in to comment.