Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow fetching user claims through the userinfo_endpoint for upstream OAuth 2.0 providers #3363

Merged
merged 5 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/cli/src/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,12 @@ pub async fn config_sync(
encrypted_client_secret,
claims_imports: map_claims_imports(&provider.claims_imports),
token_endpoint_override: provider.token_endpoint,
userinfo_endpoint_override: provider.userinfo_endpoint,
authorization_endpoint_override: provider.authorization_endpoint,
jwks_uri_override: provider.jwks_uri,
discovery_mode,
pkce_mode,
fetch_userinfo: provider.fetch_userinfo,
response_mode,
additional_authorization_parameters: provider
.additional_authorization_parameters
Expand Down
14 changes: 14 additions & 0 deletions crates/config/src/sections/upstream_oauth2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,12 +465,26 @@ pub struct Provider {
#[serde(default, skip_serializing_if = "PkceMethod::is_default")]
pub pkce_method: PkceMethod,

/// Whether to fetch the user profile from the userinfo endpoint,
/// or to rely on the data returned in the `id_token` from the
/// `token_endpoint`.
///
/// Defaults to `false`.
#[serde(default)]
pub fetch_userinfo: bool,

/// The URL to use for the provider's authorization endpoint
///
/// Defaults to the `authorization_endpoint` provided through discovery
#[serde(skip_serializing_if = "Option::is_none")]
pub authorization_endpoint: Option<Url>,

/// The URL to use for the provider's userinfo endpoint
///
/// Defaults to the `userinfo_endpoint` provided through discovery
#[serde(skip_serializing_if = "Option::is_none")]
pub userinfo_endpoint: Option<Url>,

/// The URL to use for the provider's token endpoint
///
/// Defaults to the `token_endpoint` provided through discovery
Expand Down
2 changes: 2 additions & 0 deletions crates/data-model/src/upstream_oauth2/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ pub struct UpstreamOAuthProvider {
pub authorization_endpoint_override: Option<Url>,
pub scope: Scope,
pub token_endpoint_override: Option<Url>,
pub userinfo_endpoint_override: Option<Url>,
pub fetch_userinfo: bool,
pub client_id: String,
pub encrypted_client_secret: Option<String>,
pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
Expand Down
25 changes: 22 additions & 3 deletions crates/data-model/src/upstream_oauth2/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ pub enum UpstreamOAuthAuthorizationSessionState {
link_id: Ulid,
id_token: Option<String>,
extra_callback_parameters: Option<serde_json::Value>,
userinfo: Option<serde_json::Value>,
},
Consumed {
completed_at: DateTime<Utc>,
consumed_at: DateTime<Utc>,
link_id: Ulid,
id_token: Option<String>,
extra_callback_parameters: Option<serde_json::Value>,
userinfo: Option<serde_json::Value>,
},
}

Expand All @@ -45,13 +47,15 @@ impl UpstreamOAuthAuthorizationSessionState {
link: &UpstreamOAuthLink,
id_token: Option<String>,
extra_callback_parameters: Option<serde_json::Value>,
userinfo: Option<serde_json::Value>,
) -> Result<Self, InvalidTransitionError> {
match self {
Self::Pending => Ok(Self::Completed {
completed_at,
link_id: link.id,
id_token,
extra_callback_parameters,
userinfo,
}),
Self::Completed { .. } | Self::Consumed { .. } => Err(InvalidTransitionError),
}
Expand All @@ -72,12 +76,14 @@ impl UpstreamOAuthAuthorizationSessionState {
link_id,
id_token,
extra_callback_parameters,
userinfo,
} => Ok(Self::Consumed {
completed_at,
link_id,
consumed_at,
id_token,
extra_callback_parameters,
userinfo,
}),
Self::Pending | Self::Consumed { .. } => Err(InvalidTransitionError),
}
Expand Down Expand Up @@ -151,6 +157,14 @@ impl UpstreamOAuthAuthorizationSessionState {
}
}

#[must_use]
pub fn userinfo(&self) -> Option<&serde_json::Value> {
match self {
Self::Pending => None,
Self::Completed { userinfo, .. } | Self::Consumed { userinfo, .. } => userinfo.as_ref(),
}
}

/// Get the time at which the upstream OAuth 2.0 authorization session was
/// consumed.
///
Expand Down Expand Up @@ -229,10 +243,15 @@ impl UpstreamOAuthAuthorizationSession {
link: &UpstreamOAuthLink,
id_token: Option<String>,
extra_callback_parameters: Option<serde_json::Value>,
userinfo: Option<serde_json::Value>,
) -> Result<Self, InvalidTransitionError> {
self.state =
self.state
.complete(completed_at, link, id_token, extra_callback_parameters)?;
self.state = self.state.complete(
completed_at,
link,
id_token,
extra_callback_parameters,
userinfo,
)?;
Ok(self)
}

Expand Down
14 changes: 14 additions & 0 deletions crates/handlers/src/upstream_oauth2/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,18 @@ impl<'a> LazyProviderInfos<'a> {
Ok(self.load().await?.token_endpoint())
}

/// Get the userinfo endpoint for the provider.
///
/// Uses [`UpstreamOAuthProvider.userinfo_endpoint_override`] if set,
/// otherwise uses the one from discovery.
pub async fn userinfo_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
if let Some(userinfo_endpoint) = &self.provider.userinfo_endpoint_override {
return Ok(userinfo_endpoint);
}

Ok(self.load().await?.userinfo_endpoint())
}

/// Get the PKCE methods supported by the provider.
///
/// If the mode is set to auto, it will use the ones from discovery,
Expand Down Expand Up @@ -387,9 +399,11 @@ mod tests {
brand_name: None,
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Insecure,
pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
fetch_userinfo: false,
jwks_uri_override: None,
authorization_endpoint_override: None,
scope: Scope::from_iter([OPENID]),
userinfo_endpoint_override: None,
token_endpoint_override: None,
client_id: "client_id".to_owned(),
encrypted_client_secret: None,
Expand Down
36 changes: 31 additions & 5 deletions crates/handlers/src/upstream_oauth2/callback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use mas_storage::{
use mas_templates::{FormPostContext, Templates};
use oauth2_types::errors::ClientErrorCode;
use serde::{Deserialize, Serialize};
use serde_json::json;
use thiserror::Error;
use ulid::Ulid;

Expand Down Expand Up @@ -117,14 +118,15 @@ pub(crate) enum RouteError {
},

#[error(transparent)]
Internal(Box<dyn std::error::Error>),
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
}

impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
impl_from_error_for_route!(mas_oidc_client::error::JwksError);
impl_from_error_for_route!(mas_oidc_client::error::TokenAuthorizationCodeError);
impl_from_error_for_route!(mas_oidc_client::error::UserInfoError);
impl_from_error_for_route!(super::ProviderCredentialsError);
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);

Expand Down Expand Up @@ -274,33 +276,56 @@ pub(crate) async fn handler(
redirect_uri,
};

let id_token_verification_data = JwtVerificationData {
let verification_data = JwtVerificationData {
issuer: &provider.issuer,
jwks: &jwks,
// TODO: make that configurable
signing_algorithm: &mas_iana::jose::JsonWebSignatureAlg::Rs256,
client_id: &provider.client_id,
};

let (response, id_token) =
let (response, id_token_map) =
mas_oidc_client::requests::authorization_code::access_token_with_authorization_code(
&client,
client_credentials,
lazy_metadata.token_endpoint().await?,
code,
validation_data,
Some(id_token_verification_data),
Some(verification_data),
clock.now(),
&mut rng,
)
.await?;

let (_header, id_token) = id_token.ok_or(RouteError::MissingIDToken)?.into_parts();
let (_header, id_token) = id_token_map
.clone()
.ok_or(RouteError::MissingIDToken)?
.into_parts();

let mut context = AttributeMappingContext::new().with_id_token_claims(id_token);
if let Some(extra_callback_parameters) = extra_callback_parameters.clone() {
context = context.with_extra_callback_parameters(extra_callback_parameters);
}

let userinfo = if provider.fetch_userinfo {
Some(json!(
mas_oidc_client::requests::userinfo::fetch_userinfo(
&client,
lazy_metadata.userinfo_endpoint().await?,
response.access_token.as_str(),
Some(verification_data),
&id_token_map.ok_or(RouteError::MissingIDToken)?,
)
.await?
))
} else {
None
};

if let Some(userinfo) = userinfo.clone() {
context = context.with_userinfo_claims(userinfo);
}

let context = context.build();

let env = environment();
Expand Down Expand Up @@ -341,6 +366,7 @@ pub(crate) async fn handler(
&link,
response.id_token,
extra_callback_parameters,
userinfo,
)
.await?;

Expand Down
9 changes: 9 additions & 0 deletions crates/handlers/src/upstream_oauth2/link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,9 @@ pub(crate) async fn get(
if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
}
if let Some(userinfo) = upstream_session.userinfo() {
context = context.with_userinfo_claims(userinfo.clone());
}
let context = context.build();

let ctx = if provider.claims_imports.displayname.ignore() {
Expand Down Expand Up @@ -582,6 +585,9 @@ pub(crate) async fn post(
if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
}
if let Some(userinfo) = upstream_session.userinfo() {
context = context.with_userinfo_claims(userinfo.clone());
}
let context = context.build();

// Is the email verified according to the upstream provider?
Expand Down Expand Up @@ -921,6 +927,8 @@ mod tests {
claims_imports,
authorization_endpoint_override: None,
token_endpoint_override: None,
userinfo_endpoint_override: None,
fetch_userinfo: false,
jwks_uri_override: None,
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
Expand Down Expand Up @@ -958,6 +966,7 @@ mod tests {
&link,
Some(id_token.into_string()),
None,
None,
)
.await
.unwrap();
Expand Down
49 changes: 38 additions & 11 deletions crates/handlers/src/upstream_oauth2/template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use minijinja::{
pub(crate) struct AttributeMappingContext {
id_token_claims: Option<HashMap<String, serde_json::Value>>,
extra_callback_parameters: Option<serde_json::Value>,
userinfo_claims: Option<serde_json::Value>,
}

impl AttributeMappingContext {
Expand All @@ -46,6 +47,11 @@ impl AttributeMappingContext {
self
}

pub fn with_userinfo_claims(mut self, userinfo_claims: serde_json::Value) -> Self {
self.userinfo_claims = Some(userinfo_claims);
self
}

pub fn build(self) -> Value {
Value::from_object(self)
}
Expand All @@ -54,7 +60,25 @@ impl AttributeMappingContext {
impl Object for AttributeMappingContext {
fn get_value(self: &Arc<Self>, name: &Value) -> Option<Value> {
match name.as_str()? {
"user" | "id_token_claims" => self.id_token_claims.as_ref().map(Value::from_serialize),
"user" => {
if self.id_token_claims.is_none() && self.userinfo_claims.is_none() {
return None;
}
let mut merged_user: HashMap<String, serde_json::Value> = HashMap::new();
if let serde_json::Value::Object(userinfo) = self
.userinfo_claims
.clone()
.unwrap_or(serde_json::Value::Null)
{
merged_user.extend(userinfo);
}
if let Some(id_token) = self.id_token_claims.clone() {
merged_user.extend(id_token);
}
Some(Value::from_serialize(merged_user))
}
"id_token_claims" => self.id_token_claims.as_ref().map(Value::from_serialize),
"userinfo_claims" => self.userinfo_claims.as_ref().map(Value::from_serialize),
"extra_callback_parameters" => self
.extra_callback_parameters
.as_ref()
Expand All @@ -64,17 +88,20 @@ impl Object for AttributeMappingContext {
}

fn enumerate(self: &Arc<Self>) -> Enumerator {
match (
self.id_token_claims.is_some(),
self.extra_callback_parameters.is_some(),
) {
(true, true) => {
Enumerator::Str(&["user", "id_token_claims", "extra_callback_parameters"])
}
(true, false) => Enumerator::Str(&["user", "id_token_claims"]),
(false, true) => Enumerator::Str(&["extra_callback_parameters"]),
(false, false) => Enumerator::Str(&["user"]),
let mut attrs = Vec::new();
if self.id_token_claims.is_some() || self.userinfo_claims.is_none() {
attrs.push(minijinja::Value::from("user"));
}
if self.id_token_claims.is_some() {
attrs.push(minijinja::Value::from("id_token_claims"));
}
if self.userinfo_claims.is_some() {
attrs.push(minijinja::Value::from("userinfo_claims"));
}
if self.extra_callback_parameters.is_some() {
attrs.push(minijinja::Value::from("extra_callback_parameters"));
}
Enumerator::Values(attrs)
}
}

Expand Down
4 changes: 4 additions & 0 deletions crates/handlers/src/views/login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,11 +403,13 @@ mod test {
scope: [OPENID].into_iter().collect(),
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
token_endpoint_signing_alg: None,
fetch_userinfo: false,
client_id: "client".to_owned(),
encrypted_client_secret: None,
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
authorization_endpoint_override: None,
token_endpoint_override: None,
userinfo_endpoint_override: None,
jwks_uri_override: None,
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
Expand Down Expand Up @@ -439,11 +441,13 @@ mod test {
scope: [OPENID].into_iter().collect(),
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
token_endpoint_signing_alg: None,
fetch_userinfo: false,
client_id: "client".to_owned(),
encrypted_client_secret: None,
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
authorization_endpoint_override: None,
token_endpoint_override: None,
userinfo_endpoint_override: None,
jwks_uri_override: None,
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
Expand Down
Loading
Loading