Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add managed identity support to azure container instance hook #35319

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion airflow/providers/microsoft/azure/hooks/base_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,12 @@ def get_conn(self) -> Any:
)
else:
self.log.info("Using DefaultAzureCredential as credential")
credentials = AzureIdentityCredentialAdapter()
managed_identity_client_id = conn.extra_dejson.get("managed_identity_client_id")
workload_identity_tenant_id = conn.extra_dejson.get("workload_identity_tenant_id")
credentials = AzureIdentityCredentialAdapter(
managed_identity_client_id=managed_identity_client_id,
workload_identity_tenant_id=workload_identity_tenant_id,
)

return self.sdk_client(
credentials=credentials,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.microsoft.azure.hooks.base_azure import AzureBaseHook
from airflow.providers.microsoft.azure.utils import get_default_azure_credential

if TYPE_CHECKING:
from azure.mgmt.containerinstance.models import (
Expand Down Expand Up @@ -92,7 +93,9 @@ def get_conn(self) -> Any:
)
else:
self.log.info("Using DefaultAzureCredential as credential")
credential = DefaultAzureCredential()
managed_identity_client_id = conn.extra_dejson.get("managed_identity_client_id")
workload_identity_tenant_id = conn.extra_dejson.get("workload_identity_tenant_id")
credential = get_default_azure_credential(managed_identity_client_id, workload_identity_tenant_id)

subscription_id = cast(str, conn.extra_dejson.get("subscriptionId"))
return ContainerInstanceManagementClient(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@ There are five ways to connect to Azure using Airflow.
2. Use a `JSON file`_
3. Use a `JSON dictionary`_
i.e. add a key config directly into the Airflow connection.
4. Use managed identity through providing ``managed_identity_client_id`` and ``workload_identity_tenant_id``.
5. Fallback on `DefaultAzureCredential`_.
This includes a mechanism to try different options to authenticate: Managed System Identity, environment variables, authentication through Azure CLI and etc.
``subscriptionId`` is required in this authentication mechanism.
4. Use managed identity by setting ``managed_identity_client_id``, ``workload_identity_tenant_id`` (under the hook, it uses DefaultAzureCredential_ with these arguments)
5. Fallback on `DefaultAzureCredential`_
This includes a mechanism to try different options to authenticate: Managed System Identity, environment variables, authentication through Azure CLI and etc. ``subscriptionId`` is required in this authentication mechanism.

Only one authorization method can be used at a time. If you need to manage multiple credentials or keys then you should
configure multiple connections.
Expand Down Expand Up @@ -73,8 +72,8 @@ Extra (optional)
It specifies the path to the json file that contains the authentication information.
* ``key_json``: If set, it uses the *JSON dictionary* authentication mechanism.
It specifies the json that contains the authentication information.
* ``managed_identity_client_id``: The client ID of a user-assigned managed identity. If provided with `workload_identity_tenant_id`, they'll pass to ``DefaultAzureCredential``.
* ``workload_identity_tenant_id``: ID of the application's Microsoft Entra tenant. Also called its "directory" ID. If provided with `managed_identity_client_id`, they'll pass to ``DefaultAzureCredential``.
* ``managed_identity_client_id``: The client ID of a user-assigned managed identity. If provided with ``workload_identity_tenant_id``, they'll pass to DefaultAzureCredential_.
* ``workload_identity_tenant_id``: ID of the application's Microsoft Entra tenant. Also called its "directory" ID. If provided with ``managed_identity_client_id``, they'll pass to DefaultAzureCredential_.

The entire extra column can be left out to fall back on DefaultAzureCredential_.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_connection_failure(self, mock_container_groups_list):
class TestAzureContainerInstanceHookWithoutSetupCredential:
@patch("airflow.providers.microsoft.azure.hooks.container_instance.ContainerInstanceManagementClient")
@patch("azure.common.credentials.ServicePrincipalCredentials")
@patch("airflow.providers.microsoft.azure.hooks.container_instance.DefaultAzureCredential")
@patch("airflow.providers.microsoft.azure.hooks.container_instance.get_default_azure_credential")
def test_get_conn_fallback_to_default_azure_credential(
self,
mock_default_azure_credential,
Expand Down