From eda979222eaf3b72873070604dcb0a8bb7dc4692 Mon Sep 17 00:00:00 2001 From: Tzu-Ting Date: Fri, 27 Dec 2024 00:03:09 +0800 Subject: [PATCH] refactor: user pass auth (#229) --- secrets_env/providers/vault/auth/__init__.py | 12 +- secrets_env/providers/vault/auth/userpass.py | 68 +++--- .../auth/{test_vault_auth.py => test_auth.py} | 4 +- tests/providers/vault/auth/test_userpass.py | 222 ++++++++++++++++++ .../vault/auth/test_vault_userpass.py | 182 -------------- 5 files changed, 264 insertions(+), 224 deletions(-) rename tests/providers/vault/auth/{test_vault_auth.py => test_auth.py} (97%) create mode 100644 tests/providers/vault/auth/test_userpass.py delete mode 100644 tests/providers/vault/auth/test_vault_userpass.py diff --git a/secrets_env/providers/vault/auth/__init__.py b/secrets_env/providers/vault/auth/__init__.py index 8c17d47b..117a9c0d 100644 --- a/secrets_env/providers/vault/auth/__init__.py +++ b/secrets_env/providers/vault/auth/__init__.py @@ -8,13 +8,13 @@ from secrets_env.providers.vault.auth.base import Auth if typing.TYPE_CHECKING: - from pydantic_core import Url + from pydantic import AnyUrl logger = logging.getLogger(__name__) -def create_auth(*, url: Url, method: str, **config) -> Auth: +def create_auth(*, url: AnyUrl, method: str, **config) -> Auth: """ Factory function to create an instance of the authenticator class by the method name. """ @@ -23,8 +23,8 @@ def create_auth(*, url: Url, method: str, **config) -> Auth: from secrets_env.providers.vault.auth.kubernetes import KubernetesAuth return KubernetesAuth.create(url, config) if method == "ldap": - from secrets_env.providers.vault.auth.userpass import LDAPAuth - return LDAPAuth.create(url, config) + from secrets_env.providers.vault.auth.userpass import LdapAuth + return LdapAuth.create(url, config) if method == "null": from secrets_env.providers.vault.auth.base import NoAuth return NoAuth.create(url, config) @@ -35,8 +35,8 @@ def create_auth(*, url: Url, method: str, **config) -> Auth: from secrets_env.providers.vault.auth.userpass import OktaAuth return OktaAuth.create(url, config) if method == "radius": - from secrets_env.providers.vault.auth.userpass import RADIUSAuth - return RADIUSAuth.create(url, config) + from secrets_env.providers.vault.auth.userpass import RadiusAuth + return RadiusAuth.create(url, config) if method == "token": from secrets_env.providers.vault.auth.token import TokenAuth return TokenAuth.create(url, config) diff --git a/secrets_env/providers/vault/auth/userpass.py b/secrets_env/providers/vault/auth/userpass.py index 1abd72e6..e96d8466 100644 --- a/secrets_env/providers/vault/auth/userpass.py +++ b/secrets_env/providers/vault/auth/userpass.py @@ -20,7 +20,7 @@ from typing import Self import httpx - from pydantic_core import Url + from pydantic import AnyUrl logger = logging.getLogger(__name__) @@ -41,12 +41,12 @@ class UserPasswordAuth(Auth): """Request timeout.""" @classmethod - def create(cls, url: Url, config: dict) -> Self: - username = cls._get_username(config, url) + def create(cls, url: AnyUrl, config: dict) -> Self: + username = get_username(url, config) if not username: raise ValueError(f"Missing username for {cls.method} auth") - password = cls._get_password(url, username) + password = get_password(url, username) if not password: raise ValueError(f"Missing password for {cls.method} auth") @@ -55,34 +55,6 @@ def create(cls, url: Url, config: dict) -> Self: password=cast(SecretStr, password), ) - @classmethod - def _get_username(cls, config: dict, url: Url) -> str | None: - if username := get_env_var("SECRETS_ENV_USERNAME"): - logger.debug("Found username from environment variable.") - return username - - if username := config.get("username"): - return username - - user_config = load_user_config(url) - if username := user_config.get("auth", {}).get("username"): - logger.debug("Found username in user config.") - return username - - return prompt(f"Username for {cls.method} auth") - - @classmethod - def _get_password(cls, url: Url, username: str) -> str | None: - if password := get_env_var("SECRETS_ENV_PASSWORD"): - logger.debug("Found password from environment variable.") - return password - - if password := read_keyring(create_keyring_login_key(url, username)): - logger.debug("Found password in keyring") - return password - - return prompt(f"Password for {username}", hide_input=True) - def login(self, client: httpx.Client) -> str: username = urllib.parse.quote(self.username) resp = client.post( @@ -106,7 +78,35 @@ def login(self, client: httpx.Client) -> str: return resp.json()["auth"]["client_token"] -class LDAPAuth(UserPasswordAuth): +def get_username(url: AnyUrl, config: dict) -> str | None: + if username := get_env_var("SECRETS_ENV_USERNAME"): + logger.debug("Found username from environment variable.") + return username + + if username := config.get("username"): + return username + + user_config = load_user_config(url) + if username := user_config.get("auth", {}).get("username"): + logger.debug("Found username in user config.") + return username + + return prompt(f"Username for {url.host}") + + +def get_password(url: AnyUrl, username: str) -> str | None: + if password := get_env_var("SECRETS_ENV_PASSWORD"): + logger.debug("Found password from environment variable.") + return password + + if password := read_keyring(create_keyring_login_key(url, username)): + logger.debug("Found password in keyring") + return password + + return prompt(f"Password for {username}", hide_input=True) + + +class LdapAuth(UserPasswordAuth): """Login with LDAP credentials.""" method = "LDAP" @@ -123,7 +123,7 @@ class OktaAuth(UserPasswordAuth): _timeout: float | None = PrivateAttr(60.0) -class RADIUSAuth(UserPasswordAuth): +class RadiusAuth(UserPasswordAuth): """RADIUS authentication with PAP authentication scheme.""" method = "RADIUS" diff --git a/tests/providers/vault/auth/test_vault_auth.py b/tests/providers/vault/auth/test_auth.py similarity index 97% rename from tests/providers/vault/auth/test_vault_auth.py rename to tests/providers/vault/auth/test_auth.py index aaec9938..ffb6230c 100644 --- a/tests/providers/vault/auth/test_vault_auth.py +++ b/tests/providers/vault/auth/test_auth.py @@ -15,11 +15,11 @@ class TestCreateAuth: "kubernetes", "secrets_env.providers.vault.auth.kubernetes.KubernetesAuth.create", ), - ("ldap", "secrets_env.providers.vault.auth.userpass.LDAPAuth.create"), + ("ldap", "secrets_env.providers.vault.auth.userpass.LdapAuth.create"), ("null", "secrets_env.providers.vault.auth.base.NoAuth.create"), ("oidc", "secrets_env.providers.vault.auth.oidc.OpenIDConnectAuth.create"), ("okta", "secrets_env.providers.vault.auth.userpass.OktaAuth.create"), - ("radius", "secrets_env.providers.vault.auth.userpass.RADIUSAuth.create"), + ("radius", "secrets_env.providers.vault.auth.userpass.RadiusAuth.create"), ("token", "secrets_env.providers.vault.auth.token.TokenAuth.create"), ( "userpass", diff --git a/tests/providers/vault/auth/test_userpass.py b/tests/providers/vault/auth/test_userpass.py new file mode 100644 index 00000000..bb9f65c4 --- /dev/null +++ b/tests/providers/vault/auth/test_userpass.py @@ -0,0 +1,222 @@ +import re +from unittest.mock import Mock + +import httpx +import pytest +import respx +from pydantic import HttpUrl + +from secrets_env.exceptions import AuthenticationError +from secrets_env.providers.vault.auth.userpass import ( + LdapAuth, + OktaAuth, + RadiusAuth, + UserPassAuth, + UserPasswordAuth, + get_password, + get_username, +) + + +@pytest.fixture +def login_success_response() -> httpx.Response: + return httpx.Response( + 200, + json={ + "lease_id": "", + "renewable": False, + "lease_duration": 0, + "data": None, + "warnings": None, + "auth": { + "client_token": "client-token", + "accessor": "accessor-token", + "policies": ["default"], + "metadata": {"username": "fred", "policies": "default"}, + "lease_duration": 7200, + "renewable": True, + }, + }, + ) + + +class TestUserPasswordAuth: + + def test_create_success(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr( + "secrets_env.providers.vault.auth.userpass.get_username", + lambda _1, _2: "user", + ) + monkeypatch.setattr( + "secrets_env.providers.vault.auth.userpass.get_password", + lambda _1, _2: "P@ssw0rd", + ) + + obj = UserPasswordAuth.create(HttpUrl("https://example.com/"), {}) + assert obj == UserPasswordAuth(username="user", password="P@ssw0rd") + + @pytest.mark.parametrize( + ("username", "password", "err_message"), + [ + ("user@example.com", "", "Missing password for MOCK auth"), + ("", "P@ssw0rd", "Missing username for MOCK auth"), + ("user@example.com", None, "Missing password for MOCK auth"), + (None, "P@ssw0rd", "Missing username for MOCK auth"), + ], + ) + def test_load_fail( + self, + monkeypatch: pytest.MonkeyPatch, + username: str, + password: str, + err_message: str, + ): + class MockAuth(UserPasswordAuth): + method = "MOCK" + + monkeypatch.setattr( + "secrets_env.providers.vault.auth.userpass.get_username", + lambda _1, _2: username, + ) + monkeypatch.setattr( + "secrets_env.providers.vault.auth.userpass.get_password", + lambda _1, _2: password, + ) + + with pytest.raises(ValueError, match=re.escape(err_message)): + assert MockAuth.create(HttpUrl("https://example.com/"), {}) is None + + def test_login_success( + self, + unittest_respx: respx.MockRouter, + unittest_client: httpx.Client, + login_success_response: httpx.Response, + ): + unittest_respx.post("/v1/auth/mock/login/user%40example.com").mock( + return_value=login_success_response + ) + + class MockAuth(UserPasswordAuth): + method = "MOCK" + vault_name = "mock" + + auth_obj = MockAuth(username="user@example.com", password="password") + assert auth_obj.login(unittest_client) == "client-token" + + def test_login_fail( + self, unittest_respx: respx.MockRouter, unittest_client: httpx.Client + ): + unittest_respx.post("/v1/auth/mock/login/user%40example.com").mock( + return_value=httpx.Response(400) + ) + + class MockAuth(UserPasswordAuth): + method = "MOCK" + vault_name = "mock" + + auth_obj = MockAuth(username="user@example.com", password="password") + + with pytest.raises(AuthenticationError): + assert auth_obj.login(unittest_client) is None + + +class TestGetUsername: + + def test_config(self): + assert ( + get_username(HttpUrl("https://example.com/"), {"username": "foo"}) == "foo" + ) + + def test_env_var(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("SECRETS_ENV_USERNAME", "foo") + assert get_username(HttpUrl("https://example.com/"), {}) == "foo" + + def test_user_config(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr( + "secrets_env.providers.vault.auth.userpass.load_user_config", + lambda _: {"auth": {"username": "foo"}}, + ) + assert get_username(HttpUrl("https://example.com/"), {}) == "foo" + + def test_prompt(self, monkeypatch: pytest.MonkeyPatch): + mock_prompt = Mock(return_value="foo") + monkeypatch.setattr( + "secrets_env.providers.vault.auth.userpass.prompt", mock_prompt + ) + + assert get_username(HttpUrl("https://example.com/"), {}) == "foo" + mock_prompt.assert_any_call("Username for example.com") + + def test__load_username(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr( + "secrets_env.providers.vault.auth.userpass.load_user_config", + lambda _: {}, + ) + + mock_prompt = Mock(return_value="foo") + monkeypatch.setattr( + "secrets_env.providers.vault.auth.userpass.prompt", mock_prompt + ) + + assert get_username(HttpUrl("https://example.com/"), {}) == "foo" + mock_prompt.assert_called_once_with("Username for example.com") + + +class TestGetPassword: + + def test_env_var(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("SECRETS_ENV_PASSWORD", "bar") + assert get_password(HttpUrl("https://example.com/"), "foo") == "bar" + + def test_keyring(self, monkeypatch: pytest.MonkeyPatch): + mock_read_keyring = Mock(return_value="bar") + monkeypatch.setattr( + "secrets_env.providers.vault.auth.userpass.read_keyring", mock_read_keyring + ) + + assert get_password(HttpUrl("https://example.com/"), "foo") == "bar" + mock_read_keyring.assert_any_call( + '{"host": "example.com", "type": "login", "user": "foo"}' + ) + + def test_prompt(self, monkeypatch: pytest.MonkeyPatch): + mock_prompt = Mock(return_value="bar") + monkeypatch.setattr( + "secrets_env.providers.vault.auth.userpass.prompt", mock_prompt + ) + + assert get_password(HttpUrl("https://example.com/"), "foo") == "bar" + mock_prompt.assert_called_once_with("Password for foo", hide_input=True) + + +@pytest.mark.parametrize( + ("method_class", "login_path"), + [ + (LdapAuth, "/v1/auth/ldap/login/user"), + (OktaAuth, "/v1/auth/okta/login/user"), + (RadiusAuth, "/v1/auth/radius/login/user"), + (UserPassAuth, "/v1/auth/userpass/login/user"), + ], +) +def test_auth_methods( + monkeypatch: pytest.MonkeyPatch, + method_class: type[UserPasswordAuth], + unittest_respx: respx.MockRouter, + login_path: str, + login_success_response: httpx.Response, + unittest_client: httpx.Client, +): + # no exception is enough + assert isinstance(method_class.method, str) + + # test creation + monkeypatch.setenv("SECRETS_ENV_USERNAME", "user") + monkeypatch.setenv("SECRETS_ENV_PASSWORD", "pass") + + auth = method_class.create(HttpUrl("https://example.com/"), {}) + assert auth + + # test login + unittest_respx.post(login_path).mock(return_value=login_success_response) + + assert auth.login(unittest_client) == "client-token" diff --git a/tests/providers/vault/auth/test_vault_userpass.py b/tests/providers/vault/auth/test_vault_userpass.py deleted file mode 100644 index 0997d951..00000000 --- a/tests/providers/vault/auth/test_vault_userpass.py +++ /dev/null @@ -1,182 +0,0 @@ -import re -from unittest.mock import patch - -import httpx -import pytest -import respx -from pydantic_core import Url - -import secrets_env.providers.vault.auth.userpass as t -from secrets_env.exceptions import AuthenticationError -from secrets_env.providers.vault.auth.userpass import UserPasswordAuth - - -@pytest.fixture -def login_success_response() -> httpx.Response: - return httpx.Response( - 200, - json={ - "lease_id": "", - "renewable": False, - "lease_duration": 0, - "data": None, - "warnings": None, - "auth": { - "client_token": "client-token", - "accessor": "accessor-token", - "policies": ["default"], - "metadata": {"username": "fred", "policies": "default"}, - "lease_duration": 7200, - "renewable": True, - }, - }, - ) - - -class TestUserPasswordAuth: - def test_create_success(self, monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr(UserPasswordAuth, "_get_username", lambda _1, _2: "user") - monkeypatch.setattr( - UserPasswordAuth, "_get_password", lambda _1, _2: "P@ssw0rd" - ) - - obj = UserPasswordAuth.create(Url("https://example.com/"), {}) - assert obj == UserPasswordAuth(username="user", password="P@ssw0rd") - - @pytest.mark.parametrize( - ("username", "password", "err_message"), - [ - ("user@example.com", "", "Missing password for MOCK auth"), - ("", "P@ssw0rd", "Missing username for MOCK auth"), - ("user@example.com", None, "Missing password for MOCK auth"), - (None, "P@ssw0rd", "Missing username for MOCK auth"), - ], - ) - def test_load_fail( - self, - monkeypatch: pytest.MonkeyPatch, - username: str, - password: str, - err_message: str, - ): - class MockAuth(UserPasswordAuth): - method = "MOCK" - - monkeypatch.setattr(MockAuth, "_get_username", lambda _1, _2: username) - monkeypatch.setattr(MockAuth, "_get_password", lambda _1, _2: password) - - with pytest.raises(ValueError, match=re.escape(err_message)): - assert MockAuth.create(Url("https://example.com/"), {}) is None - - def test__load_username(self, monkeypatch: pytest.MonkeyPatch): - class MockAuth(UserPasswordAuth): - method = "MOCK" - - url = Url("https://example.com/") - - # config - assert MockAuth._get_username({"username": "foo"}, url) == "foo" - - # env var - with monkeypatch.context() as m: - m.setenv("SECRETS_ENV_USERNAME", "foo") - assert MockAuth._get_username({}, url) == "foo" - - # user config - with patch.object( - t, "load_user_config", return_value={"auth": {"username": "foo"}} - ): - assert MockAuth._get_username({}, url) == "foo" - - # prompt - with ( - patch.object(t, "load_user_config", return_value={}), - patch.object(t, "prompt", return_value="foo") as p, - ): - assert MockAuth._get_username({}, url) == "foo" - p.assert_any_call("Username for MOCK auth") - - def test__load_password(self, monkeypatch: pytest.MonkeyPatch): - # env var - with monkeypatch.context() as m: - m.setenv("SECRETS_ENV_PASSWORD", "bar") - out = UserPasswordAuth._get_password(Url("https://example.com/"), "foo") - assert out == "bar" - - # prompt - with patch.object(t, "prompt", return_value="bar") as p: - out = UserPasswordAuth._get_password(Url("https://example.com/"), "foo") - assert out == "bar" - p.assert_any_call("Password for foo", hide_input=True) - - # keyring - with patch.object(t, "read_keyring", return_value="bar") as r: - out = UserPasswordAuth._get_password(Url("https://example.com/"), "foo") - assert out == "bar" - r.assert_any_call('{"host": "example.com", "type": "login", "user": "foo"}') - - def test_login_success( - self, - unittest_respx: respx.MockRouter, - unittest_client: httpx.Client, - login_success_response: httpx.Response, - ): - unittest_respx.post("/v1/auth/mock/login/user%40example.com").mock( - return_value=login_success_response - ) - - class MockAuth(UserPasswordAuth): - method = "MOCK" - vault_name = "mock" - - auth_obj = MockAuth(username="user@example.com", password="password") - assert auth_obj.login(unittest_client) == "client-token" - - def test_login_fail( - self, unittest_respx: respx.MockRouter, unittest_client: httpx.Client - ): - unittest_respx.post("/v1/auth/mock/login/user%40example.com").mock( - return_value=httpx.Response(400) - ) - - class MockAuth(UserPasswordAuth): - method = "MOCK" - vault_name = "mock" - - auth_obj = MockAuth(username="user@example.com", password="password") - - with pytest.raises(AuthenticationError): - assert auth_obj.login(unittest_client) is None - - -@pytest.mark.parametrize( - ("method_class", "login_path"), - [ - (t.LDAPAuth, "/v1/auth/ldap/login/user"), - (t.OktaAuth, "/v1/auth/okta/login/user"), - (t.RADIUSAuth, "/v1/auth/radius/login/user"), - (t.UserPassAuth, "/v1/auth/userpass/login/user"), - ], -) -def test_auth_methods( - monkeypatch: pytest.MonkeyPatch, - method_class: type[UserPasswordAuth], - unittest_respx: respx.MockRouter, - login_path: str, - login_success_response: httpx.Response, - unittest_client: httpx.Client, -): - # no exception is enough - assert isinstance(method_class.method, str) - - # test creation - monkeypatch.setenv("SECRETS_ENV_USERNAME", "user") - monkeypatch.setenv("SECRETS_ENV_PASSWORD", "pass") - - auth = method_class.create(Url("https://example.com/"), {}) - assert auth - - # test login - unittest_respx.post(login_path).mock(return_value=login_success_response) - - assert auth.login(unittest_client) == "client-token"