Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add managed identity support to azure container volume hook #35321

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions airflow/providers/microsoft/azure/hooks/container_volume.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
Expand All @@ -19,12 +18,11 @@

from typing import Any

from azure.identity import DefaultAzureCredential
from azure.mgmt.containerinstance.models import AzureFileVolume, Volume
from azure.mgmt.storage import StorageManagementClient

from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import get_field
from airflow.providers.microsoft.azure.utils import get_default_azure_credential, get_field


class AzureContainerVolumeHook(BaseHook):
Expand Down Expand Up @@ -72,6 +70,12 @@ def get_connection_form_widgets() -> dict[str, Any]:
lazy_gettext("Resource group name (optional)"),
widget=BS3TextFieldWidget(),
),
"managed_identity_client_id": StringField(
lazy_gettext("Managed Identity Client ID"), widget=BS3TextFieldWidget()
),
"workload_identity_tenant_id": StringField(
lazy_gettext("Workload Identity Tenant ID"), widget=BS3TextFieldWidget()
),
}

@staticmethod
Expand All @@ -89,6 +93,8 @@ def get_ui_field_behaviour() -> dict[str, Any]:
"connection_string": "connection string auth",
"subscription_id": "Subscription id (required for Azure AD authentication)",
"resource_group": "Resource group name (required for Azure AD authentication)",
"managed_identity_client_id": "Managed Identity Client ID",
"workload_identity_tenant_id": "Workload Identity Tenant ID",
},
}

Expand All @@ -106,7 +112,11 @@ def get_storagekey(self, *, storage_account_name: str | None = None) -> str:
subscription_id = self._get_field(extras, "subscription_id")
resource_group = self._get_field(extras, "resource_group")
if subscription_id and storage_account_name and resource_group:
credentials = DefaultAzureCredential()
managed_identity_client_id = self._get_field(extras, "managed_identity_client_id")
workload_identity_tenant_id = self._get_field(extras, "workload_identity_tenant_id")
credentials = get_default_azure_credential(
managed_identity_client_id, workload_identity_tenant_id
)
storage_client = StorageManagementClient(credentials, subscription_id)
storage_account_list_keys_result = storage_client.storage_accounts.list_keys(
resource_group, storage_account_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@ The Microsoft Azure Container Volume connection type enables the Azure Container
Authenticating to Azure Container Volume
----------------------------------------

There are three ways to connect to Azure Container Volume using Airflow.
There are four ways to connect to Azure Container Volume using Airflow.

1. Use `token credentials`_
i.e. add specific credentials (client_id, secret) and subscription id to the Airflow connection.
2. Use a `Connection String`_
i.e. add connection string to ``connection_string`` in the Airflow connection.
3. Fallback on DefaultAzureCredential_.
3. Use managed identity by setting ``managed_identity_client_id``, ``workload_identity_tenant_id`` (under the hook, it uses DefaultAzureCredential_ with these arguments)
4. Fallback on DefaultAzureCredential_.
This includes a mechanism to try different options to authenticate: Managed System Identity, environment variables, authentication through Azure CLI, etc.
``subscription_id`` and ``resource_group`` are required in this authentication mechanism.

Expand Down Expand Up @@ -66,6 +67,8 @@ Extra (optional)
* ``connection_string``: Connection string for use with connection string authentication. It can be left out to fall back on DefaultAzureCredential_.
* ``subscription_id``: The ID of the subscription used for the initial connection. This is needed for Azure Active Directory (DefaultAzureCredential_) authentication.
* ``resource_group``: Azure Resource Group Name under which the desired Azure file volume resides. This is needed for Azure Active Directory (DefaultAzureCredential_) authentication.
* ``managed_identity_client_id``: The client ID of a user-assigned managed identity. If provided with ``workload_identity_tenant_id``, they'll pass to DefaultAzureCredential_.
* ``workload_identity_tenant_id``: ID of the application's Microsoft Entra tenant. Also called its "directory" ID. If provided with ``managed_identity_client_id``, they'll pass to DefaultAzureCredential_.

When specifying the connection in environment variable you should specify
it using URI syntax.
Expand All @@ -82,3 +85,7 @@ For example connect with token credentials:
.. _token credentials: https://docs.microsoft.com/en-us/azure/developer/python/azure-sdk-authenticate?tabs=cmd#authenticate-with-token-credentials
.. _Connection String: https://docs.microsoft.com/en-us/azure/data-explorer/kusto/api/connection-strings/storage
.. _DefaultAzureCredential: https://docs.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python#defaultazurecredential

.. spelling:word-list::

Entra
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_get_file_volume_connection_string(self, mocked_connection):
indirect=True,
)
@mock.patch("airflow.providers.microsoft.azure.hooks.container_volume.StorageManagementClient")
@mock.patch("airflow.providers.microsoft.azure.hooks.container_volume.DefaultAzureCredential")
@mock.patch("airflow.providers.microsoft.azure.hooks.container_volume.get_default_azure_credential")
def test_get_file_volume_default_azure_credential(
self, mocked_default_azure_credential, mocked_client, mocked_connection
):
Expand All @@ -112,4 +112,4 @@ def test_get_file_volume_default_azure_credential(
assert volume.azure_file.storage_account_name == "storage"
assert volume.azure_file.read_only is True

mocked_default_azure_credential.assert_called_with()
mocked_default_azure_credential.assert_called_with(None, None)