Skip to content

Commit

Permalink
Typing tweaks (#696)
Browse files Browse the repository at this point in the history
Some type hint and alias tweaks for Tunables and a few others.

Split off from #690.
  • Loading branch information
bpkroth authored Mar 6, 2024
1 parent de9e35d commit 5d83f55
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 34 deletions.
2 changes: 1 addition & 1 deletion mlos_bench/mlos_bench/optimizers/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def target(self) -> str:
return self._opt_target

@property
def direction(self) -> str:
def direction(self) -> Literal['min', 'max']:
"""
The direction to optimize the target metric (e.g., min or max).
"""
Expand Down
5 changes: 3 additions & 2 deletions mlos_bench/mlos_bench/storage/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import pandas

from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable import TunableValue, TunableValueTypeTuple
from mlos_bench.util import try_parse_val


Expand All @@ -31,8 +31,9 @@ def kv_df_to_dict(dataframe: pandas.DataFrame) -> Dict[str, Optional[TunableValu
assert dataframe.columns.tolist() == ['parameter', 'value']
data = {}
for _, row in dataframe.astype('O').iterrows():
if not isinstance(row['value'], TunableValueTypeTuple):
raise TypeError(f"Invalid column type: {type(row['value'])} value: {row['value']}")
assert isinstance(row['parameter'], str)
assert row['value'] is None or isinstance(row['value'], (str, int, float))
if row['parameter'] in data:
raise ValueError(f"Duplicate parameter '{row['parameter']}' in dataframe")
data[row['parameter']] = try_parse_val(row['value']) if isinstance(row['value'], str) else row['value']
Expand Down
36 changes: 18 additions & 18 deletions mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import json5 as json
import pytest

from mlos_bench.tunables.tunable import Tunable
from mlos_bench.tunables.tunable import Tunable, TunableValueTypeName


def test_tunable_name() -> None:
Expand Down Expand Up @@ -135,7 +135,7 @@ def test_categorical_tunable_disallow_repeats() -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_tunable_disallow_null_default(tunable_type: str) -> None:
def test_numerical_tunable_disallow_null_default(tunable_type: TunableValueTypeName) -> None:
"""
Disallow null values as default for numerical tunables.
"""
Expand All @@ -148,7 +148,7 @@ def test_numerical_tunable_disallow_null_default(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_tunable_disallow_out_of_range(tunable_type: str) -> None:
def test_numerical_tunable_disallow_out_of_range(tunable_type: TunableValueTypeName) -> None:
"""
Disallow out of range values as default for numerical tunables.
"""
Expand All @@ -161,7 +161,7 @@ def test_numerical_tunable_disallow_out_of_range(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_tunable_wrong_params(tunable_type: str) -> None:
def test_numerical_tunable_wrong_params(tunable_type: TunableValueTypeName) -> None:
"""
Disallow values param for numerical tunables.
"""
Expand All @@ -175,7 +175,7 @@ def test_numerical_tunable_wrong_params(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_tunable_required_params(tunable_type: str) -> None:
def test_numerical_tunable_required_params(tunable_type: TunableValueTypeName) -> None:
"""
Disallow null values param for numerical tunables.
"""
Expand All @@ -192,7 +192,7 @@ def test_numerical_tunable_required_params(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_tunable_invalid_range(tunable_type: str) -> None:
def test_numerical_tunable_invalid_range(tunable_type: TunableValueTypeName) -> None:
"""
Disallow invalid range param for numerical tunables.
"""
Expand All @@ -209,7 +209,7 @@ def test_numerical_tunable_invalid_range(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_tunable_reversed_range(tunable_type: str) -> None:
def test_numerical_tunable_reversed_range(tunable_type: TunableValueTypeName) -> None:
"""
Disallow reverse range param for numerical tunables.
"""
Expand All @@ -226,7 +226,7 @@ def test_numerical_tunable_reversed_range(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_weights(tunable_type: str) -> None:
def test_numerical_weights(tunable_type: TunableValueTypeName) -> None:
"""
Instantiate a numerical tunable with weighted special values.
"""
Expand All @@ -248,7 +248,7 @@ def test_numerical_weights(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_quantization(tunable_type: str) -> None:
def test_numerical_quantization(tunable_type: TunableValueTypeName) -> None:
"""
Instantiate a numerical tunable with quantization.
"""
Expand All @@ -267,7 +267,7 @@ def test_numerical_quantization(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_log(tunable_type: str) -> None:
def test_numerical_log(tunable_type: TunableValueTypeName) -> None:
"""
Instantiate a numerical tunable with log scale.
"""
Expand All @@ -285,7 +285,7 @@ def test_numerical_log(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_weights_no_specials(tunable_type: str) -> None:
def test_numerical_weights_no_specials(tunable_type: TunableValueTypeName) -> None:
"""
Raise an error if special_weights are specified but no special values.
"""
Expand All @@ -303,7 +303,7 @@ def test_numerical_weights_no_specials(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_weights_non_normalized(tunable_type: str) -> None:
def test_numerical_weights_non_normalized(tunable_type: TunableValueTypeName) -> None:
"""
Instantiate a numerical tunable with non-normalized weights
of the special values.
Expand All @@ -326,7 +326,7 @@ def test_numerical_weights_non_normalized(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_weights_wrong_count(tunable_type: str) -> None:
def test_numerical_weights_wrong_count(tunable_type: TunableValueTypeName) -> None:
"""
Try to instantiate a numerical tunable with incorrect number of weights.
"""
Expand All @@ -346,7 +346,7 @@ def test_numerical_weights_wrong_count(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_weights_no_range_weight(tunable_type: str) -> None:
def test_numerical_weights_no_range_weight(tunable_type: TunableValueTypeName) -> None:
"""
Try to instantiate a numerical tunable with weights but no range_weight.
"""
Expand All @@ -365,7 +365,7 @@ def test_numerical_weights_no_range_weight(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_range_weight_no_weights(tunable_type: str) -> None:
def test_numerical_range_weight_no_weights(tunable_type: TunableValueTypeName) -> None:
"""
Try to instantiate a numerical tunable with specials but no range_weight.
"""
Expand All @@ -384,7 +384,7 @@ def test_numerical_range_weight_no_weights(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_range_weight_no_specials(tunable_type: str) -> None:
def test_numerical_range_weight_no_specials(tunable_type: TunableValueTypeName) -> None:
"""
Try to instantiate a numerical tunable with specials but no range_weight.
"""
Expand All @@ -402,7 +402,7 @@ def test_numerical_range_weight_no_specials(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_weights_wrong_values(tunable_type: str) -> None:
def test_numerical_weights_wrong_values(tunable_type: TunableValueTypeName) -> None:
"""
Try to instantiate a numerical tunable with incorrect number of weights.
"""
Expand All @@ -422,7 +422,7 @@ def test_numerical_weights_wrong_values(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_quantization_wrong(tunable_type: str) -> None:
def test_numerical_quantization_wrong(tunable_type: TunableValueTypeName) -> None:
"""
Instantiate a numerical tunable with invalid number of quantization points.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import json5 as json
import pytest

from mlos_bench.tunables.tunable import Tunable
from mlos_bench.tunables.tunable import Tunable, TunableValueTypeName


def test_categorical_distribution() -> None:
Expand All @@ -28,7 +28,7 @@ def test_categorical_distribution() -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_distribution_uniform(tunable_type: str) -> None:
def test_numerical_distribution_uniform(tunable_type: TunableValueTypeName) -> None:
"""
Create a numeric Tunable with explicit uniform distribution.
"""
Expand All @@ -46,7 +46,7 @@ def test_numerical_distribution_uniform(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_distribution_normal(tunable_type: str) -> None:
def test_numerical_distribution_normal(tunable_type: TunableValueTypeName) -> None:
"""
Create a numeric Tunable with explicit Gaussian distribution specified.
"""
Expand All @@ -67,7 +67,7 @@ def test_numerical_distribution_normal(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_distribution_beta(tunable_type: str) -> None:
def test_numerical_distribution_beta(tunable_type: TunableValueTypeName) -> None:
"""
Create a numeric Tunable with explicit Beta distribution specified.
"""
Expand Down
8 changes: 4 additions & 4 deletions mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,19 +162,19 @@ def test_tunable_assign_null_to_int(tunable_int: Tunable) -> None:
"""
Checks that we can't use null/None in integer tunables.
"""
with pytest.raises(TypeError):
with pytest.raises((TypeError, AssertionError)):
tunable_int.value = None
with pytest.raises(TypeError):
with pytest.raises((TypeError, AssertionError)):
tunable_int.numerical_value = None # type: ignore[assignment]


def test_tunable_assign_null_to_float(tunable_float: Tunable) -> None:
"""
Checks that we can't use null/None in float tunables.
"""
with pytest.raises(TypeError):
with pytest.raises((TypeError, AssertionError)):
tunable_float.value = None
with pytest.raises(TypeError):
with pytest.raises((TypeError, AssertionError)):
tunable_float.numerical_value = None # type: ignore[assignment]


Expand Down
27 changes: 22 additions & 5 deletions mlos_bench/mlos_bench/tunables/tunable.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,22 @@
"""A tunable parameter value type alias."""
TunableValue = Union[int, float, Optional[str]]

"""Tunable value type."""
TunableValueType = Union[Type[int], Type[float], Type[str]]

"""
Tunable value type tuple.
For checking with isinstance()
"""
TunableValueTypeTuple = (int, float, str, type(None))

"""The string name of a tunable value type."""
TunableValueTypeName = Literal["int", "float", "categorical"]

"""Tunable values dictionary type"""
TunableValuesDict = Dict[str, TunableValue]

"""Tunable value distribution type"""
DistributionName = Literal["uniform", "normal", "beta"]


Expand All @@ -38,7 +54,7 @@ class TunableDict(TypedDict, total=False):
These are the types expected to be received from the json config.
"""

type: str
type: TunableValueTypeName
description: Optional[str]
default: TunableValue
values: Optional[List[Optional[str]]]
Expand All @@ -59,7 +75,7 @@ class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-m
"""

# Maps tunable types to their corresponding Python types by name.
_DTYPE: Dict[str, Type] = {
_DTYPE: Dict[TunableValueTypeName, TunableValueType] = {
"int": int,
"float": float,
"categorical": str,
Expand All @@ -79,7 +95,7 @@ def __init__(self, name: str, config: TunableDict):
if not isinstance(name, str) or '!' in name: # TODO: Use a regex here and in JSON schema
raise ValueError(f"Invalid name of the tunable: {name}")
self._name = name
self._type = config["type"] # required
self._type: TunableValueTypeName = config["type"] # required
if self._type not in self._DTYPE:
raise ValueError(f"Invalid parameter type: {self._type}")
self._description = config.get("description")
Expand Down Expand Up @@ -302,6 +318,7 @@ def value(self, value: TunableValue) -> TunableValue:
if self.is_categorical and value is None:
coerced_value = None
else:
assert value is not None
coerced_value = self.dtype(value)
except Exception:
_LOG.error("Impossible conversion: %s %s <- %s %s",
Expand Down Expand Up @@ -482,7 +499,7 @@ def range_weight(self) -> Optional[float]:
return self._range_weight

@property
def type(self) -> str:
def type(self) -> TunableValueTypeName:
"""
Get the data type of the tunable.
Expand All @@ -494,7 +511,7 @@ def type(self) -> str:
return self._type

@property
def dtype(self) -> Type:
def dtype(self) -> TunableValueType:
"""
Get the actual Python data type of the tunable.
Expand Down

0 comments on commit 5d83f55

Please sign in to comment.