Skip to content

Commit

Permalink
fixes #568
Browse files Browse the repository at this point in the history
  • Loading branch information
jph00 committed Jun 2, 2024
1 parent 7f22dce commit 2723d2e
Show file tree
Hide file tree
Showing 3 changed files with 485 additions and 201 deletions.
1 change: 1 addition & 0 deletions fastcore/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@
'fastcore.xtras.modify_exception': ('xtras.html#modify_exception', 'fastcore/xtras.py'),
'fastcore.xtras.obj2dict': ('xtras.html#obj2dict', 'fastcore/xtras.py'),
'fastcore.xtras.open_file': ('xtras.html#open_file', 'fastcore/xtras.py'),
'fastcore.xtras.parse_env': ('xtras.html#parse_env', 'fastcore/xtras.py'),
'fastcore.xtras.partial_format': ('xtras.html#partial_format', 'fastcore/xtras.py'),
'fastcore.xtras.repo_details': ('xtras.html#repo_details', 'fastcore/xtras.py'),
'fastcore.xtras.repr_dict': ('xtras.html#repr_dict', 'fastcore/xtras.py'),
Expand Down
95 changes: 53 additions & 42 deletions fastcore/xtras.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

# %% auto 0
__all__ = ['spark_chars', 'walk', 'globtastic', 'maybe_open', 'mkdir', 'image_size', 'bunzip', 'loads', 'loads_multi', 'dumps',
'untar_dir', 'repo_details', 'run', 'open_file', 'save_pickle', 'load_pickle', 'dict2obj', 'obj2dict',
'repr_dict', 'is_listy', 'mapped', 'IterLen', 'ReindexCollection', 'get_source_link', 'truncstr',
'untar_dir', 'repo_details', 'run', 'open_file', 'save_pickle', 'load_pickle', 'parse_env', 'dict2obj',
'obj2dict', 'repr_dict', 'is_listy', 'mapped', 'IterLen', 'ReindexCollection', 'get_source_link', 'truncstr',
'sparkline', 'modify_exception', 'round_multiple', 'set_num_threads', 'join_path_file', 'autostart',
'EventTimer', 'stringfmt_names', 'PartialFormatter', 'partial_format', 'utc2local', 'local2utc', 'trace',
'modified_env', 'ContextManagers', 'shufflish', 'console_help', 'hl_md', 'type2str', 'dataclass_src']
Expand Down Expand Up @@ -231,69 +231,80 @@ def load_pickle(fn):
import pickle
with open_file(fn, 'rb') as f: return pickle.load(f)

# %% ../nbs/03_xtras.ipynb 60
# %% ../nbs/03_xtras.ipynb 59
def parse_env(s:str=None, fn:Union[str,Path]=None) -> dict:
"Parse a shell-style environment string or file"
assert bool(s)^bool(fn), "Must pass exactly one of `s` or `fn`"
if fn: s = Path(fn).read_text()
def _f(line):
m = re.match(r'^\s*(?:export\s+)?(\w+)\s*=\s*(["\']?)(.*?)(\2)\s*(?:#.*)?$', line).groups()
return m[0], m[2]

return dict(_f(o.strip()) for o in s.splitlines() if o.strip() and not re.match(r'\s*#', o))

# %% ../nbs/03_xtras.ipynb 62
def dict2obj(d, list_func=L, dict_func=AttrDict):
"Convert (possibly nested) dicts (or lists of dicts) to `AttrDict`"
if isinstance(d, (L,list)): return list_func(d).map(dict2obj)
if not isinstance(d, dict): return d
return dict_func(**{k:dict2obj(v) for k,v in d.items()})

# %% ../nbs/03_xtras.ipynb 65
# %% ../nbs/03_xtras.ipynb 67
def obj2dict(d):
"Convert (possibly nested) AttrDicts (or lists of AttrDicts) to `dict`"
if isinstance(d, (L,list)): return list(L(d).map(obj2dict))
if not isinstance(d, dict): return d
return dict(**{k:obj2dict(v) for k,v in d.items()})

# %% ../nbs/03_xtras.ipynb 68
# %% ../nbs/03_xtras.ipynb 70
def _repr_dict(d, lvl):
if isinstance(d,dict):
its = [f"{k}: {_repr_dict(v,lvl+1)}" for k,v in d.items()]
elif isinstance(d,(list,L)): its = [_repr_dict(o,lvl+1) for o in d]
else: return str(d)
return '\n' + '\n'.join([" "*(lvl*2) + "- " + o for o in its])

# %% ../nbs/03_xtras.ipynb 69
# %% ../nbs/03_xtras.ipynb 71
def repr_dict(d):
"Print nested dicts and lists, such as returned by `dict2obj`"
return _repr_dict(d,0).strip()

# %% ../nbs/03_xtras.ipynb 71
# %% ../nbs/03_xtras.ipynb 73
def is_listy(x):
"`isinstance(x, (tuple,list,L,slice,Generator))`"
return isinstance(x, (tuple,list,L,slice,Generator))

# %% ../nbs/03_xtras.ipynb 73
# %% ../nbs/03_xtras.ipynb 75
def mapped(f, it):
"map `f` over `it`, unless it's not listy, in which case return `f(it)`"
return L(it).map(f) if is_listy(it) else f(it)

# %% ../nbs/03_xtras.ipynb 77
# %% ../nbs/03_xtras.ipynb 79
@patch
def readlines(self:Path, hint=-1, encoding='utf8'):
"Read the content of `self`"
with self.open(encoding=encoding) as f: return f.readlines(hint)

# %% ../nbs/03_xtras.ipynb 78
# %% ../nbs/03_xtras.ipynb 80
@patch
def read_json(self:Path, encoding=None, errors=None):
"Same as `read_text` followed by `loads`"
return loads(self.read_text(encoding=encoding, errors=errors))

# %% ../nbs/03_xtras.ipynb 79
# %% ../nbs/03_xtras.ipynb 81
@patch
def mk_write(self:Path, data, encoding=None, errors=None, mode=511):
"Make all parent dirs of `self`, and write `data`"
self.parent.mkdir(exist_ok=True, parents=True, mode=mode)
self.write_text(data, encoding=encoding, errors=errors)

# %% ../nbs/03_xtras.ipynb 80
# %% ../nbs/03_xtras.ipynb 82
@patch
def relpath(self:Path, start=None):
"Same as `os.path.relpath`, but returns a `Path`, and resolves symlinks"
return Path(os.path.relpath(self.resolve(), Path(start).resolve()))

# %% ../nbs/03_xtras.ipynb 83
# %% ../nbs/03_xtras.ipynb 85
@patch
def ls(self:Path, n_max=None, file_type=None, file_exts=None):
"Contents of path as a list"
Expand All @@ -305,7 +316,7 @@ def ls(self:Path, n_max=None, file_type=None, file_exts=None):
if n_max is not None: res = itertools.islice(res, n_max)
return L(res)

# %% ../nbs/03_xtras.ipynb 89
# %% ../nbs/03_xtras.ipynb 91
@patch
def __repr__(self:Path):
b = getattr(Path, 'BASE_PATH', None)
Expand All @@ -314,7 +325,7 @@ def __repr__(self:Path):
except: pass
return f"Path({self.as_posix()!r})"

# %% ../nbs/03_xtras.ipynb 92
# %% ../nbs/03_xtras.ipynb 94
@patch
def delete(self:Path):
"Delete a file, symlink, or directory tree"
Expand All @@ -324,12 +335,12 @@ def delete(self:Path):
shutil.rmtree(self)
else: self.unlink()

# %% ../nbs/03_xtras.ipynb 94
# %% ../nbs/03_xtras.ipynb 96
class IterLen:
"Base class to add iteration to anything supporting `__len__` and `__getitem__`"
def __iter__(self): return (self[i] for i in range_of(self))

# %% ../nbs/03_xtras.ipynb 95
# %% ../nbs/03_xtras.ipynb 97
@docs
class ReindexCollection(GetAttr, IterLen):
"Reindexes collection `coll` with indices `idxs` and optional LRU cache of size `cache`"
Expand All @@ -354,7 +365,7 @@ def __setstate__(self, s): self.coll,self.idxs,self.cache,self.tfm = s['coll'],s
shuffle="Randomly shuffle indices",
cache_clear="Clear LRU cache")

# %% ../nbs/03_xtras.ipynb 114
# %% ../nbs/03_xtras.ipynb 116
def _is_type_dispatch(x): return type(x).__name__ == "TypeDispatch"
def _unwrapped_type_dispatch_func(x): return x.first() if _is_type_dispatch(x) else x

Expand All @@ -381,15 +392,15 @@ def get_source_link(func):
return f"{nbdev_mod.git_url}{module}#L{line}"
except: return f"{module}#L{line}"

# %% ../nbs/03_xtras.ipynb 118
# %% ../nbs/03_xtras.ipynb 120
def truncstr(s:str, maxlen:int, suf:str='…', space='')->str:
"Truncate `s` to length `maxlen`, adding suffix `suf` if truncated"
return s[:maxlen-len(suf)]+suf if len(s)+len(space)>maxlen else s+space

# %% ../nbs/03_xtras.ipynb 120
# %% ../nbs/03_xtras.ipynb 122
spark_chars = '▁▂▃▅▆▇'

# %% ../nbs/03_xtras.ipynb 121
# %% ../nbs/03_xtras.ipynb 123
def _ceil(x, lim=None): return x if (not lim or x <= lim) else lim

def _sparkchar(x, mn, mx, incr, empty_zero):
Expand All @@ -398,7 +409,7 @@ def _sparkchar(x, mn, mx, incr, empty_zero):
res = int((_ceil(x,mx)-mn)/incr-0.5)
return spark_chars[res]

# %% ../nbs/03_xtras.ipynb 122
# %% ../nbs/03_xtras.ipynb 124
def sparkline(data, mn=None, mx=None, empty_zero=False):
"Sparkline for `data`, with `None`s (and zero, if `empty_zero`) shown as empty column"
valid = [o for o in data if o is not None]
Expand All @@ -407,7 +418,7 @@ def sparkline(data, mn=None, mx=None, empty_zero=False):
res = [_sparkchar(x=o, mn=mn, mx=mx, incr=(mx-mn)/n, empty_zero=empty_zero) for o in data]
return ''.join(res)

# %% ../nbs/03_xtras.ipynb 126
# %% ../nbs/03_xtras.ipynb 128
def modify_exception(
e:Exception, # An exception
msg:str=None, # A custom message
Expand All @@ -417,14 +428,14 @@ def modify_exception(
e.args = [f'{e.args[0]} {msg}'] if not replace and len(e.args) > 0 else [msg]
return e

# %% ../nbs/03_xtras.ipynb 128
# %% ../nbs/03_xtras.ipynb 130
def round_multiple(x, mult, round_down=False):
"Round `x` to nearest multiple of `mult`"
def _f(x_): return (int if round_down else round)(x_/mult)*mult
res = L(x).map(_f)
return res if is_listy(x) else res[0]

# %% ../nbs/03_xtras.ipynb 130
# %% ../nbs/03_xtras.ipynb 132
def set_num_threads(nt):
"Get numpy (and others) to use `nt` threads"
try: import mkl; mkl.set_num_threads(nt)
Expand All @@ -435,14 +446,14 @@ def set_num_threads(nt):
for o in ['OPENBLAS_NUM_THREADS','NUMEXPR_NUM_THREADS','OMP_NUM_THREADS','MKL_NUM_THREADS']:
os.environ[o] = str(nt)

# %% ../nbs/03_xtras.ipynb 132
# %% ../nbs/03_xtras.ipynb 134
def join_path_file(file, path, ext=''):
"Return `path/file` if file is a string or a `Path`, file otherwise"
if not isinstance(file, (str, Path)): return file
path.mkdir(parents=True, exist_ok=True)
return path/f'{file}{ext}'

# %% ../nbs/03_xtras.ipynb 134
# %% ../nbs/03_xtras.ipynb 136
def autostart(g):
"Decorator that automatically starts a generator"
@functools.wraps(g)
Expand All @@ -452,7 +463,7 @@ def f():
return r
return f

# %% ../nbs/03_xtras.ipynb 135
# %% ../nbs/03_xtras.ipynb 137
class EventTimer:
"An event timer with history of `store` items of time `span`"

Expand All @@ -476,15 +487,15 @@ def duration(self): return time.perf_counter()-self.start
@property
def freq(self): return self.events/self.duration

# %% ../nbs/03_xtras.ipynb 139
# %% ../nbs/03_xtras.ipynb 141
_fmt = string.Formatter()

# %% ../nbs/03_xtras.ipynb 140
# %% ../nbs/03_xtras.ipynb 142
def stringfmt_names(s:str)->list:
"Unique brace-delimited names in `s`"
return uniqueify(o[1] for o in _fmt.parse(s) if o[1])

# %% ../nbs/03_xtras.ipynb 142
# %% ../nbs/03_xtras.ipynb 144
class PartialFormatter(string.Formatter):
"A `string.Formatter` that doesn't error on missing fields, and tracks missing fields and unused args"
def __init__(self):
Expand All @@ -500,24 +511,24 @@ def get_field(self, nm, args, kwargs):
def check_unused_args(self, used, args, kwargs):
self.xtra = filter_keys(kwargs, lambda o: o not in used)

# %% ../nbs/03_xtras.ipynb 144
# %% ../nbs/03_xtras.ipynb 146
def partial_format(s:str, **kwargs):
"string format `s`, ignoring missing field errors, returning missing and extra fields"
fmt = PartialFormatter()
res = fmt.format(s, **kwargs)
return res,list(fmt.missing),fmt.xtra

# %% ../nbs/03_xtras.ipynb 147
# %% ../nbs/03_xtras.ipynb 149
def utc2local(dt:datetime)->datetime:
"Convert `dt` from UTC to local time"
return dt.replace(tzinfo=timezone.utc).astimezone(tz=None)

# %% ../nbs/03_xtras.ipynb 149
# %% ../nbs/03_xtras.ipynb 151
def local2utc(dt:datetime)->datetime:
"Convert `dt` from local to UTC time"
return dt.replace(tzinfo=None).astimezone(tz=timezone.utc)

# %% ../nbs/03_xtras.ipynb 151
# %% ../nbs/03_xtras.ipynb 153
def trace(f):
"Add `set_trace` to an existing function `f`"
from pdb import set_trace
Expand All @@ -528,7 +539,7 @@ def _inner(*args,**kwargs):
_inner._traced = True
return _inner

# %% ../nbs/03_xtras.ipynb 153
# %% ../nbs/03_xtras.ipynb 155
@contextmanager
def modified_env(*delete, **replace):
"Context manager temporarily modifying `os.environ` by deleting `delete` and replacing `replace`"
Expand All @@ -541,21 +552,21 @@ def modified_env(*delete, **replace):
os.environ.clear()
os.environ.update(prev)

# %% ../nbs/03_xtras.ipynb 155
# %% ../nbs/03_xtras.ipynb 157
class ContextManagers(GetAttr):
"Wrapper for `contextlib.ExitStack` which enters a collection of context managers"
def __init__(self, mgrs): self.default,self.stack = L(mgrs),ExitStack()
def __enter__(self): self.default.map(self.stack.enter_context)
def __exit__(self, *args, **kwargs): self.stack.__exit__(*args, **kwargs)

# %% ../nbs/03_xtras.ipynb 157
# %% ../nbs/03_xtras.ipynb 159
def shufflish(x, pct=0.04):
"Randomly relocate items of `x` up to `pct` of `len(x)` from their starting location"
n = len(x)
import random
return L(x[i] for i in sorted(range_of(x), key=lambda o: o+n*(1+random.random()*pct)))

# %% ../nbs/03_xtras.ipynb 158
# %% ../nbs/03_xtras.ipynb 160
def console_help(
libname:str): # name of library for console script listing
"Show help for all console scripts from `libname`"
Expand All @@ -567,7 +578,7 @@ def console_help(
print(f'{nm:45}{e.load().__doc__}')


# %% ../nbs/03_xtras.ipynb 159
# %% ../nbs/03_xtras.ipynb 161
def hl_md(s, lang='xml', show=True):
"Syntax highlight `s` using `lang`."
md = f'```{lang}\n{s}\n```'
Expand All @@ -577,7 +588,7 @@ def hl_md(s, lang='xml', show=True):
return display.Markdown(md)
except ImportError: print(s)

# %% ../nbs/03_xtras.ipynb 162
# %% ../nbs/03_xtras.ipynb 164
def type2str(typ:type)->str:
"Stringify `typ`"
if typ is None or typ is NoneType: return 'None'
Expand All @@ -588,7 +599,7 @@ def type2str(typ:type)->str:
elif isinstance(typ, type): return typ.__name__
return str(typ)

# %% ../nbs/03_xtras.ipynb 164
# %% ../nbs/03_xtras.ipynb 166
def dataclass_src(cls):
import dataclasses
src = f"@dataclass\nclass {cls.__name__}:\n"
Expand Down
Loading

0 comments on commit 2723d2e

Please sign in to comment.