Skip to content

Commit

Permalink
AWS auth manager: implement all is_authorized_* methods (but `is_au…
Browse files Browse the repository at this point in the history
…thorized_dag`) (#35928)
  • Loading branch information
vincbeck authored Nov 29, 2023
1 parent 3a084b7 commit 985d058
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 8 deletions.
9 changes: 8 additions & 1 deletion airflow/providers/amazon/aws/auth_manager/avp/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,16 @@ class AvpEntities(Enum):

ACTION = "Action"
ROLE = "Role"
VARIABLE = "Variable"
USER = "User"

# Resource types
CONFIGURATION = "Configuration"
CONNECTION = "Connection"
DATASET = "Dataset"
POOL = "Pool"
VARIABLE = "Variable"
VIEW = "View"


def get_entity_type(resource_type: AvpEntities) -> str:
"""
Expand Down
39 changes: 34 additions & 5 deletions airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,13 @@ def is_authorized_configuration(
details: ConfigurationDetails | None = None,
user: BaseUser | None = None,
) -> bool:
return self.is_logged_in()
config_section = details.section if details else None
return self.avp_facade.is_authorized(
method=method,
entity_type=AvpEntities.CONFIGURATION,
user=user or self.get_user(),
entity_id=config_section,
)

def is_authorized_cluster_activity(self, *, method: ResourceMethod, user: BaseUser | None = None) -> bool:
return self.is_logged_in()
Expand All @@ -103,7 +109,13 @@ def is_authorized_connection(
details: ConnectionDetails | None = None,
user: BaseUser | None = None,
) -> bool:
return self.is_logged_in()
connection_id = details.conn_id if details else None
return self.avp_facade.is_authorized(
method=method,
entity_type=AvpEntities.CONNECTION,
user=user or self.get_user(),
entity_id=connection_id,
)

def is_authorized_dag(
self,
Expand All @@ -118,12 +130,24 @@ def is_authorized_dag(
def is_authorized_dataset(
self, *, method: ResourceMethod, details: DatasetDetails | None = None, user: BaseUser | None = None
) -> bool:
return self.is_logged_in()
dataset_uri = details.uri if details else None
return self.avp_facade.is_authorized(
method=method,
entity_type=AvpEntities.DATASET,
user=user or self.get_user(),
entity_id=dataset_uri,
)

def is_authorized_pool(
self, *, method: ResourceMethod, details: PoolDetails | None = None, user: BaseUser | None = None
) -> bool:
return self.is_logged_in()
pool_name = details.name if details else None
return self.avp_facade.is_authorized(
method=method,
entity_type=AvpEntities.POOL,
user=user or self.get_user(),
entity_id=pool_name,
)

def is_authorized_variable(
self, *, method: ResourceMethod, details: VariableDetails | None = None, user: BaseUser | None = None
Expand All @@ -142,7 +166,12 @@ def is_authorized_view(
access_view: AccessView,
user: BaseUser | None = None,
) -> bool:
return self.is_logged_in()
return self.avp_facade.is_authorized(
method="GET",
entity_type=AvpEntities.VIEW,
user=user or self.get_user(),
entity_id=access_view.value,
)

def get_url_login(self, **kwargs) -> str:
return url_for("AwsAuthManagerAuthenticationViews.login")
Expand Down
135 changes: 133 additions & 2 deletions tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,14 @@
import pytest
from flask import Flask, session

from airflow.auth.managers.models.resource_details import VariableDetails
from airflow.auth.managers.models.resource_details import (
AccessView,
ConfigurationDetails,
ConnectionDetails,
DatasetDetails,
PoolDetails,
VariableDetails,
)
from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities
from airflow.providers.amazon.aws.auth_manager.aws_auth_manager import AwsAuthManager
from airflow.providers.amazon.aws.auth_manager.security_manager.aws_security_manager_override import (
Expand Down Expand Up @@ -110,6 +117,108 @@ def test_is_logged_in_return_false_when_no_user_in_session(self, auth_manager, a

assert result is False

@pytest.mark.parametrize(
"details, user, expected_user, expected_entity_id",
[
(None, None, ANY, None),
(ConfigurationDetails(section="test"), mock, mock, "test"),
],
)
@patch.object(AwsAuthManager, "avp_facade")
@patch.object(AwsAuthManager, "get_user")
def test_is_authorized_configuration(
self, mock_get_user, mock_avp_facade, details, user, expected_user, expected_entity_id, auth_manager
):
is_authorized = Mock()
mock_avp_facade.is_authorized = is_authorized

method: ResourceMethod = "GET"
auth_manager.is_authorized_configuration(method=method, details=details, user=user)

if not user:
mock_get_user.assert_called_once()
is_authorized.assert_called_once_with(
method=method,
entity_type=AvpEntities.CONFIGURATION,
user=expected_user,
entity_id=expected_entity_id,
)

@pytest.mark.parametrize(
"details, user, expected_user, expected_entity_id",
[
(None, None, ANY, None),
(ConnectionDetails(conn_id="conn_id"), mock, mock, "conn_id"),
],
)
@patch.object(AwsAuthManager, "avp_facade")
@patch.object(AwsAuthManager, "get_user")
def test_is_authorized_connection(
self, mock_get_user, mock_avp_facade, details, user, expected_user, expected_entity_id, auth_manager
):
is_authorized = Mock()
mock_avp_facade.is_authorized = is_authorized

method: ResourceMethod = "GET"
auth_manager.is_authorized_connection(method=method, details=details, user=user)

if not user:
mock_get_user.assert_called_once()
is_authorized.assert_called_once_with(
method=method,
entity_type=AvpEntities.CONNECTION,
user=expected_user,
entity_id=expected_entity_id,
)

@pytest.mark.parametrize(
"details, user, expected_user, expected_entity_id",
[
(None, None, ANY, None),
(DatasetDetails(uri="uri"), mock, mock, "uri"),
],
)
@patch.object(AwsAuthManager, "avp_facade")
@patch.object(AwsAuthManager, "get_user")
def test_is_authorized_dataset(
self, mock_get_user, mock_avp_facade, details, user, expected_user, expected_entity_id, auth_manager
):
is_authorized = Mock()
mock_avp_facade.is_authorized = is_authorized

method: ResourceMethod = "GET"
auth_manager.is_authorized_dataset(method=method, details=details, user=user)

if not user:
mock_get_user.assert_called_once()
is_authorized.assert_called_once_with(
method=method, entity_type=AvpEntities.DATASET, user=expected_user, entity_id=expected_entity_id
)

@pytest.mark.parametrize(
"details, user, expected_user, expected_entity_id",
[
(None, None, ANY, None),
(PoolDetails(name="pool1"), mock, mock, "pool1"),
],
)
@patch.object(AwsAuthManager, "avp_facade")
@patch.object(AwsAuthManager, "get_user")
def test_is_authorized_pool(
self, mock_get_user, mock_avp_facade, details, user, expected_user, expected_entity_id, auth_manager
):
is_authorized = Mock()
mock_avp_facade.is_authorized = is_authorized

method: ResourceMethod = "GET"
auth_manager.is_authorized_pool(method=method, details=details, user=user)

if not user:
mock_get_user.assert_called_once()
is_authorized.assert_called_once_with(
method=method, entity_type=AvpEntities.POOL, user=expected_user, entity_id=expected_entity_id
)

@pytest.mark.parametrize(
"details, user, expected_user, expected_entity_id",
[
Expand All @@ -126,7 +235,6 @@ def test_is_authorized_variable(
mock_avp_facade.is_authorized = is_authorized

method: ResourceMethod = "GET"

auth_manager.is_authorized_variable(method=method, details=details, user=user)

if not user:
Expand All @@ -135,6 +243,29 @@ def test_is_authorized_variable(
method=method, entity_type=AvpEntities.VARIABLE, user=expected_user, entity_id=expected_entity_id
)

@pytest.mark.parametrize(
"access_view, user, expected_user",
[
(AccessView.CLUSTER_ACTIVITY, None, ANY),
(AccessView.PLUGINS, mock, mock),
],
)
@patch.object(AwsAuthManager, "avp_facade")
@patch.object(AwsAuthManager, "get_user")
def test_is_authorized_view(
self, mock_get_user, mock_avp_facade, access_view, user, expected_user, auth_manager
):
is_authorized = Mock()
mock_avp_facade.is_authorized = is_authorized

auth_manager.is_authorized_view(access_view=access_view, user=user)

if not user:
mock_get_user.assert_called_once()
is_authorized.assert_called_once_with(
method="GET", entity_type=AvpEntities.VIEW, user=expected_user, entity_id=access_view.value
)

@patch("airflow.providers.amazon.aws.auth_manager.aws_auth_manager.url_for")
def test_get_url_login(self, mock_url_for, auth_manager):
auth_manager.get_url_login()
Expand Down

0 comments on commit 985d058

Please sign in to comment.