diff --git a/examples/studio/chat/chat_function_calling.py b/examples/studio/chat/chat_function_calling.py index 4b5d08e2..e55feaf2 100644 --- a/examples/studio/chat/chat_function_calling.py +++ b/examples/studio/chat/chat_function_calling.py @@ -2,7 +2,14 @@ from ai21 import AI21Client from ai21.logger import set_verbose -from ai21.models.chat import ChatMessage, ToolMessage, FunctionToolDefinition, ToolDefinition, ToolParameters +from ai21.models.chat import ( + ChatMessage, + FunctionToolDefinition, + ToolDefinition, + ToolMessage, + ToolParameters, +) + set_verbose(True) diff --git a/tests/integration_tests/clients/studio/test_chat.py b/tests/integration_tests/clients/studio/test_chat.py deleted file mode 100644 index ba27f8c4..00000000 --- a/tests/integration_tests/clients/studio/test_chat.py +++ /dev/null @@ -1,177 +0,0 @@ -import pytest - -from ai21 import AI21Client, AsyncAI21Client -from ai21.models import ChatMessage, RoleType, Penalty, FinishReason - -_MODEL = "j2-ultra" -_MESSAGES = [ - ChatMessage( - text="Hello, I need help studying for the coming test, can you teach me about the US constitution? ", - role=RoleType.USER, - ), -] -_SYSTEM = "You are a teacher in a public school" - - -def test_chat(): - num_results = 5 - messages = _MESSAGES - - client = AI21Client() - response = client.chat.create( - system=_SYSTEM, - messages=messages, - num_results=num_results, - max_tokens=64, - temperature=0.7, - min_tokens=1, - stop_sequences=["\n"], - top_p=0.3, - top_k_return=0, - model=_MODEL, - count_penalty=Penalty( - scale=0, - apply_to_emojis=False, - apply_to_numbers=False, - apply_to_stopwords=False, - apply_to_punctuation=False, - apply_to_whitespaces=False, - ), - frequency_penalty=Penalty( - scale=0, - apply_to_emojis=False, - apply_to_numbers=False, - apply_to_stopwords=False, - apply_to_punctuation=False, - apply_to_whitespaces=False, - ), - presence_penalty=Penalty( - scale=0, - apply_to_emojis=False, - apply_to_numbers=False, - apply_to_stopwords=False, - apply_to_punctuation=False, - apply_to_whitespaces=False, - ), - ) - - assert response.outputs[0].role == RoleType.ASSISTANT - assert isinstance(response.outputs[0].text, str) - assert response.outputs[0].finish_reason == FinishReason(reason="stop", sequence="\n") - - assert len(response.outputs) == num_results - - -@pytest.mark.parametrize( - ids=[ - "finish_reason_length", - "finish_reason_endoftext", - "finish_reason_stop_sequence", - ], - argnames=["max_tokens", "stop_sequences", "reason"], - argvalues=[ - (2, "##", "length"), - (1000, "##", "endoftext"), - (20, ".", "stop"), - ], -) -def test_chat_when_finish_reason_defined__should_halt_on_expected_reason( - max_tokens: int, stop_sequences: str, reason: str -): - client = AI21Client() - response = client.chat.create( - messages=_MESSAGES, - system=_SYSTEM, - max_tokens=max_tokens, - model="j2-ultra", - temperature=1, - top_p=0, - num_results=1, - stop_sequences=[stop_sequences], - top_k_return=0, - ) - - assert response.outputs[0].finish_reason.reason == reason - - -@pytest.mark.asyncio -async def test_async_chat(): - num_results = 5 - messages = _MESSAGES - - client = AsyncAI21Client() - response = await client.chat.create( - system=_SYSTEM, - messages=messages, - num_results=num_results, - max_tokens=64, - temperature=0.7, - min_tokens=1, - stop_sequences=["\n"], - top_p=0.3, - top_k_return=0, - model=_MODEL, - count_penalty=Penalty( - scale=0, - apply_to_emojis=False, - apply_to_numbers=False, - apply_to_stopwords=False, - apply_to_punctuation=False, - apply_to_whitespaces=False, - ), - frequency_penalty=Penalty( - scale=0, - apply_to_emojis=False, - apply_to_numbers=False, - apply_to_stopwords=False, - apply_to_punctuation=False, - apply_to_whitespaces=False, - ), - presence_penalty=Penalty( - scale=0, - apply_to_emojis=False, - apply_to_numbers=False, - apply_to_stopwords=False, - apply_to_punctuation=False, - apply_to_whitespaces=False, - ), - ) - - assert response.outputs[0].role == RoleType.ASSISTANT - assert isinstance(response.outputs[0].text, str) - assert response.outputs[0].finish_reason == FinishReason(reason="stop", sequence="\n") - - assert len(response.outputs) == num_results - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - ids=[ - "finish_reason_length", - "finish_reason_endoftext", - "finish_reason_stop_sequence", - ], - argnames=["max_tokens", "stop_sequences", "reason"], - argvalues=[ - (2, "##", "length"), - (1000, "##", "endoftext"), - (20, ".", "stop"), - ], -) -async def test_async_chat_when_finish_reason_defined__should_halt_on_expected_reason( - max_tokens: int, stop_sequences: str, reason: str -): - client = AsyncAI21Client() - response = await client.chat.create( - messages=_MESSAGES, - system=_SYSTEM, - max_tokens=max_tokens, - model="j2-ultra", - temperature=1, - top_p=0, - num_results=1, - stop_sequences=[stop_sequences], - top_k_return=0, - ) - - assert response.outputs[0].finish_reason.reason == reason diff --git a/tests/integration_tests/clients/test_studio.py b/tests/integration_tests/clients/test_studio.py index 5cd0d4dd..1e10e10e 100644 --- a/tests/integration_tests/clients/test_studio.py +++ b/tests/integration_tests/clients/test_studio.py @@ -2,14 +2,16 @@ Run this script after setting the environment variable called AI21_API_KEY """ +import subprocess + from pathlib import Path from time import sleep import pytest -import subprocess from tests.integration_tests.skip_helpers import should_skip_studio_integration_tests + STUDIO_PATH = Path(__file__).parent.parent.parent.parent / "examples" / "studio" @@ -50,14 +52,12 @@ def test_studio(test_file_name: str): @pytest.mark.parametrize( argnames=["test_file_name"], argvalues=[ - ("async_chat.py",), ("chat/async_chat_completions.py",), ("chat/async_stream_chat_completions.py",), ("conversational_rag/conversational_rag.py",), ("conversational_rag/async_conversational_rag.py",), ], ids=[ - "when_chat__should_return_ok", "when_chat_completions__should_return_ok", "when_stream_chat_completions__should_return_ok", "when_conversational_rag__should_return_ok",