Skip to content

Commit

Permalink
adding option to specify client_id for MSI (#748)
Browse files Browse the repository at this point in the history
Co-authored-by: Antonio Jimenez <[email protected]>
  • Loading branch information
aj9411 and Antonio Jimenez authored May 10, 2022
1 parent 1a5430b commit dfb2430
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 4 deletions.
6 changes: 4 additions & 2 deletions sdk/identity/src/token_credentials/default_credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl DefaultAzureCredentialBuilder {
}
if self.include_managed_identity_credential {
sources.push(DefaultAzureCredentialEnum::ManagedIdentity(
ImdsManagedIdentityCredential {},
ImdsManagedIdentityCredential::default(),
))
}
if self.include_cli_credential {
Expand Down Expand Up @@ -142,7 +142,9 @@ impl Default for DefaultAzureCredential {
DefaultAzureCredential {
sources: vec![
DefaultAzureCredentialEnum::Environment(EnvironmentCredential::default()),
DefaultAzureCredentialEnum::ManagedIdentity(ImdsManagedIdentityCredential {}),
DefaultAzureCredentialEnum::ManagedIdentity(
ImdsManagedIdentityCredential::default(),
),
DefaultAzureCredentialEnum::AzureCli(AzureCliCredential {}),
],
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,53 @@ const MSI_API_VERSION: &str = "2019-08-01";
/// This authentication type works in Azure VMs, App Service and Azure Functions applications, as well as the Azure Cloud Shell
///
/// Built up from docs at [https://docs.microsoft.com/azure/app-service/overview-managed-identity#using-the-rest-protocol](https://docs.microsoft.com/azure/app-service/overview-managed-identity#using-the-rest-protocol)
pub struct ImdsManagedIdentityCredential;
#[derive(Default)]
pub struct ImdsManagedIdentityCredential {
object_id: Option<String>,
client_id: Option<String>,
msi_res_id: Option<String>,
}

impl ImdsManagedIdentityCredential {
/// Specifies the object id associated with a user assigned managed service identity resource that should be used to retrieve the access token.
///
/// The values of client_id and msi_res_id are discarded, as only one id parameter may be set when getting a token.
pub fn with_object_id<A>(mut self, object_id: A) -> Self
where
A: Into<String>,
{
self.object_id = Some(object_id.into());
self.client_id = None;
self.msi_res_id = None;
self
}

/// Specifies the application id (client id) associated with a user assigned managed service identity resource that should be used to retrieve the access token.
///
/// The values of object_id and msi_res_id are discarded, as only one id parameter may be set when getting a token.
pub fn with_client_id<A>(mut self, client_id: A) -> Self
where
A: Into<String>,
{
self.client_id = Some(client_id.into());
self.object_id = None;
self.msi_res_id = None;
self
}

/// Specifies the ARM resource id of the user assigned managed service identity resource that should be used to retrieve the access token.
///
/// The values of object_id and client_id are discarded, as only one id parameter may be set when getting a token.
pub fn with_identity<A>(mut self, msi_res_id: A) -> Self
where
A: Into<String>,
{
self.msi_res_id = Some(msi_res_id.into());
self.object_id = None;
self.client_id = None;
self
}
}

#[allow(missing_docs)]
#[non_exhaustive]
Expand Down Expand Up @@ -49,7 +95,18 @@ impl TokenCredential for ImdsManagedIdentityCredential {
let msi_endpoint = std::env::var(MSI_ENDPOINT_ENV_KEY)
.unwrap_or_else(|_| "http://169.254.169.254/metadata/identity/oauth2/token".to_owned());

let query_items = vec![("api-version", MSI_API_VERSION), ("resource", resource)];
let mut query_items = vec![("api-version", MSI_API_VERSION), ("resource", resource)];

match (
self.object_id.as_ref(),
self.client_id.as_ref(),
self.msi_res_id.as_ref(),
) {
(Some(object_id), None, None) => query_items.push(("object_id", object_id)),
(None, Some(client_id), None) => query_items.push(("client_id", client_id)),
(None, None, Some(msi_res_id)) => query_items.push(("msi_res_id", msi_res_id)),
_ => (),
}

let msi_endpoint_url = Url::parse_with_params(&msi_endpoint, &query_items)
.map_err(ManagedIdentityCredentialError::MsiEndpointParseUrlError)?;
Expand Down

0 comments on commit dfb2430

Please sign in to comment.