Skip to content

Commit

Permalink
Add support for multiuser client in python (#44)
Browse files Browse the repository at this point in the history
* Add support for multiuser client in python

* Review comments

* Remove options

* Update sdks/python/tests/unit_tests/multi_user_client_tests.py

Co-authored-by: Matthew Timms <[email protected]>
Signed-off-by: IsisChameleon <[email protected]>

---------

Signed-off-by: IsisChameleon <[email protected]>
Co-authored-by: Your Name <[email protected]>
Co-authored-by: Matthew Timms <[email protected]>
  • Loading branch information
3 people authored Oct 21, 2024
1 parent ad96a8c commit 0d1019e
Show file tree
Hide file tree
Showing 5 changed files with 417 additions and 13 deletions.
19 changes: 9 additions & 10 deletions sdks/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,12 @@ dependencies = [
"betterproto[compiler]==2.0.0b6",
"grpclib[protobuf]",
"httpx",
"pyjwt",
]

[project.optional-dependencies]
tests = [
"pytest",
"pytest-asyncio",
"pytest-httpx",
]
reranking = [
"rerankers",
"rerankers[transformers]"
]
tests = ["pytest", "pytest-asyncio", "pytest-httpx"]
reranking = ["rerankers", "rerankers[transformers]"]

[project.urls]
Homepage = "https://github.com/redactive-ai/redactive"
Expand All @@ -54,7 +48,12 @@ python_classes = "*Tests"
python_functions = "test_*"

[tool.hatch.envs.types]
extra-dependencies = ["redactive>=1.0.0", "pyright", "rerankers", "rerankers[transformers]"]
extra-dependencies = [
"redactive>=1.0.0",
"pyright",
"rerankers",
"rerankers[transformers]",
]

[tool.hatch.envs.types.scripts]
check = "pyright"
Expand Down
34 changes: 33 additions & 1 deletion sdks/python/src/redactive/auth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
from redactive._connection_mode import get_default_http_endpoint as _get_default_http_endpoint


class ListConnectionsResponse(BaseModel):
user_id: str
connections: list[str]


class ExchangeTokenResponse(BaseModel):
idToken: str # noqa: N815
refreshToken: str # noqa: N815
Expand All @@ -32,7 +37,12 @@ def __init__(self, api_key: str, base_url: str | None = None):
self._client = httpx.AsyncClient(base_url=f"{base_url}", auth=BearerAuth(api_key))

async def begin_connection(
self, provider: str, redirect_uri: str, endpoint: str | None = None, code_param_alias: str | None = None
self,
provider: str,
redirect_uri: str,
endpoint: str | None = None,
code_param_alias: str | None = None,
state: str | None = None,
) -> BeginConnectionResponse:
"""
Initiates a connection process with a specified provider.
Expand All @@ -45,6 +55,8 @@ async def begin_connection(
:type endpoint: str, optional
:param code_param_alias: The alias for the code parameter. This is the name of the query parameter that will need to be passed to the `/auth/token` endpoint as `code`. Defaults to None and will be `code` on the return.
:type code_param_alias: str, optional
:param state: An optional parameter that is stored as app_callback_state for building callback url. Defaults to None.
:type state: str, optional
:raises httpx.RequestError: If an error occurs while making the HTTP request.
:return: The URL to redirect the user to for beginning the connection.
:rtype: BeginConnectionResponse
Expand All @@ -54,6 +66,8 @@ async def begin_connection(
params["endpoint"] = endpoint
if code_param_alias:
params["code_param_alias"] = code_param_alias
if state:
params["state"] = state
response = await self._client.post(url=f"/api/auth/connect/{provider}/url", params=params)
if response.status_code != http.HTTPStatus.OK:
raise httpx.RequestError(response.text)
Expand Down Expand Up @@ -85,6 +99,24 @@ async def exchange_tokens(self, code: str | None = None, refresh_token: str | No

return ExchangeTokenResponse(**response.json())

async def list_connections(self, access_token: str) -> ListConnectionsResponse:
"""
Retrieve the list of user connections.
:param access_token: The access token for authentication.
:type access_token: str
:raises httpx.RequestError: If an error occurs while making the HTTP request.
:return: An object containing the user ID and current connections.
:rtype: UserConnections
"""
headers = {"Authorization": f"Bearer {access_token}"}
response = await self._client.get("/api/auth/connections", headers=headers)

if response.status_code != http.HTTPStatus.OK:
raise httpx.RequestError(response.text)

return ListConnectionsResponse(**response.json())


class BearerAuth(httpx.Auth):
def __init__(self, token):
Expand Down
136 changes: 136 additions & 0 deletions sdks/python/src/redactive/multi_user_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import uuid
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from datetime import UTC, datetime, timedelta
from typing import Annotated

import jwt

from redactive.auth_client import AuthClient
from redactive.grpc.v1 import Chunk, RelevantChunk
from redactive.search_client import SearchClient


@dataclass
class UserData:
refresh_token: str | None = None
id_token: str | None = None
id_token_expiry: datetime | None = None
connections: list[str] = field(default_factory=list)
sign_in_state: str | None = None


class InvalidRedactiveSessionError(Exception):
def __init__(self, user_id: str) -> None:
super().__init__(f"No valid Redactive session for user '{user_id}'")


class MultiUserClient:
def __init__(
self,
api_key: str,
callback_uri: str,
read_user_data: Callable[[Annotated[str, "user_id"]], Awaitable[UserData]],
write_user_data: Callable[[Annotated[str, "user_id"], UserData | None], Awaitable[None]],
*,
auth_base_url: str | None = None,
grpc_host: str | None = None,
grpc_port: int | None = None,
) -> None:
"""Redactive client handling multiple users authentication and access to the Redactive Search service.
:param api_key: Redactive API key.
:type api_key: str
:param callback_uri: The URI to redirect to after initiating the connection.
:type callback_uri: str
:param read_user_data: Function to read user data from storage.
:type read_user_data: Callable[[Annotated[str, user_id]], Awaitable[UserData]]
:param write_user_data: Function to write user data to storage.
:type write_user_data: Callable[[[Annotated[str, user_id], UserData | None], Awaitable[None]]
:param auth_base_url: Base URL for the authentication service. Optional.
:type auth_base_url: str | None
:param grpc_host: Host for the gRPC service. Optional.
:type grpc_host: str | None
:param grpc_port: Port for the gRPC service. Optional.
:type grpc_port: int | None
"""

self.auth_client = AuthClient(api_key, base_url=auth_base_url)
self.search_client = SearchClient(host=grpc_host, port=grpc_port)
self.callback_uri = callback_uri
self.read_user_data = read_user_data
self.write_user_data = write_user_data

async def get_begin_connection_url(self, user_id: str, provider: str) -> str:
state = str(uuid.uuid4())
response = await self.auth_client.begin_connection(provider, self.callback_uri, state=state)
user_data = await self.read_user_data(user_id)
user_data.sign_in_state = state
await self.write_user_data(user_id, user_data)
return response.url

async def _refresh_user_data(
self, user_id: str, refresh_token: str | None = None, sign_in_code: str | None = None
) -> UserData:
tokens = await self.auth_client.exchange_tokens(sign_in_code, refresh_token)
connections = await self.auth_client.list_connections(tokens.idToken)
user_data = UserData(
refresh_token=tokens.refreshToken,
id_token=tokens.idToken,
id_token_expiry=datetime.now(UTC) + timedelta(seconds=tokens.expiresIn - 10),
connections=connections.connections,
)
await self.write_user_data(user_id, user_data)
return user_data

async def get_users_redactive_email(self, user_id: str) -> str | None:
user_data = await self.read_user_data(user_id)
if not user_data or not user_data.id_token:
return None
token_body = jwt.decode(user_data.id_token, options={"verify_signature": False})
return token_body.get("email")

async def handle_connection_callback(self, user_id: str, sign_in_code: str, state: str) -> bool:
user_data = await self.read_user_data(user_id)
if not user_data or user_data.sign_in_state != state:
return False
await self._refresh_user_data(user_id, sign_in_code=sign_in_code)
return True

async def get_user_connections(self, user_id: str) -> list[str]:
user_data = await self.read_user_data(user_id)
if user_data and user_data.id_token_expiry and user_data.id_token_expiry > datetime.now(UTC):
return user_data.connections
if user_data and user_data.refresh_token:
user_data = await self._refresh_user_data(user_id, refresh_token=user_data.refresh_token)
return user_data.connections
return []

async def clear_user_data(self, user_id: str) -> None:
await self.write_user_data(user_id, None)

async def _get_id_token(self, user_id: str) -> str:
user_data = await self.read_user_data(user_id)
if not user_data or not user_data.refresh_token:
raise InvalidRedactiveSessionError(user_id)
if user_data.id_token_expiry and user_data.id_token_expiry < datetime.now(UTC):
user_data = await self._refresh_user_data(user_id, refresh_token=user_data.refresh_token)
if not user_data.id_token:
raise InvalidRedactiveSessionError(user_id)
return user_data.id_token

async def query_chunks(
self, user_id: str, semantic_query: str, count: int = 10, filters: dict | None = None
) -> list[RelevantChunk]:
id_token = await self._get_id_token(user_id)
return await self.search_client.query_chunks(id_token, semantic_query, count, filters)

async def query_chunks_by_document_name(
self, user_id: str, document_name: str, filters: dict | None = None
) -> list[Chunk]:
id_token = await self._get_id_token(user_id)
return await self.search_client.query_chunks_by_document_name(id_token, document_name, filters)

async def get_chunks_by_url(self, user_id: str, url: str) -> list[Chunk]:
id_token = await self._get_id_token(user_id)
return await self.search_client.get_chunks_by_url(id_token, url)
18 changes: 16 additions & 2 deletions sdks/python/tests/unit_tests/auth_client_tests.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Any
from urllib.parse import urlencode

import pytest

from redactive.auth_client import (
Expand All @@ -7,17 +10,28 @@
)


def build_uri_query(data: dict[str, Any]) -> str:
"""Take dict and convert to query string and ignore None value"""
return urlencode([(k, v) for k, v in data.items() if v is not None], doseq=True)


@pytest.fixture
def mock_client():
return AuthClient(api_key="test_api_key", base_url="https://mock.api")


@pytest.mark.asyncio
async def test_begin_connection(mock_client, httpx_mock):
provider = "provider"
redirect_uri = "https://redirect.uri"
state = "state123"
expected_url = f"https://mock.api/api/auth/connect/provider/url?redirect_uri={redirect_uri}"
httpx_mock.add_response(json={"url": expected_url})
response = await mock_client.begin_connection("provider", redirect_uri)
httpx_mock.add_response(
method="POST",
url=f"https://mock.api/api/auth/connect/{provider}/url?{build_uri_query({"redirect_uri": redirect_uri, "state": state})}",
json={"url": expected_url},
)
response = await mock_client.begin_connection("provider", redirect_uri, state="state123")
assert response == BeginConnectionResponse(url=expected_url)


Expand Down
Loading

0 comments on commit 0d1019e

Please sign in to comment.