Skip to content

Commit

Permalink
Clean up utility functions. (#597)
Browse files Browse the repository at this point in the history
* Clean up utility functions.

* Add missing test case for dotted keys function where values are empty dicts.
  • Loading branch information
bdice committed Aug 1, 2022
1 parent b688bab commit 18cae0b
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 49 deletions.
3 changes: 1 addition & 2 deletions signac/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,8 +1696,7 @@ def main():

parser_update_cache = subparsers.add_parser(
"update-cache",
description="""Use this command to update the project's persistent state point cache.
This feature is still experimental and may be removed in future versions.""",
description="Use this command to update the project's persistent state point cache.",
)
parser_update_cache.set_defaults(func=main_update_cache)

Expand Down
4 changes: 2 additions & 2 deletions signac/contrib/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from .indexing import _SignacProjectCrawler
from .job import Job
from .schema import ProjectSchema, _collect_by_type
from .utility import _mkdir_p, _nested_dicts_to_dotted_keys, split_and_print_progress
from .utility import _mkdir_p, _nested_dicts_to_dotted_keys, _split_and_print_progress

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1801,7 +1801,7 @@ def _update_in_memory_cache(self):
def _add(_id):
self._sp_cache[_id] = self._get_statepoint_from_workspace(_id)

to_add_chunks = split_and_print_progress(
to_add_chunks = _split_and_print_progress(
iterable=list(to_add),
num_chunks=max(1, min(100, int(len(to_add) / 1000))),
write=logger.info,
Expand Down
41 changes: 2 additions & 39 deletions signac/contrib/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,14 @@
import logging
import os
import sys
import tarfile
import zipfile
from collections.abc import Mapping
from contextlib import contextmanager
from datetime import timedelta
from tempfile import TemporaryDirectory
from time import time

logger = logging.getLogger(__name__)


def query_yes_no(question, default="yes"):
def query_yes_no(question, default="yes"): # pragma: no cover
"""Ask a yes/no question via input() and return their answer.
"question" is a string that is presented to the user.
Expand Down Expand Up @@ -150,7 +146,7 @@ def _mkdir_p(path):
os.makedirs(path, exist_ok=True)


def split_and_print_progress(iterable, num_chunks=10, write=None, desc="Progress: "):
def _split_and_print_progress(iterable, num_chunks=10, write=None, desc="Progress: "):
"""Split the progress and prints it.
Parameters
Expand Down Expand Up @@ -201,39 +197,6 @@ def split_and_print_progress(iterable, num_chunks=10, write=None, desc="Progress
yield iterable


@contextmanager
def _extract(filename):
"""Extract zipfile and tarfile.
Parameters
----------
filename : str
Name of zipfile/tarfile to extract.
Yields
------
str
Path to the extracted directory.
Raises
------
RuntimeError
When the provided file is neither a zipfile nor a tarfile.
"""
with TemporaryDirectory() as tmpdir:
if zipfile.is_zipfile(filename):
with zipfile.ZipFile(filename) as file:
file.extractall(tmpdir)
yield tmpdir
elif tarfile.is_tarfile(filename):
with tarfile.open(filename) as file:
file.extractall(path=tmpdir)
yield tmpdir
else:
raise RuntimeError(f"Unknown file type: '{filename}'.")


def _dotted_dict_to_nested_dicts(dotted_dict, delimiter_nested="."):
"""Convert dotted keys in the state point dict to a nested dict.
Expand Down
22 changes: 16 additions & 6 deletions tests/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,8 +809,9 @@ def test_schema(self):
for i in range(10):
self.project.open_job(
{
"const": 0,
"const1": 0,
"const2": {"const3": 0},
"const4": {},
"a": i,
"b": {"b2": i},
"c": [i if i % 2 else None, 0, 0],
Expand All @@ -821,8 +822,18 @@ def test_schema(self):
).init()

s = self.project.detect_schema()
assert len(s) == 9
for k in "const", "const2.const3", "a", "b.b2", "c", "d", "e.e2", "f.f2":
assert len(s) == 10
for k in (
"const1",
"const2.const3",
"const4",
"a",
"b.b2",
"c",
"d",
"e.e2",
"f.f2",
):
assert k in s
if "." in k:
with pytest.warns(FutureWarning):
Expand All @@ -836,10 +847,9 @@ def test_schema(self):
assert s.format() == str(s)
s = self.project.detect_schema(exclude_const=True)
assert len(s) == 7
assert "const" not in s
with pytest.warns(FutureWarning):
assert ("const2", "const3") not in s
assert "const1" not in s
assert "const2.const3" not in s
assert "const4" not in s
assert type not in s["e"]

def test_schema_subset(self):
Expand Down

0 comments on commit 18cae0b

Please sign in to comment.