Skip to content

Commit

Permalink
refactor: user pass auth (#229)
Browse files Browse the repository at this point in the history
  • Loading branch information
tzing authored Dec 26, 2024
1 parent 33df640 commit eda9792
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 224 deletions.
12 changes: 6 additions & 6 deletions secrets_env/providers/vault/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from secrets_env.providers.vault.auth.base import Auth

if typing.TYPE_CHECKING:
from pydantic_core import Url
from pydantic import AnyUrl


logger = logging.getLogger(__name__)


def create_auth(*, url: Url, method: str, **config) -> Auth:
def create_auth(*, url: AnyUrl, method: str, **config) -> Auth:
"""
Factory function to create an instance of the authenticator class by the method name.
"""
Expand All @@ -23,8 +23,8 @@ def create_auth(*, url: Url, method: str, **config) -> Auth:
from secrets_env.providers.vault.auth.kubernetes import KubernetesAuth
return KubernetesAuth.create(url, config)
if method == "ldap":
from secrets_env.providers.vault.auth.userpass import LDAPAuth
return LDAPAuth.create(url, config)
from secrets_env.providers.vault.auth.userpass import LdapAuth
return LdapAuth.create(url, config)
if method == "null":
from secrets_env.providers.vault.auth.base import NoAuth
return NoAuth.create(url, config)
Expand All @@ -35,8 +35,8 @@ def create_auth(*, url: Url, method: str, **config) -> Auth:
from secrets_env.providers.vault.auth.userpass import OktaAuth
return OktaAuth.create(url, config)
if method == "radius":
from secrets_env.providers.vault.auth.userpass import RADIUSAuth
return RADIUSAuth.create(url, config)
from secrets_env.providers.vault.auth.userpass import RadiusAuth
return RadiusAuth.create(url, config)
if method == "token":
from secrets_env.providers.vault.auth.token import TokenAuth
return TokenAuth.create(url, config)
Expand Down
68 changes: 34 additions & 34 deletions secrets_env/providers/vault/auth/userpass.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Self

import httpx
from pydantic_core import Url
from pydantic import AnyUrl

logger = logging.getLogger(__name__)

Expand All @@ -41,12 +41,12 @@ class UserPasswordAuth(Auth):
"""Request timeout."""

@classmethod
def create(cls, url: Url, config: dict) -> Self:
username = cls._get_username(config, url)
def create(cls, url: AnyUrl, config: dict) -> Self:
username = get_username(url, config)
if not username:
raise ValueError(f"Missing username for {cls.method} auth")

password = cls._get_password(url, username)
password = get_password(url, username)
if not password:
raise ValueError(f"Missing password for {cls.method} auth")

Expand All @@ -55,34 +55,6 @@ def create(cls, url: Url, config: dict) -> Self:
password=cast(SecretStr, password),
)

@classmethod
def _get_username(cls, config: dict, url: Url) -> str | None:
if username := get_env_var("SECRETS_ENV_USERNAME"):
logger.debug("Found username from environment variable.")
return username

if username := config.get("username"):
return username

user_config = load_user_config(url)
if username := user_config.get("auth", {}).get("username"):
logger.debug("Found username in user config.")
return username

return prompt(f"Username for {cls.method} auth")

@classmethod
def _get_password(cls, url: Url, username: str) -> str | None:
if password := get_env_var("SECRETS_ENV_PASSWORD"):
logger.debug("Found password from environment variable.")
return password

if password := read_keyring(create_keyring_login_key(url, username)):
logger.debug("Found password in keyring")
return password

return prompt(f"Password for {username}", hide_input=True)

def login(self, client: httpx.Client) -> str:
username = urllib.parse.quote(self.username)
resp = client.post(
Expand All @@ -106,7 +78,35 @@ def login(self, client: httpx.Client) -> str:
return resp.json()["auth"]["client_token"]


class LDAPAuth(UserPasswordAuth):
def get_username(url: AnyUrl, config: dict) -> str | None:
if username := get_env_var("SECRETS_ENV_USERNAME"):
logger.debug("Found username from environment variable.")
return username

if username := config.get("username"):
return username

user_config = load_user_config(url)
if username := user_config.get("auth", {}).get("username"):
logger.debug("Found username in user config.")
return username

return prompt(f"Username for {url.host}")


def get_password(url: AnyUrl, username: str) -> str | None:
if password := get_env_var("SECRETS_ENV_PASSWORD"):
logger.debug("Found password from environment variable.")
return password

if password := read_keyring(create_keyring_login_key(url, username)):
logger.debug("Found password in keyring")
return password

return prompt(f"Password for {username}", hide_input=True)


class LdapAuth(UserPasswordAuth):
"""Login with LDAP credentials."""

method = "LDAP"
Expand All @@ -123,7 +123,7 @@ class OktaAuth(UserPasswordAuth):
_timeout: float | None = PrivateAttr(60.0)


class RADIUSAuth(UserPasswordAuth):
class RadiusAuth(UserPasswordAuth):
"""RADIUS authentication with PAP authentication scheme."""

method = "RADIUS"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ class TestCreateAuth:
"kubernetes",
"secrets_env.providers.vault.auth.kubernetes.KubernetesAuth.create",
),
("ldap", "secrets_env.providers.vault.auth.userpass.LDAPAuth.create"),
("ldap", "secrets_env.providers.vault.auth.userpass.LdapAuth.create"),
("null", "secrets_env.providers.vault.auth.base.NoAuth.create"),
("oidc", "secrets_env.providers.vault.auth.oidc.OpenIDConnectAuth.create"),
("okta", "secrets_env.providers.vault.auth.userpass.OktaAuth.create"),
("radius", "secrets_env.providers.vault.auth.userpass.RADIUSAuth.create"),
("radius", "secrets_env.providers.vault.auth.userpass.RadiusAuth.create"),
("token", "secrets_env.providers.vault.auth.token.TokenAuth.create"),
(
"userpass",
Expand Down
222 changes: 222 additions & 0 deletions tests/providers/vault/auth/test_userpass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
import re
from unittest.mock import Mock

import httpx
import pytest
import respx
from pydantic import HttpUrl

from secrets_env.exceptions import AuthenticationError
from secrets_env.providers.vault.auth.userpass import (
LdapAuth,
OktaAuth,
RadiusAuth,
UserPassAuth,
UserPasswordAuth,
get_password,
get_username,
)


@pytest.fixture
def login_success_response() -> httpx.Response:
return httpx.Response(
200,
json={
"lease_id": "",
"renewable": False,
"lease_duration": 0,
"data": None,
"warnings": None,
"auth": {
"client_token": "client-token",
"accessor": "accessor-token",
"policies": ["default"],
"metadata": {"username": "fred", "policies": "default"},
"lease_duration": 7200,
"renewable": True,
},
},
)


class TestUserPasswordAuth:

def test_create_success(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(
"secrets_env.providers.vault.auth.userpass.get_username",
lambda _1, _2: "user",
)
monkeypatch.setattr(
"secrets_env.providers.vault.auth.userpass.get_password",
lambda _1, _2: "P@ssw0rd",
)

obj = UserPasswordAuth.create(HttpUrl("https://example.com/"), {})
assert obj == UserPasswordAuth(username="user", password="P@ssw0rd")

@pytest.mark.parametrize(
("username", "password", "err_message"),
[
("[email protected]", "", "Missing password for MOCK auth"),
("", "P@ssw0rd", "Missing username for MOCK auth"),
("[email protected]", None, "Missing password for MOCK auth"),
(None, "P@ssw0rd", "Missing username for MOCK auth"),
],
)
def test_load_fail(
self,
monkeypatch: pytest.MonkeyPatch,
username: str,
password: str,
err_message: str,
):
class MockAuth(UserPasswordAuth):
method = "MOCK"

monkeypatch.setattr(
"secrets_env.providers.vault.auth.userpass.get_username",
lambda _1, _2: username,
)
monkeypatch.setattr(
"secrets_env.providers.vault.auth.userpass.get_password",
lambda _1, _2: password,
)

with pytest.raises(ValueError, match=re.escape(err_message)):
assert MockAuth.create(HttpUrl("https://example.com/"), {}) is None

def test_login_success(
self,
unittest_respx: respx.MockRouter,
unittest_client: httpx.Client,
login_success_response: httpx.Response,
):
unittest_respx.post("/v1/auth/mock/login/user%40example.com").mock(
return_value=login_success_response
)

class MockAuth(UserPasswordAuth):
method = "MOCK"
vault_name = "mock"

auth_obj = MockAuth(username="[email protected]", password="password")
assert auth_obj.login(unittest_client) == "client-token"

def test_login_fail(
self, unittest_respx: respx.MockRouter, unittest_client: httpx.Client
):
unittest_respx.post("/v1/auth/mock/login/user%40example.com").mock(
return_value=httpx.Response(400)
)

class MockAuth(UserPasswordAuth):
method = "MOCK"
vault_name = "mock"

auth_obj = MockAuth(username="[email protected]", password="password")

with pytest.raises(AuthenticationError):
assert auth_obj.login(unittest_client) is None


class TestGetUsername:

def test_config(self):
assert (
get_username(HttpUrl("https://example.com/"), {"username": "foo"}) == "foo"
)

def test_env_var(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("SECRETS_ENV_USERNAME", "foo")
assert get_username(HttpUrl("https://example.com/"), {}) == "foo"

def test_user_config(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(
"secrets_env.providers.vault.auth.userpass.load_user_config",
lambda _: {"auth": {"username": "foo"}},
)
assert get_username(HttpUrl("https://example.com/"), {}) == "foo"

def test_prompt(self, monkeypatch: pytest.MonkeyPatch):
mock_prompt = Mock(return_value="foo")
monkeypatch.setattr(
"secrets_env.providers.vault.auth.userpass.prompt", mock_prompt
)

assert get_username(HttpUrl("https://example.com/"), {}) == "foo"
mock_prompt.assert_any_call("Username for example.com")

def test__load_username(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(
"secrets_env.providers.vault.auth.userpass.load_user_config",
lambda _: {},
)

mock_prompt = Mock(return_value="foo")
monkeypatch.setattr(
"secrets_env.providers.vault.auth.userpass.prompt", mock_prompt
)

assert get_username(HttpUrl("https://example.com/"), {}) == "foo"
mock_prompt.assert_called_once_with("Username for example.com")


class TestGetPassword:

def test_env_var(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("SECRETS_ENV_PASSWORD", "bar")
assert get_password(HttpUrl("https://example.com/"), "foo") == "bar"

def test_keyring(self, monkeypatch: pytest.MonkeyPatch):
mock_read_keyring = Mock(return_value="bar")
monkeypatch.setattr(
"secrets_env.providers.vault.auth.userpass.read_keyring", mock_read_keyring
)

assert get_password(HttpUrl("https://example.com/"), "foo") == "bar"
mock_read_keyring.assert_any_call(
'{"host": "example.com", "type": "login", "user": "foo"}'
)

def test_prompt(self, monkeypatch: pytest.MonkeyPatch):
mock_prompt = Mock(return_value="bar")
monkeypatch.setattr(
"secrets_env.providers.vault.auth.userpass.prompt", mock_prompt
)

assert get_password(HttpUrl("https://example.com/"), "foo") == "bar"
mock_prompt.assert_called_once_with("Password for foo", hide_input=True)


@pytest.mark.parametrize(
("method_class", "login_path"),
[
(LdapAuth, "/v1/auth/ldap/login/user"),
(OktaAuth, "/v1/auth/okta/login/user"),
(RadiusAuth, "/v1/auth/radius/login/user"),
(UserPassAuth, "/v1/auth/userpass/login/user"),
],
)
def test_auth_methods(
monkeypatch: pytest.MonkeyPatch,
method_class: type[UserPasswordAuth],
unittest_respx: respx.MockRouter,
login_path: str,
login_success_response: httpx.Response,
unittest_client: httpx.Client,
):
# no exception is enough
assert isinstance(method_class.method, str)

# test creation
monkeypatch.setenv("SECRETS_ENV_USERNAME", "user")
monkeypatch.setenv("SECRETS_ENV_PASSWORD", "pass")

auth = method_class.create(HttpUrl("https://example.com/"), {})
assert auth

# test login
unittest_respx.post(login_path).mock(return_value=login_success_response)

assert auth.login(unittest_client) == "client-token"
Loading

0 comments on commit eda9792

Please sign in to comment.