Skip to content

Commit

Permalink
[2/x] Support non-OAI providers as LLM judges for GenAI metrics. (mlf…
Browse files Browse the repository at this point in the history
…low#13717)

Signed-off-by: B-Step62 <[email protected]>
Signed-off-by: k99kurella <[email protected]>
  • Loading branch information
B-Step62 authored and karthikkurella committed Jan 30, 2025
1 parent 29ab0af commit cc6b3bf
Show file tree
Hide file tree
Showing 9 changed files with 493 additions and 70 deletions.
4 changes: 4 additions & 0 deletions mlflow/gateway/provider_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def _register_plugin_providers(registry: ProviderRegistry):
registry.register(p.name, cls)


def is_supported_provider(name: str) -> bool:
return name in provider_registry.keys()


provider_registry = ProviderRegistry()
_register_default_providers(provider_registry)
_register_plugin_providers(provider_registry)
40 changes: 26 additions & 14 deletions mlflow/gateway/providers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class AnthropicAdapter(ProviderAdapter):
@classmethod
def chat_to_model(cls, payload, config):
key_mapping = {"stop": "stop_sequences"}
payload["model"] = config.model.name
payload = rename_payload_keys(payload, key_mapping)

if "top_p" in payload and "temperature" in payload:
Expand Down Expand Up @@ -155,6 +156,8 @@ def model_to_completions(cls, resp, config):
def completions_to_model(cls, payload, config):
key_mapping = {"max_tokens": "max_tokens_to_sample", "stop": "stop_sequences"}

payload["model"] = config.model.name

if "top_p" in payload:
raise AIGatewayException(
status_code=422,
Expand Down Expand Up @@ -220,11 +223,29 @@ def __init__(self, config: RouteConfig) -> None:
if config.model.config is None or not isinstance(config.model.config, AnthropicConfig):
raise TypeError(f"Invalid config type {config.model.config}")
self.anthropic_config: AnthropicConfig = config.model.config
self.headers = {

@property
def headers(self) -> dict[str, str]:
return {
"x-api-key": self.anthropic_config.anthropic_api_key,
"anthropic-version": self.anthropic_config.anthropic_version,
}
self.base_url = "https://api.anthropic.com/v1/"

@property
def base_url(self) -> str:
return "https://api.anthropic.com/v1"

@property
def adapter_class(self) -> type[ProviderAdapter]:
return AnthropicAdapter

def get_endpoint_url(self, route_type: str) -> str:
if route_type == "llm/v1/chat":
return f"{self.base_url}/messages"
elif route_type == "llm/v1/completions":
return f"{self.base_url}/complete"
else:
raise ValueError(f"Invalid route type {route_type}")

async def chat_stream(
self, payload: chat.RequestPayload
Expand All @@ -237,10 +258,7 @@ async def chat_stream(
headers=self.headers,
base_url=self.base_url,
path="messages",
payload={
"model": self.config.model.name,
**AnthropicAdapter.chat_streaming_to_model(payload, self.config),
},
payload=AnthropicAdapter.chat_streaming_to_model(payload, self.config),
)

indices = []
Expand Down Expand Up @@ -294,10 +312,7 @@ async def chat(self, payload: chat.RequestPayload) -> chat.ResponsePayload:
headers=self.headers,
base_url=self.base_url,
path="messages",
payload={
"model": self.config.model.name,
**AnthropicAdapter.chat_to_model(payload, self.config),
},
payload=AnthropicAdapter.chat_to_model(payload, self.config),
)
return AnthropicAdapter.model_to_chat(resp, self.config)

Expand All @@ -311,10 +326,7 @@ async def completions(self, payload: completions.RequestPayload) -> completions.
headers=self.headers,
base_url=self.base_url,
path="complete",
payload={
"model": self.config.model.name,
**AnthropicAdapter.completions_to_model(payload, self.config),
},
payload=AnthropicAdapter.completions_to_model(payload, self.config),
)

# Example response:
Expand Down
20 changes: 15 additions & 5 deletions mlflow/gateway/providers/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@


class AmazonBedrockAnthropicAdapter(AnthropicAdapter):
@classmethod
def chat_to_model(cls, payload, config):
payload = super().chat_to_model(payload, config)
# "model" keys are not supported in Bedrock"
payload.pop("model", None)
return payload

@classmethod
def completions_to_model(cls, payload, config):
payload = super().completions_to_model(payload, config)
Expand All @@ -28,6 +35,9 @@ def completions_to_model(cls, payload, config):
payload.get("max_tokens_to_sample", MLFLOW_AI_GATEWAY_ANTHROPIC_DEFAULT_MAX_TOKENS),
AWS_BEDROCK_ANTHROPIC_MAXIMUM_MAX_TOKENS,
)

# "model" keys are not supported in Bedrock"
payload.pop("model", None)
return payload

@classmethod
Expand Down Expand Up @@ -137,7 +147,7 @@ class AmazonBedrockModelProvider(Enum):
ANTHROPIC = "anthropic"

@property
def adapter(self):
def adapter_class(self) -> type[ProviderAdapter]:
return AWS_MODEL_PROVIDER_TO_ADAPTER.get(self)

@classmethod
Expand Down Expand Up @@ -243,14 +253,14 @@ def _underlying_provider(self):
return AmazonBedrockModelProvider.of_str(provider)

@property
def underlying_provider_adapter(self) -> ProviderAdapter:
def adapter_class(self) -> type[ProviderAdapter]:
provider = self._underlying_provider
if not provider:
raise AIGatewayException(
status_code=422,
detail=f"Unknown Amazon Bedrock model type {self._underlying_provider}",
)
adapter = provider.adapter
adapter = provider.adapter_class
if not adapter:
raise AIGatewayException(
status_code=422,
Expand Down Expand Up @@ -284,6 +294,6 @@ async def completions(self, payload: completions.RequestPayload) -> completions.

self.check_for_model_field(payload)
payload = jsonable_encoder(payload, exclude_none=True, exclude_defaults=True)
payload = self.underlying_provider_adapter.completions_to_model(payload, self.config)
payload = self.adapter_class.completions_to_model(payload, self.config)
response = self._request(payload)
return self.underlying_provider_adapter.model_to_completions(response, self.config)
return self.adapter_class.model_to_completions(response, self.config)
20 changes: 17 additions & 3 deletions mlflow/gateway/providers/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,24 +336,38 @@ def __init__(self, config: RouteConfig) -> None:
self.cohere_config: CohereConfig = config.model.config

@property
def auth_headers(self) -> dict[str, str]:
def headers(self) -> dict[str, str]:
return {"Authorization": f"Bearer {self.cohere_config.cohere_api_key}"}

@property
def base_url(self) -> str:
return "https://api.cohere.ai/v1"

@property
def adapter_class(self) -> type[ProviderAdapter]:
return CohereAdapter

def get_endpoint_url(self, route_type: str) -> str:
if route_type == "llm/v1/chat":
return f"{self.base_url}/chat"
elif route_type == "llm/v1/completions":
return f"{self.base_url}/generate"
elif route_type == "llm/v1/embeddings":
return f"{self.base_url}/embed"
else:
raise ValueError(f"Invalid route type {route_type}")

async def _request(self, path: str, payload: dict[str, Any]) -> dict[str, Any]:
return await send_request(
headers=self.auth_headers,
headers=self.headers,
base_url=self.base_url,
path=path,
payload=payload,
)

def _stream_request(self, path: str, payload: dict[str, Any]) -> AsyncGenerator[bytes, None]:
return send_stream_request(
headers=self.auth_headers,
headers=self.headers,
base_url=self.base_url,
path=path,
payload=payload,
Expand Down
65 changes: 52 additions & 13 deletions mlflow/gateway/providers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from mlflow.gateway.config import MistralConfig, RouteConfig
from mlflow.gateway.providers.base import BaseProvider, ProviderAdapter
from mlflow.gateway.providers.utils import send_request
from mlflow.gateway.schemas import completions, embeddings
from mlflow.gateway.schemas import chat, completions, embeddings


class MistralAdapter(ProviderAdapter):
Expand Down Expand Up @@ -54,6 +54,36 @@ def model_to_completions(cls, resp, config):
),
)

@classmethod
def model_to_chat(cls, resp, config):
# Response example (https://docs.mistral.ai/api/#operation/createChatCompletion)
return chat.ResponsePayload(
id=resp["id"],
object=resp["object"],
created=resp["created"],
model=resp["model"],
choices=[
chat.Choice(
index=idx,
message=chat.ResponseMessage(
role=c["message"]["role"],
content=c["message"].get("content"),
tool_calls=(
(calls := c["message"].get("tool_calls"))
and [chat.ToolCall(**c) for c in calls]
),
),
finish_reason=c.get("finish_reason"),
)
for idx, c in enumerate(resp["choices"])
],
usage=chat.ChatUsage(
prompt_tokens=resp["usage"]["prompt_tokens"],
completion_tokens=resp["usage"]["completion_tokens"],
total_tokens=resp["usage"]["total_tokens"],
),
)

@classmethod
def model_to_embeddings(cls, resp, config):
# Response example (https://docs.mistral.ai/api/#operation/createEmbedding):
Expand Down Expand Up @@ -97,6 +127,7 @@ def model_to_embeddings(cls, resp, config):

@classmethod
def completions_to_model(cls, payload, config):
payload["model"] = config.model.name
payload.pop("stop", None)
payload.pop("n", None)
payload["messages"] = [{"role": "user", "content": payload.pop("prompt")}]
Expand All @@ -107,9 +138,13 @@ def completions_to_model(cls, payload, config):

return payload

@classmethod
def chat_to_model(cls, payload, config):
return {"model": config.model.name, **payload}

@classmethod
def embeddings_to_model(cls, payload, config):
return payload
return {"model": config.model.name, **payload}


class MistralProvider(BaseProvider):
Expand All @@ -123,16 +158,26 @@ def __init__(self, config: RouteConfig) -> None:
self.mistral_config: MistralConfig = config.model.config

@property
def auth_headers(self) -> dict[str, str]:
def headers(self) -> dict[str, str]:
return {"Authorization": f"Bearer {self.mistral_config.mistral_api_key}"}

@property
def base_url(self) -> str:
return "https://api.mistral.ai/v1/"
return "https://api.mistral.ai/v1"

@property
def adapter_class(self) -> type[ProviderAdapter]:
return MistralAdapter

def get_endpoint_url(self, route_type: str) -> str:
if route_type == "llm/v1/chat":
return f"{self.base_url}/chat/completions"
else:
raise ValueError(f"Invalid route type {route_type}")

async def _request(self, path: str, payload: dict[str, Any]) -> dict[str, Any]:
return await send_request(
headers=self.auth_headers,
headers=self.headers,
base_url=self.base_url,
path=path,
payload=payload,
Expand All @@ -145,10 +190,7 @@ async def completions(self, payload: completions.RequestPayload) -> completions.
self.check_for_model_field(payload)
resp = await self._request(
"chat/completions",
{
"model": self.config.model.name,
**MistralAdapter.completions_to_model(payload, self.config),
},
MistralAdapter.completions_to_model(payload, self.config),
)
return MistralAdapter.model_to_completions(resp, self.config)

Expand All @@ -159,9 +201,6 @@ async def embeddings(self, payload: embeddings.RequestPayload) -> embeddings.Res
self.check_for_model_field(payload)
resp = await self._request(
"embeddings",
{
"model": self.config.model.name,
**MistralAdapter.embeddings_to_model(payload, self.config),
},
MistralAdapter.embeddings_to_model(payload, self.config),
)
return MistralAdapter.model_to_embeddings(resp, self.config)
Loading

0 comments on commit cc6b3bf

Please sign in to comment.