Skip to content

Commit

Permalink
feat: add vertex support (#182)
Browse files Browse the repository at this point in the history
  • Loading branch information
miri-bar authored Aug 22, 2024
1 parent e526c53 commit 0c7a32d
Show file tree
Hide file tree
Showing 22 changed files with 864 additions and 9 deletions.
76 changes: 76 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
- [Bedrock](#Bedrock)
- [SageMaker](#SageMaker)
- [Azure](#Azure)
- [Vertex](#Vertex)

## Examples (tl;dr)

Expand Down Expand Up @@ -736,4 +737,79 @@ response = client.chat.completions.create(
)
```

#### Async

```python
import asyncio
from ai21 import AsyncAI21AzureClient
from ai21.models.chat import ChatMessage

client = AsyncAI21AzureClient(
base_url="https://<YOUR-ENDPOINT>.inference.ai.azure.com/v1/chat/completions",
api_key="<your Azure api key>",
)

messages = [
ChatMessage(content="You are a helpful assistant", role="system"),
ChatMessage(content="What is the meaning of life?", role="user")
]

async def main():
response = await client.chat.completions.create(
model="jamba-instruct",
messages=messages,
)

asyncio.run(main())
```

### Vertex

If you wish to interact with your Vertex AI endpoint on GCP, use the `AI21VertexClient`
and `AsyncAI21VertexClient` clients.

The following models are supported on Vertex:

- `jamba-1.5-mini`
- `jamba-1.5-large`

```python
from ai21 import AI21VertexClient

from ai21.models.chat import ChatMessage

# You can also set the project_id, region, access_token and Google credentials in the constructor
client = AI21VertexClient()

messages = ChatMessage(content="What is the meaning of life?", role="user")

response = client.chat.completions.create(
model="jamba-1.5-mini",
messages=[messages],
)
```

#### Async

```python
import asyncio

from ai21 import AsyncAI21VertexClient
from ai21.models.chat import ChatMessage

# You can also set the project_id, region, access_token and Google credentials in the constructor
client = AsyncAI21VertexClient()


async def main():
messages = ChatMessage(content="What is the meaning of life?", role="user")

response = await client.chat.completions.create(
model="jamba-1.5-mini",
messages=[messages],
)

asyncio.run(main())
```

Happy prompting! 🚀
22 changes: 21 additions & 1 deletion ai21/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ def _import_async_sagemaker_client():
return AsyncAI21SageMakerClient


def _import_vertex_client():
from ai21.clients.vertex.ai21_vertex_client import AI21VertexClient

return AI21VertexClient


def _import_async_vertex_client():
from ai21.clients.vertex.ai21_vertex_client import AsyncAI21VertexClient

return AsyncAI21VertexClient


def __getattr__(name: str) -> Any:
try:
if name == "AI21BedrockClient":
Expand All @@ -67,8 +79,14 @@ def __getattr__(name: str) -> Any:

if name == "AsyncAI21SageMakerClient":
return _import_async_sagemaker_client()

if name == "AI21VertexClient":
return _import_vertex_client()

if name == "AsyncAI21VertexClient":
return _import_async_vertex_client()
except ImportError as e:
raise ImportError(f'Please install "ai21[AWS]" in order to use {name}') from e
raise ImportError('Please install "ai21[AWS]" for SageMaker or Bedrock, or "ai21[Vertex]" for Vertex') from e


__all__ = [
Expand All @@ -89,4 +107,6 @@ def __getattr__(name: str) -> Any:
"AsyncAI21AzureClient",
"AsyncAI21BedrockClient",
"AsyncAI21SageMakerClient",
"AI21VertexClient",
"AsyncAI21VertexClient",
]
Empty file added ai21/clients/vertex/__init__.py
Empty file.
209 changes: 209 additions & 0 deletions ai21/clients/vertex/ai21_vertex_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
from __future__ import annotations

from typing import Optional, Dict, Any

import httpx
from google.auth.credentials import Credentials as GCPCredentials

from ai21.clients.studio.resources.studio_chat import StudioChat, AsyncStudioChat
from ai21.clients.vertex.gcp_authorization import GCPAuthorization
from ai21.http_client.async_http_client import AsyncAI21HTTPClient
from ai21.http_client.http_client import AI21HTTPClient
from ai21.models.request_options import RequestOptions

_DEFAULT_GCP_REGION = "us-central1"
_VERTEX_BASE_URL_FORMAT = "https://{region}-aiplatform.googleapis.com/v1"
_VERTEX_PATH_FORMAT = "/projects/{project_id}/locations/{region}/publishers/ai21/models/{model}:{endpoint}"


class BaseAI21VertexClient:
def __init__(
self,
region: Optional[str] = None,
project_id: Optional[str] = None,
access_token: Optional[str] = None,
credentials: Optional[GCPCredentials] = None,
):
if access_token is not None and project_id is None:
raise ValueError("Field project_id is required when setting access_token")
self._region = region or _DEFAULT_GCP_REGION
self._access_token = access_token
self._project_id = project_id
self._credentials = credentials
self._gcp_auth = GCPAuthorization()

def _get_base_url(self) -> str:
return _VERTEX_BASE_URL_FORMAT.format(region=self._region)

def _get_access_token(self) -> str:
if self._access_token is not None:
return self._access_token

if self._credentials is None:
self._credentials, self._project_id = self._gcp_auth.get_gcp_credentials(
project_id=self._project_id,
)

if self._credentials is None:
raise ValueError("Could not get credentials for GCP project")

self._gcp_auth.refresh_auth(self._credentials)

if self._credentials.token is None:
raise RuntimeError(f"Could not get access token for GCP project {self._project_id}")

return self._credentials.token

def _build_path(
self,
project_id: str,
region: str,
model: str,
endpoint: str,
) -> str:
return _VERTEX_PATH_FORMAT.format(
project_id=project_id,
region=region,
model=model,
endpoint=endpoint,
)

def _get_authorization_header(self) -> Dict[str, Any]:
access_token = self._get_access_token()
return {"Authorization": f"Bearer {access_token}"}


class AI21VertexClient(BaseAI21VertexClient, AI21HTTPClient):
def __init__(
self,
region: Optional[str] = None,
project_id: Optional[str] = None,
base_url: Optional[str] = None,
access_token: Optional[str] = None,
credentials: Optional[GCPCredentials] = None,
headers: Dict[str, str] | None = None,
timeout_sec: Optional[float] = None,
num_retries: Optional[int] = None,
http_client: Optional[httpx.Client] = None,
):
BaseAI21VertexClient.__init__(
self,
region=region,
project_id=project_id,
access_token=access_token,
credentials=credentials,
)

if base_url is None:
base_url = self._get_base_url()

AI21HTTPClient.__init__(
self,
base_url=base_url,
timeout_sec=timeout_sec,
num_retries=num_retries,
headers=headers,
client=http_client,
requires_api_key=False,
)

self.chat = StudioChat(self)
# Override the chat.create method to match the completions endpoint,
# so it wouldn't get to the old J2 completion endpoint
self.chat.create = self.chat.completions.create

def _build_request(self, options: RequestOptions) -> httpx.Request:
options = self._prepare_options(options)

return super()._build_request(options)

def _prepare_options(self, options: RequestOptions) -> RequestOptions:
body = options.body

model = body.pop("model")
stream = body.pop("stream", False)
endpoint = "streamRawPredict" if stream else "rawPredict"
headers = self._prepare_headers()
path = self._build_path(
project_id=self._project_id,
region=self._region,
model=model,
endpoint=endpoint,
)

return options.replace(
body=body,
path=path,
headers=headers,
)

def _prepare_headers(self) -> Dict[str, Any]:
return self._get_authorization_header()


class AsyncAI21VertexClient(BaseAI21VertexClient, AsyncAI21HTTPClient):
def __init__(
self,
region: Optional[str] = None,
project_id: Optional[str] = None,
base_url: Optional[str] = None,
access_token: Optional[str] = None,
credentials: Optional[GCPCredentials] = None,
headers: Dict[str, str] | None = None,
timeout_sec: Optional[float] = None,
num_retries: Optional[int] = None,
http_client: Optional[httpx.AsyncClient] = None,
):
BaseAI21VertexClient.__init__(
self,
region=region,
project_id=project_id,
access_token=access_token,
credentials=credentials,
)

if base_url is None:
base_url = self._get_base_url()

AsyncAI21HTTPClient.__init__(
self,
base_url=base_url,
timeout_sec=timeout_sec,
num_retries=num_retries,
headers=headers,
client=http_client,
requires_api_key=False,
)

self.chat = AsyncStudioChat(self)
# Override the chat.create method to match the completions endpoint,
# so it wouldn't get to the old J2 completion endpoint
self.chat.create = self.chat.completions.create

def _build_request(self, options: RequestOptions) -> httpx.Request:
options = self._prepare_options(options)

return super()._build_request(options)

def _prepare_options(self, options: RequestOptions) -> RequestOptions:
body = options.body

model = body.pop("model")
stream = body.pop("stream", False)
endpoint = "streamRawPredict" if stream else "rawPredict"
headers = self._prepare_headers()
path = self._build_path(
project_id=self._project_id,
region=self._region,
model=model,
endpoint=endpoint,
)

return options.replace(
body=body,
path=path,
headers=headers,
)

def _prepare_headers(self) -> Dict[str, Any]:
return self._get_authorization_header()
43 changes: 43 additions & 0 deletions ai21/clients/vertex/gcp_authorization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations

from typing import Optional, Tuple

import google.auth
from google.auth.credentials import Credentials
from google.auth.transport.requests import Request
from google.auth.exceptions import DefaultCredentialsError

from ai21.errors import CredentialsError


class GCPAuthorization:
def get_gcp_credentials(
self,
project_id: Optional[str] = None,
) -> Tuple[Credentials, str]:
try:
credentials, loaded_project_id = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
except DefaultCredentialsError as e:
raise CredentialsError(provider_name="GCP", error_message=str(e))

if project_id is not None and project_id != loaded_project_id:
raise ValueError("Mismatch between credentials project id and 'project_id'")

project_id = project_id or loaded_project_id

if project_id is None:
raise ValueError("Could not get project_id for GCP project")

if not isinstance(project_id, str):
raise ValueError(f"Variable project_id must be a string, got {type(project_id)} instead")

return credentials, project_id

def _get_gcp_request(self) -> Request:
return Request()

def refresh_auth(self, credentials: Credentials) -> None:
request = self._get_gcp_request()
credentials.refresh(request)
6 changes: 6 additions & 0 deletions ai21/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ def __init__(self, key: str):
super().__init__(message)


class CredentialsError(AI21Error):
def __init__(self, provider_name: str, error_message: str):
message = f"Could not get default {provider_name} credentials: {error_message}"
super().__init__(message)


class StreamingDecodeError(AI21Error):
def __init__(self, chunk: str, error_message: Optional[str] = None):
message = f"Failed to decode chunk: {chunk} in stream. Please check the stream format."
Expand Down
Empty file added examples/vertex/__init__.py
Empty file.
Loading

0 comments on commit 0c7a32d

Please sign in to comment.