diff --git a/.mock/definition/__package__.yml b/.mock/definition/__package__.yml index 9f24749a8..1256f5c0b 100644 --- a/.mock/definition/__package__.yml +++ b/.mock/definition/__package__.yml @@ -2347,20 +2347,34 @@ types: organization: optional source: openapi: openapi/openapi.yaml + RefinedPromptResponseRefinementStatus: + enum: + - Pending + - InProgress + - Completed + - Failed + docs: Status of the refinement job + default: Pending + source: + openapi: openapi/openapi.yaml RefinedPromptResponse: properties: title: - type: string + type: optional docs: Title of the refined prompt reasoning: - type: string + type: optional docs: Reasoning behind the refinement prompt: type: string docs: The refined prompt text refinement_job_id: - type: string + type: optional docs: Unique identifier for the refinement job + refinement_status: + type: optional + docs: Status of the refinement job + default: Pending source: openapi: openapi/openapi.yaml InferenceRunOrganization: diff --git a/.mock/definition/prompts/versions.yml b/.mock/definition/prompts/versions.yml index 2decea25d..c21e9acda 100644 --- a/.mock/definition/prompts/versions.yml +++ b/.mock/definition/prompts/versions.yml @@ -162,6 +162,46 @@ service: organization: 1 audiences: - public + get_refined_prompt: + path: /api/prompts/{prompt_id}/versions/{version_id}/refine + method: GET + auth: true + docs: | + Get the refined prompt based on the `refinement_job_id`. + path-parameters: + prompt_id: + type: integer + docs: Prompt ID + version_id: + type: integer + docs: Prompt Version ID + display-name: Get refined prompt + request: + name: VersionsGetRefinedPromptRequest + query-parameters: + refinement_job_id: + type: string + docs: >- + Refinement Job ID acquired from the `POST + /api/prompts/{prompt_id}/versions/{version_id}/refine` endpoint + response: + docs: '' + type: root.RefinedPromptResponse + examples: + - path-parameters: + prompt_id: 1 + version_id: 1 + query-parameters: + refinement_job_id: refinement_job_id + response: + body: + title: title + reasoning: reasoning + prompt: prompt + refinement_job_id: refinement_job_id + refinement_status: Pending + audiences: + - public refine_prompt: path: /api/prompts/{prompt_id}/versions/{version_id}/refine method: POST @@ -204,6 +244,7 @@ service: reasoning: reasoning prompt: prompt refinement_job_id: refinement_job_id + refinement_status: Pending audiences: - public source: diff --git a/reference.md b/reference.md index 456ca2b4d..73010d0bc 100644 --- a/reference.md +++ b/reference.md @@ -15161,6 +15161,94 @@ client.prompts.versions.update( + + + + +
client.prompts.versions.get_refined_prompt(...) +
+
+ +#### 📝 Description + +
+
+ +
+
+ +Get the refined prompt based on the `refinement_job_id`. +
+
+
+
+ +#### 🔌 Usage + +
+
+ +
+
+ +```python +from label_studio_sdk.client import LabelStudio + +client = LabelStudio( + api_key="YOUR_API_KEY", +) +client.prompts.versions.get_refined_prompt( + prompt_id=1, + version_id=1, + refinement_job_id="refinement_job_id", +) + +``` +
+
+
+
+ +#### ⚙️ Parameters + +
+
+ +
+
+ +**prompt_id:** `int` — Prompt ID + +
+
+ +
+
+ +**version_id:** `int` — Prompt Version ID + +
+
+ +
+
+ +**refinement_job_id:** `str` — Refinement Job ID acquired from the `POST /api/prompts/{prompt_id}/versions/{version_id}/refine` endpoint + +
+
+ +
+
+ +**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration. + +
+
+
+
+ +
diff --git a/src/label_studio_sdk/__init__.py b/src/label_studio_sdk/__init__.py index 2e46761f7..c320e3733 100644 --- a/src/label_studio_sdk/__init__.py +++ b/src/label_studio_sdk/__init__.py @@ -75,6 +75,7 @@ RedisImportStorage, RedisImportStorageStatus, RefinedPromptResponse, + RefinedPromptResponseRefinementStatus, S3ExportStorage, S3ExportStorageStatus, S3ImportStorage, @@ -277,6 +278,7 @@ "RedisImportStorage", "RedisImportStorageStatus", "RefinedPromptResponse", + "RefinedPromptResponseRefinementStatus", "S3ExportStorage", "S3ExportStorageStatus", "S3ImportStorage", diff --git a/src/label_studio_sdk/prompts/versions/client.py b/src/label_studio_sdk/prompts/versions/client.py index 061752028..0dc9f31b7 100644 --- a/src/label_studio_sdk/prompts/versions/client.py +++ b/src/label_studio_sdk/prompts/versions/client.py @@ -336,6 +336,63 @@ def update( raise ApiError(status_code=_response.status_code, body=_response.text) raise ApiError(status_code=_response.status_code, body=_response_json) + def get_refined_prompt( + self, + prompt_id: int, + version_id: int, + *, + refinement_job_id: str, + request_options: typing.Optional[RequestOptions] = None, + ) -> RefinedPromptResponse: + """ + Get the refined prompt based on the `refinement_job_id`. + + Parameters + ---------- + prompt_id : int + Prompt ID + + version_id : int + Prompt Version ID + + refinement_job_id : str + Refinement Job ID acquired from the `POST /api/prompts/{prompt_id}/versions/{version_id}/refine` endpoint + + request_options : typing.Optional[RequestOptions] + Request-specific configuration. + + Returns + ------- + RefinedPromptResponse + + + Examples + -------- + from label_studio_sdk.client import LabelStudio + + client = LabelStudio( + api_key="YOUR_API_KEY", + ) + client.prompts.versions.get_refined_prompt( + prompt_id=1, + version_id=1, + refinement_job_id="refinement_job_id", + ) + """ + _response = self._client_wrapper.httpx_client.request( + f"api/prompts/{jsonable_encoder(prompt_id)}/versions/{jsonable_encoder(version_id)}/refine", + method="GET", + params={"refinement_job_id": refinement_job_id}, + request_options=request_options, + ) + try: + if 200 <= _response.status_code < 300: + return pydantic_v1.parse_obj_as(RefinedPromptResponse, _response.json()) # type: ignore + _response_json = _response.json() + except JSONDecodeError: + raise ApiError(status_code=_response.status_code, body=_response.text) + raise ApiError(status_code=_response.status_code, body=_response_json) + def refine_prompt( self, prompt_id: int, @@ -727,6 +784,63 @@ async def update( raise ApiError(status_code=_response.status_code, body=_response.text) raise ApiError(status_code=_response.status_code, body=_response_json) + async def get_refined_prompt( + self, + prompt_id: int, + version_id: int, + *, + refinement_job_id: str, + request_options: typing.Optional[RequestOptions] = None, + ) -> RefinedPromptResponse: + """ + Get the refined prompt based on the `refinement_job_id`. + + Parameters + ---------- + prompt_id : int + Prompt ID + + version_id : int + Prompt Version ID + + refinement_job_id : str + Refinement Job ID acquired from the `POST /api/prompts/{prompt_id}/versions/{version_id}/refine` endpoint + + request_options : typing.Optional[RequestOptions] + Request-specific configuration. + + Returns + ------- + RefinedPromptResponse + + + Examples + -------- + from label_studio_sdk.client import AsyncLabelStudio + + client = AsyncLabelStudio( + api_key="YOUR_API_KEY", + ) + await client.prompts.versions.get_refined_prompt( + prompt_id=1, + version_id=1, + refinement_job_id="refinement_job_id", + ) + """ + _response = await self._client_wrapper.httpx_client.request( + f"api/prompts/{jsonable_encoder(prompt_id)}/versions/{jsonable_encoder(version_id)}/refine", + method="GET", + params={"refinement_job_id": refinement_job_id}, + request_options=request_options, + ) + try: + if 200 <= _response.status_code < 300: + return pydantic_v1.parse_obj_as(RefinedPromptResponse, _response.json()) # type: ignore + _response_json = _response.json() + except JSONDecodeError: + raise ApiError(status_code=_response.status_code, body=_response.text) + raise ApiError(status_code=_response.status_code, body=_response_json) + async def refine_prompt( self, prompt_id: int, diff --git a/src/label_studio_sdk/types/__init__.py b/src/label_studio_sdk/types/__init__.py index 3b5316d3c..a3b2f2833 100644 --- a/src/label_studio_sdk/types/__init__.py +++ b/src/label_studio_sdk/types/__init__.py @@ -74,6 +74,7 @@ from .redis_import_storage import RedisImportStorage from .redis_import_storage_status import RedisImportStorageStatus from .refined_prompt_response import RefinedPromptResponse +from .refined_prompt_response_refinement_status import RefinedPromptResponseRefinementStatus from .s3export_storage import S3ExportStorage from .s3export_storage_status import S3ExportStorageStatus from .s3import_storage import S3ImportStorage @@ -169,6 +170,7 @@ "RedisImportStorage", "RedisImportStorageStatus", "RefinedPromptResponse", + "RefinedPromptResponseRefinementStatus", "S3ExportStorage", "S3ExportStorageStatus", "S3ImportStorage", diff --git a/src/label_studio_sdk/types/refined_prompt_response.py b/src/label_studio_sdk/types/refined_prompt_response.py index 1b37313f1..756cecc4a 100644 --- a/src/label_studio_sdk/types/refined_prompt_response.py +++ b/src/label_studio_sdk/types/refined_prompt_response.py @@ -5,15 +5,16 @@ from ..core.datetime_utils import serialize_datetime from ..core.pydantic_utilities import deep_union_pydantic_dicts, pydantic_v1 +from .refined_prompt_response_refinement_status import RefinedPromptResponseRefinementStatus class RefinedPromptResponse(pydantic_v1.BaseModel): - title: str = pydantic_v1.Field() + title: typing.Optional[str] = pydantic_v1.Field(default=None) """ Title of the refined prompt """ - reasoning: str = pydantic_v1.Field() + reasoning: typing.Optional[str] = pydantic_v1.Field(default=None) """ Reasoning behind the refinement """ @@ -23,11 +24,16 @@ class RefinedPromptResponse(pydantic_v1.BaseModel): The refined prompt text """ - refinement_job_id: str = pydantic_v1.Field() + refinement_job_id: typing.Optional[str] = pydantic_v1.Field(default=None) """ Unique identifier for the refinement job """ + refinement_status: typing.Optional[RefinedPromptResponseRefinementStatus] = pydantic_v1.Field(default=None) + """ + Status of the refinement job + """ + def json(self, **kwargs: typing.Any) -> str: kwargs_with_defaults: typing.Any = {"by_alias": True, "exclude_unset": True, **kwargs} return super().json(**kwargs_with_defaults) diff --git a/src/label_studio_sdk/types/refined_prompt_response_refinement_status.py b/src/label_studio_sdk/types/refined_prompt_response_refinement_status.py new file mode 100644 index 000000000..8488e13ca --- /dev/null +++ b/src/label_studio_sdk/types/refined_prompt_response_refinement_status.py @@ -0,0 +1,7 @@ +# This file was auto-generated by Fern from our API Definition. + +import typing + +RefinedPromptResponseRefinementStatus = typing.Union[ + typing.Literal["Pending", "InProgress", "Completed", "Failed"], typing.Any +] diff --git a/tests/prompts/test_versions.py b/tests/prompts/test_versions.py index 1226b91a0..aa7c4069d 100644 --- a/tests/prompts/test_versions.py +++ b/tests/prompts/test_versions.py @@ -149,14 +149,47 @@ async def test_update(client: LabelStudio, async_client: AsyncLabelStudio) -> No validate_response(async_response, expected_response, expected_types) +async def test_get_refined_prompt(client: LabelStudio, async_client: AsyncLabelStudio) -> None: + expected_response: typing.Any = { + "title": "title", + "reasoning": "reasoning", + "prompt": "prompt", + "refinement_job_id": "refinement_job_id", + "refinement_status": "Pending", + } + expected_types: typing.Any = { + "title": None, + "reasoning": None, + "prompt": None, + "refinement_job_id": None, + "refinement_status": None, + } + response = client.prompts.versions.get_refined_prompt( + prompt_id=1, version_id=1, refinement_job_id="refinement_job_id" + ) + validate_response(response, expected_response, expected_types) + + async_response = await async_client.prompts.versions.get_refined_prompt( + prompt_id=1, version_id=1, refinement_job_id="refinement_job_id" + ) + validate_response(async_response, expected_response, expected_types) + + async def test_refine_prompt(client: LabelStudio, async_client: AsyncLabelStudio) -> None: expected_response: typing.Any = { "title": "title", "reasoning": "reasoning", "prompt": "prompt", "refinement_job_id": "refinement_job_id", + "refinement_status": "Pending", + } + expected_types: typing.Any = { + "title": None, + "reasoning": None, + "prompt": None, + "refinement_job_id": None, + "refinement_status": None, } - expected_types: typing.Any = {"title": None, "reasoning": None, "prompt": None, "refinement_job_id": None} response = client.prompts.versions.refine_prompt(prompt_id=1, version_id=1) validate_response(response, expected_response, expected_types)