diff --git a/airflow/providers/microsoft/azure/hooks/container_volume.py b/airflow/providers/microsoft/azure/hooks/container_volume.py index 3aa58415a48db..791c08303bc83 100644 --- a/airflow/providers/microsoft/azure/hooks/container_volume.py +++ b/airflow/providers/microsoft/azure/hooks/container_volume.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -19,12 +18,11 @@ from typing import Any -from azure.identity import DefaultAzureCredential from azure.mgmt.containerinstance.models import AzureFileVolume, Volume from azure.mgmt.storage import StorageManagementClient from airflow.hooks.base import BaseHook -from airflow.providers.microsoft.azure.utils import get_field +from airflow.providers.microsoft.azure.utils import get_default_azure_credential, get_field class AzureContainerVolumeHook(BaseHook): @@ -72,6 +70,12 @@ def get_connection_form_widgets() -> dict[str, Any]: lazy_gettext("Resource group name (optional)"), widget=BS3TextFieldWidget(), ), + "managed_identity_client_id": StringField( + lazy_gettext("Managed Identity Client ID"), widget=BS3TextFieldWidget() + ), + "workload_identity_tenant_id": StringField( + lazy_gettext("Workload Identity Tenant ID"), widget=BS3TextFieldWidget() + ), } @staticmethod @@ -89,6 +93,8 @@ def get_ui_field_behaviour() -> dict[str, Any]: "connection_string": "connection string auth", "subscription_id": "Subscription id (required for Azure AD authentication)", "resource_group": "Resource group name (required for Azure AD authentication)", + "managed_identity_client_id": "Managed Identity Client ID", + "workload_identity_tenant_id": "Workload Identity Tenant ID", }, } @@ -106,7 +112,11 @@ def get_storagekey(self, *, storage_account_name: str | None = None) -> str: subscription_id = self._get_field(extras, "subscription_id") resource_group = self._get_field(extras, "resource_group") if subscription_id and storage_account_name and resource_group: - credentials = DefaultAzureCredential() + managed_identity_client_id = self._get_field(extras, "managed_identity_client_id") + workload_identity_tenant_id = self._get_field(extras, "workload_identity_tenant_id") + credentials = get_default_azure_credential( + managed_identity_client_id, workload_identity_tenant_id + ) storage_client = StorageManagementClient(credentials, subscription_id) storage_account_list_keys_result = storage_client.storage_accounts.list_keys( resource_group, storage_account_name diff --git a/docs/apache-airflow-providers-microsoft-azure/connections/azure_container_volume.rst b/docs/apache-airflow-providers-microsoft-azure/connections/azure_container_volume.rst index 15f6d176c7059..408e0e6a51baa 100644 --- a/docs/apache-airflow-providers-microsoft-azure/connections/azure_container_volume.rst +++ b/docs/apache-airflow-providers-microsoft-azure/connections/azure_container_volume.rst @@ -27,13 +27,14 @@ The Microsoft Azure Container Volume connection type enables the Azure Container Authenticating to Azure Container Volume ---------------------------------------- -There are three ways to connect to Azure Container Volume using Airflow. +There are four ways to connect to Azure Container Volume using Airflow. 1. Use `token credentials`_ i.e. add specific credentials (client_id, secret) and subscription id to the Airflow connection. 2. Use a `Connection String`_ i.e. add connection string to ``connection_string`` in the Airflow connection. -3. Fallback on DefaultAzureCredential_. +3. Use managed identity by setting ``managed_identity_client_id``, ``workload_identity_tenant_id`` (under the hook, it uses DefaultAzureCredential_ with these arguments) +4. Fallback on DefaultAzureCredential_. This includes a mechanism to try different options to authenticate: Managed System Identity, environment variables, authentication through Azure CLI, etc. ``subscription_id`` and ``resource_group`` are required in this authentication mechanism. @@ -66,6 +67,8 @@ Extra (optional) * ``connection_string``: Connection string for use with connection string authentication. It can be left out to fall back on DefaultAzureCredential_. * ``subscription_id``: The ID of the subscription used for the initial connection. This is needed for Azure Active Directory (DefaultAzureCredential_) authentication. * ``resource_group``: Azure Resource Group Name under which the desired Azure file volume resides. This is needed for Azure Active Directory (DefaultAzureCredential_) authentication. + * ``managed_identity_client_id``: The client ID of a user-assigned managed identity. If provided with ``workload_identity_tenant_id``, they'll pass to DefaultAzureCredential_. + * ``workload_identity_tenant_id``: ID of the application's Microsoft Entra tenant. Also called its "directory" ID. If provided with ``managed_identity_client_id``, they'll pass to DefaultAzureCredential_. When specifying the connection in environment variable you should specify it using URI syntax. @@ -82,3 +85,7 @@ For example connect with token credentials: .. _token credentials: https://docs.microsoft.com/en-us/azure/developer/python/azure-sdk-authenticate?tabs=cmd#authenticate-with-token-credentials .. _Connection String: https://docs.microsoft.com/en-us/azure/data-explorer/kusto/api/connection-strings/storage .. _DefaultAzureCredential: https://docs.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python#defaultazurecredential + +.. spelling:word-list:: + + Entra diff --git a/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py b/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py index b5fdf2b08c6a1..5eec510f4e22c 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py @@ -86,7 +86,7 @@ def test_get_file_volume_connection_string(self, mocked_connection): indirect=True, ) @mock.patch("airflow.providers.microsoft.azure.hooks.container_volume.StorageManagementClient") - @mock.patch("airflow.providers.microsoft.azure.hooks.container_volume.DefaultAzureCredential") + @mock.patch("airflow.providers.microsoft.azure.hooks.container_volume.get_default_azure_credential") def test_get_file_volume_default_azure_credential( self, mocked_default_azure_credential, mocked_client, mocked_connection ): @@ -112,4 +112,4 @@ def test_get_file_volume_default_azure_credential( assert volume.azure_file.storage_account_name == "storage" assert volume.azure_file.read_only is True - mocked_default_azure_credential.assert_called_with() + mocked_default_azure_credential.assert_called_with(None, None)