diff --git a/ai21/__init__.py b/ai21/__init__.py index 799aeaaa..4d3ac0e9 100644 --- a/ai21/__init__.py +++ b/ai21/__init__.py @@ -1,11 +1,9 @@ +from ai21.clients.bedrock.ai21_bedrock_client import AI21BedrockClient +from .clients.bedrock.bedrock_model_id import BedrockModelID +from .clients.sagemaker.ai21_sagemaker_client import AI21SageMakerClient +from .clients.studio.ai21_client import AI21Client from .version import VERSION -from .clients import AI21Client, AI21BedrockClient, AI21SageMakerClient __version__ = VERSION __all__ = ["AI21Client", "AI21BedrockClient", "AI21SageMakerClient", "BedrockModelID"] - - -class BedrockModelID: - J2_MID_V1 = "ai21.j2-mid-v1" - J2_ULTRA_V1 = "ai21.j2-ultra-v1" diff --git a/ai21/clients/__init__.py b/ai21/clients/__init__.py index 4576472c..e69de29b 100644 --- a/ai21/clients/__init__.py +++ b/ai21/clients/__init__.py @@ -1,5 +0,0 @@ -from .sagemaker import AI21SageMakerClient -from .bedrock import AI21BedrockClient -from .studio import AI21Client - -__all__ = ["AI21Client", "AI21BedrockClient", "AI21SageMakerClient"] diff --git a/ai21/clients/bedrock/__init__.py b/ai21/clients/bedrock/__init__.py index cd06a82e..e69de29b 100644 --- a/ai21/clients/bedrock/__init__.py +++ b/ai21/clients/bedrock/__init__.py @@ -1,3 +0,0 @@ -from .ai21_bedrock_client import AI21BedrockClient - -__all__ = ["AI21BedrockClient"] diff --git a/ai21/clients/bedrock/ai21_bedrock_client.py b/ai21/clients/bedrock/ai21_bedrock_client.py index 27e84d92..1fecb895 100644 --- a/ai21/clients/bedrock/ai21_bedrock_client.py +++ b/ai21/clients/bedrock/ai21_bedrock_client.py @@ -6,17 +6,11 @@ from botocore.exceptions import ClientError from ai21.ai21_env_config import AI21EnvConfig, _AI21EnvConfig -from ai21.clients.bedrock import resources +from ai21.clients.bedrock.resources.bedrock_completion import BedrockCompletion from ai21.errors import AccessDenied, NotFound, APITimeoutError from ai21.http_client import handle_non_success_response from ai21.utils import log_error -__all__ = [ - "resources", - "AI21BedrockClient", -] - - RUNTIME_NAME = "bedrock-runtime" _ERROR_MSG_TEMPLATE = ( r"Received client error \((.*?)\) from primary with message \"(.*?)\". " @@ -37,7 +31,7 @@ def __init__( self._session = ( session.client(RUNTIME_NAME) if session else boto3.client(RUNTIME_NAME, region_name=env_config.aws_region) ) - self.completion = resources.BedrockCompletion(self) + self.completion = BedrockCompletion(self) def invoke_model(self, model_id: str, input_json: str) -> Dict[str, Any]: try: diff --git a/ai21/clients/bedrock/bedrock_model_id.py b/ai21/clients/bedrock/bedrock_model_id.py new file mode 100644 index 00000000..c73b476e --- /dev/null +++ b/ai21/clients/bedrock/bedrock_model_id.py @@ -0,0 +1,3 @@ +class BedrockModelID: + J2_MID_V1 = "ai21.j2-mid-v1" + J2_ULTRA_V1 = "ai21.j2-ultra-v1" diff --git a/ai21/clients/bedrock/resources/__init__.py b/ai21/clients/bedrock/resources/__init__.py index 9bdad381..e69de29b 100644 --- a/ai21/clients/bedrock/resources/__init__.py +++ b/ai21/clients/bedrock/resources/__init__.py @@ -1,3 +0,0 @@ -from .bedrock_completion import BedrockCompletion - -__all__ = ["BedrockCompletion"] diff --git a/ai21/clients/sagemaker/__init__.py b/ai21/clients/sagemaker/__init__.py index 76323756..e69de29b 100644 --- a/ai21/clients/sagemaker/__init__.py +++ b/ai21/clients/sagemaker/__init__.py @@ -1,3 +0,0 @@ -from .ai21_sagemaker_client import AI21SageMakerClient - -__all__ = ["AI21SageMakerClient"] diff --git a/ai21/clients/sagemaker/ai21_sagemaker_client.py b/ai21/clients/sagemaker/ai21_sagemaker_client.py index d8fe483b..3efd168d 100644 --- a/ai21/clients/sagemaker/ai21_sagemaker_client.py +++ b/ai21/clients/sagemaker/ai21_sagemaker_client.py @@ -7,13 +7,15 @@ from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig from ai21.ai21_studio_client import AI21StudioClient -from ai21.clients.sagemaker import resources +from ai21.clients.sagemaker.resources.sagemaker_answer import SageMakerAnswer +from ai21.clients.sagemaker.resources.sagemaker_completion import SageMakerCompletion +from ai21.clients.sagemaker.resources.sagemaker_gec import SageMakerGEC +from ai21.clients.sagemaker.resources.sagemaker_paraphrase import SageMakerParaphrase +from ai21.clients.sagemaker.resources.sagemaker_summarize import SageMakerSummarize from ai21.errors import BadRequest, ServiceUnavailable, ServerError, APIError from ai21.http_client import handle_non_success_response from ai21.utils import log_error -__all__ = ["resources", "AI21SageMakerClient"] - # Each one of the clients should be able to implement async/sync interface _ERROR_MSG_TEMPLATE = ( r"Received client error \((.*?)\) from primary with message \"(.*?)\". " @@ -56,11 +58,11 @@ def __init__( ) self._region = region or self._env_config.aws_region self._endpoint_name = endpoint_name - self.completion = resources.SageMakerCompletion(self) - self.paraphrase = resources.SageMakerParaphrase(self) - self.answer = resources.SageMakerAnswer(self) - self.gec = resources.SageMakerGEC(self) - self.summarize = resources.SageMakerSummarize(self) + self.completion = SageMakerCompletion(self) + self.paraphrase = SageMakerParaphrase(self) + self.answer = SageMakerAnswer(self) + self.gec = SageMakerGEC(self) + self.summarize = SageMakerSummarize(self) def invoke_endpoint( self, diff --git a/ai21/clients/sagemaker/constants.py b/ai21/clients/sagemaker/constants.py index d6991473..121e47b4 100644 --- a/ai21/clients/sagemaker/constants.py +++ b/ai21/clients/sagemaker/constants.py @@ -1,5 +1,3 @@ -SAGEMAKER_ENDPOINT_KEY = "sm_endpoint" - SAGEMAKER_MODEL_PACKAGE_NAMES = [ "j2-light", "j2-mid", diff --git a/ai21/clients/sagemaker/resources/__init__.py b/ai21/clients/sagemaker/resources/__init__.py index 94bcade8..e69de29b 100644 --- a/ai21/clients/sagemaker/resources/__init__.py +++ b/ai21/clients/sagemaker/resources/__init__.py @@ -1,7 +0,0 @@ -from .sagemaker_completion import SageMakerCompletion -from .sagemaker_gec import SageMakerGEC -from .sagemaker_paraphrase import SageMakerParaphrase -from .sagemaker_summarize import SageMakerSummarize -from .sagemaker_answer import SageMakerAnswer - -__all__ = ["SageMakerSummarize", "SageMakerParaphrase", "SageMakerGEC", "SageMakerCompletion", "SageMakerAnswer"] diff --git a/ai21/clients/studio/__init__.py b/ai21/clients/studio/__init__.py index 62968a0e..e69de29b 100644 --- a/ai21/clients/studio/__init__.py +++ b/ai21/clients/studio/__init__.py @@ -1,3 +0,0 @@ -from .ai21_client import AI21Client - -__all__ = ["AI21Client"] diff --git a/ai21/clients/studio/ai21_client.py b/ai21/clients/studio/ai21_client.py index c0b3bf12..51b7ad5c 100644 --- a/ai21/clients/studio/ai21_client.py +++ b/ai21/clients/studio/ai21_client.py @@ -1,12 +1,22 @@ from typing import Optional, Any, Dict from ai21.ai21_studio_client import AI21StudioClient -from ai21.clients.studio import resources +from ai21.clients.studio.resources.studio_answer import StudioAnswer +from ai21.clients.studio.resources.studio_chat import StudioChat +from ai21.clients.studio.resources.studio_completion import StudioCompletion +from ai21.clients.studio.resources.studio_custom_model import StudioCustomModel +from ai21.clients.studio.resources.studio_dataset import StudioDataset +from ai21.clients.studio.resources.studio_embed import StudioEmbed +from ai21.clients.studio.resources.studio_gec import StudioGEC +from ai21.clients.studio.resources.studio_improvements import StudioImprovements +from ai21.clients.studio.resources.studio_library import StudioLibrary +from ai21.clients.studio.resources.studio_paraphrase import StudioParaphrase +from ai21.clients.studio.resources.studio_segmentation import StudioSegmentation +from ai21.clients.studio.resources.studio_summarize import StudioSummarize +from ai21.clients.studio.resources.studio_summarize_by_segment import StudioSummarizeBySegment from ai21.tokenizers.ai21_tokenizer import AI21Tokenizer from ai21.tokenizers.factory import get_tokenizer -__all__ = ["AI21Client", "resources"] - class AI21Client(AI21StudioClient): """ @@ -33,19 +43,19 @@ def __init__( timeout_sec=timeout_sec, num_retries=num_retries, ) - self.completion = resources.StudioCompletion(self) - self.chat = resources.StudioChat(self) - self.summarize = resources.StudioSummarize(self) - self.embed = resources.StudioEmbed(self) - self.gec = resources.StudioGEC(self) - self.improvements = resources.StudioImprovements(self) - self.paraphrase = resources.StudioParaphrase(self) - self.summarize_by_segment = resources.StudioSummarizeBySegment(self) - self.custom_model = resources.StudioCustomModel(self) - self.dataset = resources.StudioDataset(self) - self.answer = resources.StudioAnswer(self) - self.library = resources.StudioLibrary(self) - self.segmentation = resources.StudioSegmentation(self) + self.completion = StudioCompletion(self) + self.chat = StudioChat(self) + self.summarize = StudioSummarize(self) + self.embed = StudioEmbed(self) + self.gec = StudioGEC(self) + self.improvements = StudioImprovements(self) + self.paraphrase = StudioParaphrase(self) + self.summarize_by_segment = StudioSummarizeBySegment(self) + self.custom_model = StudioCustomModel(self) + self.dataset = StudioDataset(self) + self.answer = StudioAnswer(self) + self.library = StudioLibrary(self) + self.segmentation = StudioSegmentation(self) def count_token(self, text: str, model_id: str = "j2-instruct") -> int: # We might want to cache the tokenizer instance within the class diff --git a/ai21/clients/studio/resources/__init__.py b/ai21/clients/studio/resources/__init__.py index 0f570c42..e69de29b 100644 --- a/ai21/clients/studio/resources/__init__.py +++ b/ai21/clients/studio/resources/__init__.py @@ -1,29 +0,0 @@ -from .studio_answer import StudioAnswer -from .studio_chat import StudioChat -from .studio_completion import StudioCompletion -from .studio_custom_model import StudioCustomModel -from .studio_dataset import StudioDataset -from .studio_embed import StudioEmbed -from .studio_gec import StudioGEC -from .studio_improvements import StudioImprovements -from .studio_library import StudioLibrary -from .studio_paraphrase import StudioParaphrase -from .studio_segmentation import StudioSegmentation -from .studio_summarize import StudioSummarize -from .studio_summarize_by_segment import StudioSummarizeBySegment - -__all__ = [ - "StudioSummarizeBySegment", - "StudioSummarize", - "StudioSegmentation", - "StudioParaphrase", - "StudioLibrary", - "StudioImprovements", - "StudioGEC", - "StudioEmbed", - "StudioDataset", - "StudioCustomModel", - "StudioCompletion", - "StudioChat", - "StudioAnswer", -]