Skip to content

Commit

Permalink
fix!: Remove max_tokens and temperature as top level model config keys
Browse files Browse the repository at this point in the history
  • Loading branch information
keelerm84 committed Nov 22, 2024
1 parent 319f64d commit 55f34fe
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 53 deletions.
40 changes: 7 additions & 33 deletions ldai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,13 @@ class ModelConfig:
Configuration related to the model.
"""

def __init__(self, id: str, temperature: Optional[float] = None,
max_tokens: Optional[int] = None, attributes: dict = {}):
def __init__(self, id: str, parameters: dict = {}):
"""
:param id: The ID of the model.
:param temperature: Turning parameter for randomness versus determinism. Exact effect will be determined by the model.
:param max_tokens: The maximum number of tokens.
:param attributes: Additional model-specific attributes.
:param parameters: Additional model-specific parameters.
"""
self._id = id
self._temperature = temperature
self._max_tokens = max_tokens
self._attributes = attributes
self._parameters = parameters

@property
def id(self) -> str:
Expand All @@ -41,36 +36,17 @@ def id(self) -> str:
"""
return self._id

@property
def temperature(self) -> Optional[float]:
""""
Turning parameter for randomness versus determinism. Exact effect will be determined by the model.
"""
return self._temperature

@property
def max_tokens(self) -> Optional[int]:
"""
The maximum number of tokens.
"""

return self._max_tokens

def get_attribute(self, key: str) -> Any:
def get_parameter(self, key: str) -> Any:
"""
Retrieve model-specific attributes.
Retrieve model-specific parameters.
Accessing a named, typed attribute (e.g. id) will result in the call
being delegated to the appropriate property.
"""
if key == 'id':
return self.id
if key == 'temperature':
return self.temperature
if key == 'maxTokens':
return self.max_tokens

return self._attributes.get(key)
return self._parameters.get(key)


class ProviderConfig:
Expand Down Expand Up @@ -150,9 +126,7 @@ def config(
if 'model' in variation:
model = ModelConfig(
id=variation['model']['modelId'],
temperature=variation['model'].get('temperature'),
max_tokens=variation['model'].get('maxTokens'),
attributes=variation['model'],
parameters=variation['model'],
)

enabled = variation.get('_ldMeta', {}).get('enabled', False)
Expand Down
35 changes: 15 additions & 20 deletions ldai/testing/test_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,17 +109,12 @@ def ldai_client(client: LDClient) -> LDAIClient:


def test_model_config_delegates_to_properties():
model = ModelConfig('fakeModel', temperature=0.5, max_tokens=4096, attributes={'extra-attribute': 'value'})
model = ModelConfig('fakeModel', parameters={'extra-attribute': 'value'})
assert model.id == 'fakeModel'
assert model.temperature == 0.5
assert model.max_tokens == 4096
assert model.get_attribute('extra-attribute') == 'value'
assert model.get_attribute('non-existent') is None
assert model.get_parameter('extra-attribute') == 'value'
assert model.get_parameter('non-existent') is None

assert model.id == model.get_attribute('id')
assert model.temperature == model.get_attribute('temperature')
assert model.max_tokens == model.get_attribute('maxTokens')
assert model.max_tokens != model.get_attribute('max_tokens')
assert model.id == model.get_parameter('id')


def test_model_config_interpolation(ldai_client: LDAIClient, tracker):
Expand All @@ -141,8 +136,8 @@ def test_model_config_interpolation(ldai_client: LDAIClient, tracker):

assert config.model is not None
assert config.model.id == 'fakeModel'
assert config.model.temperature == 0.5
assert config.model.max_tokens == 4096
assert config.model.get_parameter('temperature') == 0.5
assert config.model.get_parameter('maxTokens') == 4096


def test_model_config_no_variables(ldai_client: LDAIClient, tracker):
Expand All @@ -158,8 +153,8 @@ def test_model_config_no_variables(ldai_client: LDAIClient, tracker):

assert config.model is not None
assert config.model.id == 'fakeModel'
assert config.model.temperature == 0.5
assert config.model.max_tokens == 4096
assert config.model.get_parameter('temperature') == 0.5
assert config.model.get_parameter('maxTokens') == 4096


def test_provider_config_handling(ldai_client: LDAIClient, tracker):
Expand Down Expand Up @@ -189,9 +184,9 @@ def test_context_interpolation(ldai_client: LDAIClient, tracker):

assert config.model is not None
assert config.model.id == 'fakeModel'
assert config.model.temperature is None
assert config.model.max_tokens is None
assert config.model.get_attribute('extra-attribute') == 'I can be anything I set my mind/type to'
assert config.model.get_parameter('temperature') is None
assert config.model.get_parameter('maxTokens') is None
assert config.model.get_parameter('extra-attribute') == 'I can be anything I set my mind/type to'


def test_model_config_multiple(ldai_client: LDAIClient, tracker):
Expand All @@ -211,8 +206,8 @@ def test_model_config_multiple(ldai_client: LDAIClient, tracker):

assert config.model is not None
assert config.model.id == 'fakeModel'
assert config.model.temperature == 0.7
assert config.model.max_tokens == 8192
assert config.model.get_parameter('temperature') == 0.7
assert config.model.get_parameter('maxTokens') == 8192


def test_model_config_disabled(ldai_client: LDAIClient, tracker):
Expand All @@ -224,8 +219,8 @@ def test_model_config_disabled(ldai_client: LDAIClient, tracker):
assert config.model is not None
assert config.enabled is False
assert config.model.id == 'fakeModel'
assert config.model.temperature == 0.1
assert config.model.max_tokens is None
assert config.model.get_parameter('temperature') == 0.1
assert config.model.get_parameter('maxTokens') is None


def test_model_initial_config_disabled(ldai_client: LDAIClient, tracker):
Expand Down

0 comments on commit 55f34fe

Please sign in to comment.