Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
EugenHotaj committed Jan 11, 2025
1 parent 78b7d0d commit 2bab3ec
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions tests/torchtune/config/test_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
},
"d": 4,
"f": 8,
"g": "foo",
"h": "${g}/bar",
}


Expand All @@ -50,7 +52,9 @@ def test_get_component_from_path(self):
):
_ = _get_component_from_path("torchtune.models.dummy")

@mock.patch("torchtune.config._parse.OmegaConf.load", return_value=_CONFIG)
@mock.patch(
"torchtune.config._parse.OmegaConf.load", return_value=OmegaConf.create(_CONFIG)
)
def test_merge_yaml_and_cli_args(self, mock_load):
parser = TuneRecipeArgumentParser("test parser")
yaml_args, cli_args = parser.parse_known_args(
Expand All @@ -63,6 +67,7 @@ def test_merge_yaml_and_cli_args(self, mock_load):
"d=6", # Test overriding a flat param
"e=7", # Test adding a new param
"~f", # Test removing a param
"g=bazz", # Test interpolation happens after override
]
)
conf = _merge_yaml_and_cli_args(yaml_args, cli_args)
Expand All @@ -75,6 +80,7 @@ def test_merge_yaml_and_cli_args(self, mock_load):
assert conf.d == 6, f"d == {conf.d}, not 6 as set in overrides."
assert conf.e == 7, f"e == {conf.e}, not 7 as set in overrides."
assert "f" not in conf, f"f == {conf.f}, not removed as set in overrides."
assert conf.h == "bazz/bar", f"h == {conf.h}, not bazz/bar as set in overrides."
mock_load.assert_called_once()

yaml_args, cli_args = parser.parse_known_args(
Expand Down Expand Up @@ -185,5 +191,5 @@ def test_remove_key_by_dotpath(self):

# Test removing non-existent param fails
cfg = copy.deepcopy(_CONFIG)
with pytest.raises(KeyError, match="'g'"):
_remove_key_by_dotpath(cfg, "g")
with pytest.raises(KeyError, match="'i'"):
_remove_key_by_dotpath(cfg, "i")

0 comments on commit 2bab3ec

Please sign in to comment.