Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(api): update embeddings signature so inputs and outputs list align #1161

Merged
merged 2 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions docs/_static/llama-stack-spec.html
Original file line number Diff line number Diff line change
Expand Up @@ -4929,11 +4929,21 @@
"description": "The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint."
},
"contents": {
"type": "array",
"items": {
"$ref": "#/components/schemas/InterleavedContent"
},
"description": "List of contents to generate embeddings for. Note that content can be multimodal. The behavior depends on the model and provider. Some models may only support text."
"oneOf": [
{
"type": "array",
"items": {
"type": "string"
}
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/InterleavedContentItem"
}
}
],
"description": "List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text."
}
},
"additionalProperties": false,
Expand Down
16 changes: 10 additions & 6 deletions docs/_static/llama-stack-spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3224,13 +3224,17 @@ components:
The identifier of the model to use. The model must be an embedding model
registered with Llama Stack and available via the /models endpoint.
contents:
type: array
items:
$ref: '#/components/schemas/InterleavedContent'
oneOf:
- type: array
items:
type: string
- type: array
items:
$ref: '#/components/schemas/InterleavedContentItem'
description: >-
List of contents to generate embeddings for. Note that content can be
multimodal. The behavior depends on the model and provider. Some models
may only support text.
List of contents to generate embeddings for. Each content can be a string
or an InterleavedContentItem (and hence can be multimodal). The behavior
depends on the model and provider. Some models may only support text.
additionalProperties: false
required:
- model_id
Expand Down
6 changes: 3 additions & 3 deletions llama_stack/apis/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated

from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
from llama_stack.apis.models import Model
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
from llama_stack.models.llama.datatypes import (
Expand Down Expand Up @@ -481,12 +481,12 @@ async def chat_completion(
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
"""Generate embeddings for content pieces using the specified model.

:param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
:param contents: List of contents to generate embeddings for. Note that content can be multimodal. The behavior depends on the model and provider. Some models may only support text.
:param contents: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text.
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
"""
...
4 changes: 2 additions & 2 deletions llama_stack/distribution/routers/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from typing import Any, AsyncGenerator, Dict, List, Optional

from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.common.content_types import URL, InterleavedContent, InterleavedContentItem
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.eval import (
BenchmarkConfig,
Expand Down Expand Up @@ -214,7 +214,7 @@ async def completion(
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.routing_table.get_model(model_id)
if model is None:
Expand Down
3 changes: 2 additions & 1 deletion llama_stack/providers/inline/inference/vllm/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
CompletionResponseStreamChunk,
EmbeddingsResponse,
Inference,
InterleavedContentItem,
LogProbConfig,
Message,
ResponseFormat,
Expand Down Expand Up @@ -230,5 +231,5 @@ async def _generate_and_convert_to_openai_compat():
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk

async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:
async def embeddings(self, model_id: str, contents: List[str] | List[InterleavedContentItem]) -> EmbeddingsResponse:
raise NotImplementedError()
4 changes: 2 additions & 2 deletions llama_stack/providers/remote/inference/bedrock/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from botocore.client import BaseClient

from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
Expand Down Expand Up @@ -162,7 +162,7 @@ async def _get_params_for_chat_completion(self, request: ChatCompletionRequest)
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embeddings = []
Expand Down
4 changes: 2 additions & 2 deletions llama_stack/providers/remote/inference/cerebras/cerebras.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from cerebras.cloud.sdk import AsyncCerebras

from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
Expand Down Expand Up @@ -172,6 +172,6 @@ async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequ
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
raise NotImplementedError()
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from openai import OpenAI

from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
Expand Down Expand Up @@ -130,7 +130,7 @@ def _get_params(self, request: ChatCompletionRequest) -> dict:

async def embeddings(
self,
model: str,
contents: List[InterleavedContent],
model_id: str,
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
raise NotImplementedError()
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from fireworks.client import Fireworks

from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
Expand Down Expand Up @@ -232,7 +232,7 @@ async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequ
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)

Expand Down
3 changes: 2 additions & 1 deletion llama_stack/providers/remote/inference/groq/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
EmbeddingsResponse,
Inference,
InterleavedContent,
InterleavedContentItem,
LogProbConfig,
Message,
ResponseFormat,
Expand Down Expand Up @@ -140,7 +141,7 @@ async def chat_completion(
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
raise NotImplementedError()

Expand Down
3 changes: 2 additions & 1 deletion llama_stack/providers/remote/inference/ollama/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.inference import (
Expand Down Expand Up @@ -258,7 +259,7 @@ async def _generate_and_convert_to_openai_compat():
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)

Expand Down
5 changes: 3 additions & 2 deletions llama_stack/providers/remote/inference/runpod/runpod.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,10 @@ async def chat_completion(
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
Expand Down Expand Up @@ -119,6 +120,6 @@ def _get_params(self, request: ChatCompletionRequest) -> dict:
async def embeddings(
self,
model: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
raise NotImplementedError()
26 changes: 23 additions & 3 deletions llama_stack/providers/remote/inference/sambanova/sambanova.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,36 @@
# the root directory of this source tree.

import json
from typing import AsyncGenerator
from typing import AsyncGenerator, List, Optional

from openai import OpenAI

from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionMessage,
EmbeddingsResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
StopReason,
SystemMessage,
ToolCall,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
ToolResponseMessage,
UserMessage,
)
from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy,
TopKSamplingStrategy,
Expand Down Expand Up @@ -119,7 +139,7 @@ async def _to_async_generator():
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
raise NotImplementedError()

Expand Down
4 changes: 2 additions & 2 deletions llama_stack/providers/remote/inference/tgi/tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from huggingface_hub import AsyncInferenceClient, HfApi

from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
Expand Down Expand Up @@ -268,7 +268,7 @@ async def _get_params(self, request: ChatCompletionRequest) -> dict:
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
raise NotImplementedError()

Expand Down
4 changes: 2 additions & 2 deletions llama_stack/providers/remote/inference/together/together.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from together import Together

from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
Expand Down Expand Up @@ -219,7 +219,7 @@ async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequ
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
assert all(not content_has_media(content) for content in contents), (
Expand Down
10 changes: 8 additions & 2 deletions llama_stack/providers/remote/inference/vllm/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
from llama_models.datatypes import StopReason, ToolCall
from openai import OpenAI

from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
Expand Down Expand Up @@ -376,7 +382,7 @@ async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequ
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)

Expand Down
4 changes: 2 additions & 2 deletions llama_stack/providers/utils/inference/embedding_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from llama_stack.apis.inference import (
EmbeddingsResponse,
InterleavedContent,
InterleavedContentItem,
ModelStore,
)

Expand All @@ -25,7 +25,7 @@ class SentenceTransformerEmbeddingMixin:
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
Expand Down