Skip to content

Commit

Permalink
bug fix: Fix the way we parse providers for during reading client-reg…
Browse files Browse the repository at this point in the history
…istry

Prior: We used deserialize on an descriminated union
Now: We use the FromStr capability so a user can use the strings we document
  • Loading branch information
hellovai committed Feb 7, 2025
1 parent 62fba91 commit 0aa1115
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 8 deletions.
9 changes: 4 additions & 5 deletions engine/baml-lib/llm-client/src/clientspec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ use anyhow::Result;
use std::collections::HashSet;

use baml_types::{GetEnvVar, StringOr};
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, Serialize)]
#[derive(Clone, Debug)]
pub enum ClientSpec {
Named(String),
/// Shorthand for "<provider>/<model>"
Expand All @@ -30,7 +29,7 @@ impl ClientSpec {
}

/// The provider for the client, e.g. baml-openai-chat
#[derive(Clone, Debug, Deserialize, Serialize)]
#[derive(Clone, Debug)]
pub enum ClientProvider {
/// The OpenAI client provider variant
OpenAI(OpenAIClientProviderVariant),
Expand All @@ -47,7 +46,7 @@ pub enum ClientProvider {
}

/// The OpenAI client provider variant
#[derive(Clone, Debug, Deserialize, Serialize)]
#[derive(Clone, Debug)]
pub enum OpenAIClientProviderVariant {
/// The base OpenAI client provider variant
Base,
Expand All @@ -60,7 +59,7 @@ pub enum OpenAIClientProviderVariant {
}

/// The strategy client provider variant
#[derive(Clone, Debug, Deserialize, Serialize)]
#[derive(Clone, Debug)]
pub enum StrategyClientProvider {
/// The round-robin strategy client provider variant
RoundRobin,
Expand Down
11 changes: 10 additions & 1 deletion engine/baml-runtime/src/client_registry/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use anyhow::{Context, Result};
pub use internal_llm_client::ClientProvider;
use internal_llm_client::{ClientSpec, PropertyHandler, UnresolvedClientProperty};
use std::collections::HashMap;
use std::{collections::HashMap, str::FromStr};
use std::sync::Arc;

use baml_types::{BamlMap, BamlValue};
Expand All @@ -21,6 +21,7 @@ pub enum PrimitiveClient {
#[derive(Clone, Deserialize, Debug)]
pub struct ClientProperty {
pub name: String,
#[serde(deserialize_with = "deserialize_client_provider")]
pub provider: ClientProvider,
pub retry_policy: Option<String>,
options: BamlMap<String, BamlValue>,
Expand Down Expand Up @@ -126,3 +127,11 @@ where
.map(|client: ClientProperty| (client.name.clone(), client))
.collect())
}

fn deserialize_client_provider<'de, D>(deserializer: D) -> Result<ClientProvider, D::Error>
where
D: Deserializer<'de>,
{
let s: &str = Deserialize::deserialize(deserializer)?;
ClientProvider::from_str(s).map_err(|e| serde::de::Error::custom(e.to_string()))
}
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ impl OrchestrationScope {
}
}

#[derive(Clone, Debug, Serialize)]
#[derive(Clone, Debug)]
pub enum ExecutionScope {
Direct(String),
// PolicyName, RetryCount, RetryDelayMs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use crate::{
use serde::Serialize;
use serde::Serializer;

#[derive(Debug, Serialize)]
#[derive(Debug)]
pub struct RoundRobinStrategy {
pub name: String,
pub(super) retry_policy: Option<String>,
Expand Down

0 comments on commit 0aa1115

Please sign in to comment.