From e2b1fb1f64d30e593a45883ef7cfa92b74a26b88 Mon Sep 17 00:00:00 2001 From: dtrifiro <36171005+dtrifiro@users.noreply.github.com> Date: Fri, 24 Jul 2020 18:09:04 +0200 Subject: [PATCH] add toml support for ParamsDependency (#4258) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add toml support for ParamsDependency * add toml support to dvc show Co-authored-by: Daniele Trifirò --- dvc/dependency/param.py | 9 ++++++-- dvc/repo/params/show.py | 7 ++++-- setup.cfg | 2 +- setup.py | 1 + tests/func/params/test_show.py | 10 +++++++++ tests/unit/dependency/test_params.py | 32 ++++++++++++++++++++++++++++ 6 files changed, 56 insertions(+), 5 deletions(-) diff --git a/dvc/dependency/param.py b/dvc/dependency/param.py index 12860e3475..35fe44f033 100644 --- a/dvc/dependency/param.py +++ b/dvc/dependency/param.py @@ -2,6 +2,7 @@ from collections import defaultdict import dpath.util +import toml import yaml from voluptuous import Any @@ -21,6 +22,8 @@ class ParamsDependency(LocalDependency): PARAM_PARAMS = "params" PARAM_SCHEMA = {PARAM_PARAMS: Any(dict, list, None)} DEFAULT_PARAMS_FILE = "params.yaml" + PARAMS_FILE_LOADERS = defaultdict(lambda: yaml.safe_load) + PARAMS_FILE_LOADERS.update({".toml": toml.load}) def __init__(self, stage, path, params): info = {} @@ -87,8 +90,10 @@ def read_params(self): with self.repo.tree.open(self.path_info, "r") as fobj: try: - config = yaml.safe_load(fobj) - except yaml.YAMLError as exc: + config = self.PARAMS_FILE_LOADERS[ + self.path_info.suffix.lower() + ](fobj) + except (yaml.YAMLError, toml.TomlDecodeError) as exc: raise BadParamFileError( f"Unable to read parameters from '{self}'" ) from exc diff --git a/dvc/repo/params/show.py b/dvc/repo/params/show.py index 058125ce86..e60aebaedf 100644 --- a/dvc/repo/params/show.py +++ b/dvc/repo/params/show.py @@ -1,5 +1,6 @@ import logging +import toml import yaml from dvc.dependency.param import ParamsDependency @@ -34,8 +35,10 @@ def _read_params(repo, configs, rev): with repo.tree.open(config, "r") as fobj: try: - res[str(config)] = yaml.safe_load(fobj) - except yaml.YAMLError: + res[str(config)] = ParamsDependency.PARAMS_FILE_LOADERS[ + config.suffix.lower() + ](fobj) + except (yaml.YAMLError, toml.TomlDecodeError): logger.debug( "failed to read '%s' on '%s'", config, rev, exc_info=True ) diff --git a/setup.cfg b/setup.cfg index 826c5f64fe..c794a46e17 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,7 +17,7 @@ count=true [isort] include_trailing_comma=true known_first_party=dvc,tests -known_third_party=PyInstaller,RangeHTTPServer,boto3,colorama,configobj,distro,dpath,flaky,flufl,funcy,git,grandalf,mock,moto,nanotime,networkx,packaging,pathspec,pygtrie,pylint,pytest,requests,ruamel,setuptools,shortuuid,shtab,tqdm,voluptuous,yaml,zc +known_third_party=PyInstaller,RangeHTTPServer,boto3,colorama,configobj,distro,dpath,flaky,flufl,funcy,git,grandalf,mock,moto,nanotime,networkx,packaging,pathspec,pygtrie,pylint,pytest,requests,ruamel,setuptools,shortuuid,shtab,toml,tqdm,voluptuous,yaml,zc line_length=79 force_grid_wrap=0 use_parentheses=True diff --git a/setup.py b/setup.py index 6db9a1c0c9..875ab04d1d 100644 --- a/setup.py +++ b/setup.py @@ -62,6 +62,7 @@ def run(self): "appdirs>=1.4.3", "PyYAML>=5.1.2,<5.4", # Compatibility with awscli "ruamel.yaml>=0.16.1", + "toml>=0.10.1", "funcy>=1.14", "pathspec>=0.6.0", "shortuuid>=0.5.0", diff --git a/tests/func/params/test_show.py b/tests/func/params/test_show.py index 93ccdee96a..fd4621fb77 100644 --- a/tests/func/params/test_show.py +++ b/tests/func/params/test_show.py @@ -14,6 +14,16 @@ def test_show(tmp_dir, dvc): assert dvc.params.show() == {"": {"params.yaml": {"foo": "bar"}}} +def test_show_toml(tmp_dir, dvc): + tmp_dir.gen("params.toml", "[foo]\nbar = 42\nbaz = [1, 2]\n") + dvc.run( + cmd="echo params.toml", params=["params.toml:foo"], single_stage=True + ) + assert dvc.params.show() == { + "": {"params.toml": {"foo": {"bar": 42, "baz": [1, 2]}}} + } + + def test_show_multiple(tmp_dir, dvc): tmp_dir.gen("params.yaml", "foo: bar\nbaz: qux\n") dvc.run( diff --git a/tests/unit/dependency/test_params.py b/tests/unit/dependency/test_params.py index 2667feabb6..d8992d747e 100644 --- a/tests/unit/dependency/test_params.py +++ b/tests/unit/dependency/test_params.py @@ -1,4 +1,5 @@ import pytest +import toml import yaml from dvc.dependency import ParamsDependency, loadd_from, loads_params @@ -99,6 +100,37 @@ def test_read_params_nested(tmp_dir, dvc): assert dep.read_params() == {"some.path.foo": ["val1", "val2"]} +def test_read_params_default_loader(tmp_dir, dvc): + parameters_file = "parameters.foo" + tmp_dir.gen( + parameters_file, + yaml.dump({"some": {"path": {"foo": ["val1", "val2"]}}}), + ) + dep = ParamsDependency(Stage(dvc), parameters_file, ["some.path.foo"]) + assert dep.read_params() == {"some.path.foo": ["val1", "val2"]} + + +def test_read_params_wrong_suffix(tmp_dir, dvc): + parameters_file = "parameters.toml" + tmp_dir.gen( + parameters_file, + yaml.dump({"some": {"path": {"foo": ["val1", "val2"]}}}), + ) + dep = ParamsDependency(Stage(dvc), parameters_file, ["some.path.foo"]) + with pytest.raises(BadParamFileError): + dep.read_params() + + +def test_read_params_toml(tmp_dir, dvc): + parameters_file = "parameters.toml" + tmp_dir.gen( + parameters_file, + toml.dumps({"some": {"path": {"foo": ["val1", "val2"]}}}), + ) + dep = ParamsDependency(Stage(dvc), parameters_file, ["some.path.foo"]) + assert dep.read_params() == {"some.path.foo": ["val1", "val2"]} + + def test_save_info_missing_config(dvc): dep = ParamsDependency(Stage(dvc), None, ["foo"]) with pytest.raises(MissingParamsError):