Skip to content

Commit

Permalink
add toml support for ParamsDependency (#4258)
Browse files Browse the repository at this point in the history
* add toml support for ParamsDependency

* add toml support to dvc show

Co-authored-by: Daniele Trifirò <[email protected]>
  • Loading branch information
dtrifiro and Daniele Trifirò authored Jul 24, 2020
1 parent 39c0bdb commit e2b1fb1
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 5 deletions.
9 changes: 7 additions & 2 deletions dvc/dependency/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import defaultdict

import dpath.util
import toml
import yaml
from voluptuous import Any

Expand All @@ -21,6 +22,8 @@ class ParamsDependency(LocalDependency):
PARAM_PARAMS = "params"
PARAM_SCHEMA = {PARAM_PARAMS: Any(dict, list, None)}
DEFAULT_PARAMS_FILE = "params.yaml"
PARAMS_FILE_LOADERS = defaultdict(lambda: yaml.safe_load)
PARAMS_FILE_LOADERS.update({".toml": toml.load})

def __init__(self, stage, path, params):
info = {}
Expand Down Expand Up @@ -87,8 +90,10 @@ def read_params(self):

with self.repo.tree.open(self.path_info, "r") as fobj:
try:
config = yaml.safe_load(fobj)
except yaml.YAMLError as exc:
config = self.PARAMS_FILE_LOADERS[
self.path_info.suffix.lower()
](fobj)
except (yaml.YAMLError, toml.TomlDecodeError) as exc:
raise BadParamFileError(
f"Unable to read parameters from '{self}'"
) from exc
Expand Down
7 changes: 5 additions & 2 deletions dvc/repo/params/show.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

import toml
import yaml

from dvc.dependency.param import ParamsDependency
Expand Down Expand Up @@ -34,8 +35,10 @@ def _read_params(repo, configs, rev):

with repo.tree.open(config, "r") as fobj:
try:
res[str(config)] = yaml.safe_load(fobj)
except yaml.YAMLError:
res[str(config)] = ParamsDependency.PARAMS_FILE_LOADERS[
config.suffix.lower()
](fobj)
except (yaml.YAMLError, toml.TomlDecodeError):
logger.debug(
"failed to read '%s' on '%s'", config, rev, exc_info=True
)
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ count=true
[isort]
include_trailing_comma=true
known_first_party=dvc,tests
known_third_party=PyInstaller,RangeHTTPServer,boto3,colorama,configobj,distro,dpath,flaky,flufl,funcy,git,grandalf,mock,moto,nanotime,networkx,packaging,pathspec,pygtrie,pylint,pytest,requests,ruamel,setuptools,shortuuid,shtab,tqdm,voluptuous,yaml,zc
known_third_party=PyInstaller,RangeHTTPServer,boto3,colorama,configobj,distro,dpath,flaky,flufl,funcy,git,grandalf,mock,moto,nanotime,networkx,packaging,pathspec,pygtrie,pylint,pytest,requests,ruamel,setuptools,shortuuid,shtab,toml,tqdm,voluptuous,yaml,zc
line_length=79
force_grid_wrap=0
use_parentheses=True
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def run(self):
"appdirs>=1.4.3",
"PyYAML>=5.1.2,<5.4", # Compatibility with awscli
"ruamel.yaml>=0.16.1",
"toml>=0.10.1",
"funcy>=1.14",
"pathspec>=0.6.0",
"shortuuid>=0.5.0",
Expand Down
10 changes: 10 additions & 0 deletions tests/func/params/test_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ def test_show(tmp_dir, dvc):
assert dvc.params.show() == {"": {"params.yaml": {"foo": "bar"}}}


def test_show_toml(tmp_dir, dvc):
tmp_dir.gen("params.toml", "[foo]\nbar = 42\nbaz = [1, 2]\n")
dvc.run(
cmd="echo params.toml", params=["params.toml:foo"], single_stage=True
)
assert dvc.params.show() == {
"": {"params.toml": {"foo": {"bar": 42, "baz": [1, 2]}}}
}


def test_show_multiple(tmp_dir, dvc):
tmp_dir.gen("params.yaml", "foo: bar\nbaz: qux\n")
dvc.run(
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/dependency/test_params.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import toml
import yaml

from dvc.dependency import ParamsDependency, loadd_from, loads_params
Expand Down Expand Up @@ -99,6 +100,37 @@ def test_read_params_nested(tmp_dir, dvc):
assert dep.read_params() == {"some.path.foo": ["val1", "val2"]}


def test_read_params_default_loader(tmp_dir, dvc):
parameters_file = "parameters.foo"
tmp_dir.gen(
parameters_file,
yaml.dump({"some": {"path": {"foo": ["val1", "val2"]}}}),
)
dep = ParamsDependency(Stage(dvc), parameters_file, ["some.path.foo"])
assert dep.read_params() == {"some.path.foo": ["val1", "val2"]}


def test_read_params_wrong_suffix(tmp_dir, dvc):
parameters_file = "parameters.toml"
tmp_dir.gen(
parameters_file,
yaml.dump({"some": {"path": {"foo": ["val1", "val2"]}}}),
)
dep = ParamsDependency(Stage(dvc), parameters_file, ["some.path.foo"])
with pytest.raises(BadParamFileError):
dep.read_params()


def test_read_params_toml(tmp_dir, dvc):
parameters_file = "parameters.toml"
tmp_dir.gen(
parameters_file,
toml.dumps({"some": {"path": {"foo": ["val1", "val2"]}}}),
)
dep = ParamsDependency(Stage(dvc), parameters_file, ["some.path.foo"])
assert dep.read_params() == {"some.path.foo": ["val1", "val2"]}


def test_save_info_missing_config(dvc):
dep = ParamsDependency(Stage(dvc), None, ["foo"])
with pytest.raises(MissingParamsError):
Expand Down

0 comments on commit e2b1fb1

Please sign in to comment.