Skip to content

Commit

Permalink
fix: error renames
Browse files Browse the repository at this point in the history
  • Loading branch information
asafgardin committed Dec 26, 2023
1 parent 68bc456 commit 6118385
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 15 deletions.
8 changes: 4 additions & 4 deletions ai21/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from ai21.errors import (
AI21APIError,
APITimeoutError,
MissingApiKeyException,
ModelPackageDoesntExistException,
MissingApiKeyError,
ModelPackageDoesntExistError,
AI21Error,
TooManyRequestsError,
)
Expand Down Expand Up @@ -69,8 +69,8 @@ def __getattr__(name: str) -> Any:
"AI21APIError",
"APITimeoutError",
"AI21Error",
"MissingApiKeyException",
"ModelPackageDoesntExistException",
"MissingApiKeyError",
"ModelPackageDoesntExistError",
"TooManyRequestsError",
"AI21BedrockClient",
"AI21SageMakerClient",
Expand Down
4 changes: 2 additions & 2 deletions ai21/ai21_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig
from ai21.errors import MissingApiKeyException
from ai21.errors import MissingApiKeyError
from ai21.http_client import HttpClient
from ai21.version import VERSION

Expand All @@ -28,7 +28,7 @@ def __init__(
self._api_key = api_key or self._env_config.api_key

if self._api_key is None:
raise MissingApiKeyException()
raise MissingApiKeyError()

self._api_host = api_host or self._env_config.api_host
self._api_version = api_version or self._env_config.api_version
Expand Down
4 changes: 2 additions & 2 deletions ai21/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ def __str__(self) -> str:
return f"{type(self).__name__} {self.message}"


class MissingApiKeyException(AI21Error):
class MissingApiKeyError(AI21Error):
def __init__(self):
message = "API key must be supplied either globally in the ai21 namespace, or to be provided in the call args"
super().__init__(message)
self.message = message


class ModelPackageDoesntExistException(AI21Error):
class ModelPackageDoesntExistError(AI21Error):
def __init__(self, model_name: str, region: str, version: Optional[str] = None):
message = f"model_name: {model_name} doesn't exist in region: {region}"

Expand Down
6 changes: 3 additions & 3 deletions ai21/services/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ai21.clients.sagemaker.constants import (
SAGEMAKER_MODEL_PACKAGE_NAMES,
)
from ai21.errors import ModelPackageDoesntExistException
from ai21.errors import ModelPackageDoesntExistError

_JUMPSTART_ENDPOINT = "jumpstart"
_LIST_VERSIONS_ENDPOINT = f"{_JUMPSTART_ENDPOINT}/list_versions"
Expand Down Expand Up @@ -33,7 +33,7 @@ def get_model_package_arn(cls, model_name: str, region: str, version: str = LATE
arn = response["arn"]

if not arn:
raise ModelPackageDoesntExistException(model_name=model_name, region=region, version=version)
raise ModelPackageDoesntExistError(model_name=model_name, region=region, version=version)

return arn

Expand Down Expand Up @@ -61,4 +61,4 @@ def _create_ai21_http_client(cls) -> AI21HTTPClient:

def _assert_model_package_exists(model_name, region):
if model_name not in SAGEMAKER_MODEL_PACKAGE_NAMES:
raise ModelPackageDoesntExistException(model_name=model_name, region=region)
raise ModelPackageDoesntExistError(model_name=model_name, region=region)
8 changes: 4 additions & 4 deletions tests/unittests/services/test_sagemaker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from ai21.errors import ModelPackageDoesntExistException
from ai21.errors import ModelPackageDoesntExistError
from tests.unittests.services.sagemaker_stub import SageMakerStub

_DUMMY_ARN = "some-model-package-id1"
Expand All @@ -22,7 +22,7 @@ def test__get_model_package_arn__should_return_model_package_arn(self):
def test__get_model_package_arn__when_no_arn__should_raise_error(self):
SageMakerStub.ai21_http_client.execute_http_request.return_value = {"arn": []}

with pytest.raises(ModelPackageDoesntExistException):
with pytest.raises(ModelPackageDoesntExistError):
SageMakerStub.get_model_package_arn(model_name="j2-mid", region="us-east-1")

def test__list_model_package_versions__should_return_model_package_arn(self):
Expand All @@ -36,9 +36,9 @@ def test__list_model_package_versions__should_return_model_package_arn(self):
assert actual_model_package_arn == _DUMMY_VERSIONS

def test__list_model_package_versions__when_model_package_not_available__should_raise_an_error(self):
with pytest.raises(ModelPackageDoesntExistException):
with pytest.raises(ModelPackageDoesntExistError):
SageMakerStub.list_model_package_versions(model_name="openai", region="us-east-1")

def test__get_model_package_arn__when_model_package_not_available__should_raise_an_error(self):
with pytest.raises(ModelPackageDoesntExistException):
with pytest.raises(ModelPackageDoesntExistError):
SageMakerStub.get_model_package_arn(model_name="openai", region="us-east-1")

0 comments on commit 6118385

Please sign in to comment.