Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Make --overrides more flexible #5399

Merged
merged 7 commits into from
Sep 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Changed

- The behavior of `--overrides` has changed. Previously the final configuration params were simply taken as the union over the original params and the `--overrides` params.
But now you can use `--overrides` to completely replace any part of the original config. For example, passing `--overrides '{"model":{"type":"foo"}}'` will completely
replace the "model" part of the original config. However, when you just want to change a single field in the JSON structure without removing / replacing adjacent fields,
you can still use the "dot" syntax. For example, `--overrides '{"model.num_layers":3}'` will only change the `num_layers` parameter to the "model" part of the config, leaving
everything else unchanged.

### Fixed

- Fixed the implementation of `PairedPCABiasDirection` in `allennlp.fairness.bias_direction`, where the difference vectors should not be centered when performing the PCA.
Expand Down
131 changes: 56 additions & 75 deletions allennlp/common/params.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import copy
from itertools import chain
import json
import logging
import os
import zlib
from collections import OrderedDict
from collections.abc import MutableMapping
from os import PathLike
from typing import Any, Dict, List, Union, Optional
from typing import Any, Dict, List, Union, Optional, TypeVar, Iterable, Set

from overrides import overrides

Expand Down Expand Up @@ -93,87 +94,61 @@ def _environment_variables() -> Dict[str, str]:
return {key: value for key, value in os.environ.items() if _is_encodable(value)}


def unflatten(flat_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Given a "flattened" dict with compound keys, e.g.
{"a.b": 0}
unflatten it:
{"a": {"b": 0}}
"""
unflat: Dict[str, Any] = {}

for compound_key, value in flat_dict.items():
curr_dict = unflat
parts = compound_key.split(".")
for key in parts[:-1]:
curr_value = curr_dict.get(key)
if key not in curr_dict:
curr_dict[key] = {}
curr_dict = curr_dict[key]
elif isinstance(curr_value, dict):
curr_dict = curr_value
else:
raise ConfigurationError("flattened dictionary is invalid")
if not isinstance(curr_dict, dict) or parts[-1] in curr_dict:
raise ConfigurationError("flattened dictionary is invalid")
curr_dict[parts[-1]] = value

return unflat
T = TypeVar("T", dict, list)


def with_fallback(preferred: Dict[str, Any], fallback: Dict[str, Any]) -> Dict[str, Any]:
"""
Deep merge two dicts, preferring values from `preferred`.
"""

def merge(preferred_value: Any, fallback_value: Any) -> Any:
if isinstance(preferred_value, dict) and isinstance(fallback_value, dict):
return with_fallback(preferred_value, fallback_value)
elif isinstance(preferred_value, dict) and isinstance(fallback_value, list):
# treat preferred_value as a sparse list, where each key is an index to be overridden
merged_list = fallback_value
for elem_key, preferred_element in preferred_value.items():
try:
index = int(elem_key)
merged_list[index] = merge(preferred_element, fallback_value[index])
except ValueError:
raise ConfigurationError(
"could not merge dicts - the preferred dict contains "
f"invalid keys (key {elem_key} is not a valid list index)"
)
except IndexError:
raise ConfigurationError(
"could not merge dicts - the preferred dict contains "
f"invalid keys (key {index} is out of bounds)"
)
return merged_list
def with_overrides(original: T, overrides_dict: Dict[str, Any], prefix: str = "") -> T:
merged: T
keys: Union[Iterable[str], Iterable[int]]
if isinstance(original, list):
merged = [None] * len(original)
keys = range(len(original))
elif isinstance(original, dict):
merged = {}
keys = chain(
original.keys(), (k for k in overrides_dict if "." not in k and k not in original)
)
else:
if prefix:
raise ValueError(
f"overrides for '{prefix[:-1]}.*' expected list or dict in original, "
f"found {type(original)} instead"
)
else:
return copy.deepcopy(preferred_value)

preferred_keys = set(preferred.keys())
fallback_keys = set(fallback.keys())
common_keys = preferred_keys & fallback_keys

merged: Dict[str, Any] = {}
raise ValueError(f"expected list or dict, found {type(original)} instead")

for key in preferred_keys - fallback_keys:
merged[key] = copy.deepcopy(preferred[key])
for key in fallback_keys - preferred_keys:
merged[key] = copy.deepcopy(fallback[key])
used_override_keys: Set[str] = set()
for key in keys:
if str(key) in overrides_dict:
merged[key] = copy.deepcopy(overrides_dict[str(key)])
used_override_keys.add(str(key))
else:
overrides_subdict = {}
for o_key in overrides_dict:
if o_key.startswith(f"{key}."):
overrides_subdict[o_key[len(f"{key}.") :]] = overrides_dict[o_key]
used_override_keys.add(o_key)
if overrides_subdict:
merged[key] = with_overrides(
original[key], overrides_subdict, prefix=prefix + f"{key}."
)
else:
merged[key] = copy.deepcopy(original[key])

for key in common_keys:
preferred_value = preferred[key]
fallback_value = fallback[key]
unused_override_keys = [prefix + key for key in set(overrides_dict.keys()) - used_override_keys]
if unused_override_keys:
raise ValueError(f"overrides dict contains unused keys: {unused_override_keys}")

merged[key] = merge(preferred_value, fallback_value)
return merged


def parse_overrides(serialized_overrides: str) -> Dict[str, Any]:
def parse_overrides(
serialized_overrides: str, ext_vars: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
if serialized_overrides:
ext_vars = _environment_variables()
ext_vars = {**_environment_variables(), **(ext_vars or {})}

return unflatten(json.loads(evaluate_snippet("", serialized_overrides, ext_vars=ext_vars)))
return json.loads(evaluate_snippet("", serialized_overrides, ext_vars=ext_vars))
else:
return {}

Expand Down Expand Up @@ -427,7 +402,7 @@ def __getitem__(self, key):
if key in self.params:
return self._check_is_dict(key, self.params[key])
else:
raise KeyError
raise KeyError(str(key))

def __setitem__(self, key, value):
self.params[key] = value
Expand Down Expand Up @@ -468,7 +443,9 @@ def from_file(
params_overrides: `Union[str, Dict[str, Any]]`, optional (default = `""`)

A dict of overrides that can be applied to final object.
e.g. {"model.embedding_dim": 10}
e.g. `{"model.embedding_dim": 10}` will change the value of "embedding_dim"
within the "model" object of the config to 10. If you wanted to override the entire
"model" object of the config, you could do `{"model": {"type": "other_type", ...}}`.

ext_vars: `dict`, optional

Expand All @@ -489,8 +466,12 @@ def from_file(

if isinstance(params_overrides, dict):
params_overrides = json.dumps(params_overrides)
overrides_dict = parse_overrides(params_overrides)
param_dict = with_fallback(preferred=overrides_dict, fallback=file_dict)
overrides_dict = parse_overrides(params_overrides, ext_vars=ext_vars)

if overrides_dict:
param_dict = with_overrides(file_dict, overrides_dict)
else:
param_dict = file_dict

return cls(param_dict)

Expand Down
86 changes: 30 additions & 56 deletions tests/common/params_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
from allennlp.common.params import (
infer_and_cast,
Params,
parse_overrides,
unflatten,
with_fallback,
remove_keys_from_params,
with_overrides,
)
from allennlp.common.testing import AllenNlpTestCase

Expand Down Expand Up @@ -40,13 +38,39 @@ def test_bad_unicode_environment_variables(self):
Params.from_file(filename)
del os.environ["BAD_ENVIRONMENT_VARIABLE"]

def test_with_overrides(self):
original = {
"foo": {"bar": {"baz": 3}, "x": 0},
"bar": ["a", "b", "c"],
"baz": {"bar": 2, "y": 3, "x": [0, 1, 2]},
}
overrides = {
"foo.bar": {"z": 2},
"bar.0": "d",
"baz.bar": 1,
"baz.x": [0, 0],
"z": 2,
}
assert with_overrides(original, overrides) == {
"foo": {"bar": {"z": 2}, "x": 0},
"bar": ["d", "b", "c"],
"baz": {"bar": 1, "y": 3, "x": [0, 0]},
"z": 2,
}

def test_bad_overrides(self):
with pytest.raises(ValueError, match="contains unused keys"):
with_overrides({"foo": [0, 1, 2]}, {"foo.3": 4})
with pytest.raises(ValueError, match="expected list or dict"):
with_overrides({"foo": 3}, {"foo.x": 2})

@pytest.mark.parametrize("input_type", [dict, str])
def test_overrides(self, input_type):
filename = self.FIXTURES_ROOT / "simple_tagger" / "experiment.json"
overrides = {
"train_data_path": "FOO",
"model": {"type": "BAR"},
"model.text_field_embedder.tokens.type": "BAZ",
"model.type": "BAR",
"model.text_field_embedder.token_embedders.tokens.type": "BAZ",
"data_loader.batch_sampler.sorting_keys.0": "question",
}
params = Params.from_file(
Expand All @@ -60,57 +84,7 @@ def test_overrides(self, input_type):

model_params = params.pop("model")
assert model_params.pop("type") == "BAR"
assert model_params["text_field_embedder"]["tokens"]["type"] == "BAZ"

def test_unflatten(self):
flattened = {"a.b.c": 1, "a.b.d": 0, "a.e.f.g.h": 2, "b": 3}
unflattened = unflatten(flattened)
assert unflattened == {"a": {"b": {"c": 1, "d": 0}, "e": {"f": {"g": {"h": 2}}}}, "b": 3}

# should do nothing to a non-flat dictionary
assert unflatten(unflattened) == unflattened

def test_with_fallback(self):
preferred = {"a": 1}
fallback = {"a": 0, "b": 2}

merged = with_fallback(preferred=preferred, fallback=fallback)
assert merged == {"a": 1, "b": 2}

# incompatibility is ok
preferred = {"a": {"c": 3}}
fallback = {"a": 0, "b": 2}
merged = with_fallback(preferred=preferred, fallback=fallback)
assert merged == {"a": {"c": 3}, "b": 2}

# goes deep
preferred = {"deep": {"a": 1}}
fallback = {"deep": {"a": 0, "b": 2}}

merged = with_fallback(preferred=preferred, fallback=fallback)
assert merged == {"deep": {"a": 1, "b": 2}}

def test_parse_overrides(self):
assert parse_overrides("") == {}
assert parse_overrides("{}") == {}

override_dict = parse_overrides('{"train_data": "/train", "trainer.num_epochs": 10}')
assert override_dict == {"train_data": "/train", "trainer": {"num_epochs": 10}}

params = with_fallback(
preferred=override_dict,
fallback={
"train_data": "/test",
"model": "simple_tagger",
"trainer": {"num_epochs": 100, "optimizer": "sgd"},
},
)

assert params == {
"train_data": "/train",
"model": "simple_tagger",
"trainer": {"num_epochs": 10, "optimizer": "sgd"},
}
assert model_params["text_field_embedder"]["token_embedders"]["tokens"]["type"] == "BAZ"

def test_as_flat_dict(self):
params = Params({"a": 10, "b": {"c": 20, "d": "stuff"}}).as_flat_dict()
Expand Down