Skip to content

Commit

Permalink
test_avertex_batch_prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
ishaan-jaff committed Feb 15, 2025
1 parent 8de6e7c commit 0ffd99a
Showing 1 changed file with 124 additions and 29 deletions.
153 changes: 124 additions & 29 deletions tests/batches_tests/test_openai_batches_and_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import StandardLoggingPayload
import random
from unittest.mock import patch, MagicMock


def load_vertex_ai_credentials():
Expand Down Expand Up @@ -367,37 +368,131 @@ async def test_async_create_batch(provider):
cleanup_azure_ft_models()


mock_file_response = {
"kind": "storage#object",
"id": "litellm-local/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/5f7b99ad-9203-4430-98bf-3b45451af4cb/1739598666670574",
"selfLink": "https://www.googleapis.com/storage/v1/b/litellm-local/o/litellm-vertex-files%2Fpublishers%2Fgoogle%2Fmodels%2Fgemini-1.5-flash-001%2F5f7b99ad-9203-4430-98bf-3b45451af4cb",
"mediaLink": "https://storage.googleapis.com/download/storage/v1/b/litellm-local/o/litellm-vertex-files%2Fpublishers%2Fgoogle%2Fmodels%2Fgemini-1.5-flash-001%2F5f7b99ad-9203-4430-98bf-3b45451af4cb?generation=1739598666670574&alt=media",
"name": "litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/5f7b99ad-9203-4430-98bf-3b45451af4cb",
"bucket": "litellm-local",
"generation": "1739598666670574",
"metageneration": "1",
"contentType": "application/json",
"storageClass": "STANDARD",
"size": "416",
"md5Hash": "hbBNj7C8KJ7oVH+JmyRM6A==",
"crc32c": "oDmiUA==",
"etag": "CO7D0IT+xIsDEAE=",
"timeCreated": "2025-02-15T05:51:06.741Z",
"updated": "2025-02-15T05:51:06.741Z",
"timeStorageClassUpdated": "2025-02-15T05:51:06.741Z",
"timeFinalized": "2025-02-15T05:51:06.741Z",
}

mock_vertex_batch_response = {
"name": "projects/123456789/locations/us-central1/batchPredictionJobs/test-batch-id-456",
"displayName": "litellm_batch_job",
"model": "projects/123456789/locations/us-central1/models/gemini-1.5-flash-001",
"modelVersionId": "v1",
"inputConfig": {
"gcsSource": {
"uris": [
"gs://litellm-local/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/5f7b99ad-9203-4430-98bf-3b45451af4cb"
]
}
},
"outputConfig": {
"gcsDestination": {"outputUriPrefix": "gs://litellm-local/batch-outputs/"}
},
"dedicatedResources": {
"machineSpec": {
"machineType": "n1-standard-4",
"acceleratorType": "NVIDIA_TESLA_T4",
"acceleratorCount": 1,
},
"startingReplicaCount": 1,
"maxReplicaCount": 1,
},
"state": "JOB_STATE_RUNNING",
"createTime": "2025-02-15T05:51:06.741Z",
"startTime": "2025-02-15T05:51:07.741Z",
"updateTime": "2025-02-15T05:51:08.741Z",
"labels": {"key1": "value1", "key2": "value2"},
"completionStats": {"successfulCount": 0, "failedCount": 0, "remainingCount": 100},
}


@pytest.mark.asyncio
async def test_avertex_batch_prediction():
load_vertex_ai_credentials()
litellm.set_verbose = True
file_name = "vertex_batch_completions.jsonl"
_current_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(_current_dir, file_name)
file_obj = await litellm.acreate_file(
file=open(file_path, "rb"),
purpose="batch",
custom_llm_provider="vertex_ai",
)
print("Response from creating file=", file_obj)
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post"
) as mock_post:
# Configure mock responses
mock_response = MagicMock()
mock_response.raise_for_status.return_value = None

# Set up different responses for different API calls
async def mock_side_effect(*args, **kwargs):
url = kwargs.get("url", "")
if "files" in url:
mock_response.json.return_value = mock_file_response
elif "batch" in url:
mock_response.json.return_value = mock_vertex_batch_response
mock_response.status_code = 200
return mock_response

mock_post.side_effect = mock_side_effect

# load_vertex_ai_credentials()
litellm.set_verbose = True
litellm._turn_on_debug()
file_name = "vertex_batch_completions.jsonl"
_current_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(_current_dir, file_name)

# Create file
file_obj = await litellm.acreate_file(
file=open(file_path, "rb"),
purpose="batch",
custom_llm_provider="vertex_ai",
)
print("Response from creating file=", file_obj)

batch_input_file_id = file_obj.id
assert (
batch_input_file_id is not None
), f"Failed to create file, expected a non null file_id but got {batch_input_file_id}"
assert (
file_obj.id
== "gs://litellm-local/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/5f7b99ad-9203-4430-98bf-3b45451af4cb"
)

create_batch_response = await litellm.acreate_batch(
completion_window="24h",
endpoint="/v1/chat/completions",
input_file_id=batch_input_file_id,
custom_llm_provider="vertex_ai",
metadata={"key1": "value1", "key2": "value2"},
)
print("create_batch_response=", create_batch_response)
# Create batch
create_batch_response = await litellm.acreate_batch(
completion_window="24h",
endpoint="/v1/chat/completions",
input_file_id=file_obj.id,
custom_llm_provider="vertex_ai",
metadata={"key1": "value1", "key2": "value2"},
)
print("create_batch_response=", create_batch_response)

retrieved_batch = await litellm.aretrieve_batch(
batch_id=create_batch_response.id,
custom_llm_provider="vertex_ai",
)
print("retrieved_batch=", retrieved_batch)
pass
assert create_batch_response.id == "test-batch-id-456"
assert (
create_batch_response.input_file_id
== "gs://litellm-local/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/5f7b99ad-9203-4430-98bf-3b45451af4cb"
)

# Mock the retrieve batch response
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.get"
) as mock_get:
mock_get_response = MagicMock()
mock_get_response.json.return_value = mock_vertex_batch_response
mock_get_response.status_code = 200
mock_get_response.raise_for_status.return_value = None
mock_get.return_value = mock_get_response

retrieved_batch = await litellm.aretrieve_batch(
batch_id=create_batch_response.id,
custom_llm_provider="vertex_ai",
)
print("retrieved_batch=", retrieved_batch)

assert retrieved_batch.id == "test-batch-id-456"

0 comments on commit 0ffd99a

Please sign in to comment.