Skip to content

Commit

Permalink
Add some minor-impact utilities, rework types in a backward-compatibl…
Browse files Browse the repository at this point in the history
…e way.

Signed-off-by: Emanuele Ballarin <[email protected]>
  • Loading branch information
emaballarin committed Dec 11, 2024
1 parent 1dc09c7 commit 9c5fcd4
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 3 deletions.
4 changes: 4 additions & 0 deletions ebtorch/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,12 @@
del tableau10_cycler
del petroff_2021_cmap
del tableau10_cmap
del variadic_attrs
del set_petroff_2021_colors
del set_tableau10_colors
del custom_plot_setup
del plot_out
del fromcache
del repr_sizes_flat_adapter
del repr_fx_flat_adapter
del act_auto_broadcast
Expand Down
7 changes: 7 additions & 0 deletions ebtorch/nn/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from .actab import broadcast_in_dim
from .adverutils import AdverApply
from .adverutils import TA2ATAdapter
from .attrsutils import variadic_attrs
from .autoclip import AutoClipper
from .cacher import fromcache
from .csttyping import actvt
from .csttyping import numlike
from .csttyping import realnum
Expand Down Expand Up @@ -49,6 +51,8 @@
from .patches import patchify_2d
from .patches import patchify_batch
from .patches import patchify_dataset
from .plotting import custom_plot_setup
from .plotting import plot_out
from .reprutils import gather_model_repr
from .reprutils import model_reqgrad
from .reprutils import model_reqgrad_
Expand All @@ -69,5 +73,8 @@
del reprutils
del filtermanip
del palettes
del cacher
del plotting
del attrsutils
del mapply
del csttyping
57 changes: 57 additions & 0 deletions ebtorch/nn/utils/attrsutils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ~~ Imports ~~ ────────────────────────────────────────────────────────────────
from collections.abc import Iterable
from functools import partial
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

# ~~ Exports ~~ ────────────────────────────────────────────────────────────────
__all__: List[str] = ["variadic_attrs"]


# ~~ Utilities ~~ ──────────────────────────────────────────────────────────────
def _str_to_bool(s: str, onesym: bool = False) -> bool:
osl: List[str] = ["t", "y", "1"]
if onesym:
return s.lower() in osl
return s.lower() in (osl + ["true", "yes"])


def _any_to_bool(x, onesym: bool = False) -> bool:
if isinstance(x, str):
return _str_to_bool(x, onesym)
return bool(x)


def _str_to_booltuple(s: str, sep: Optional[str] = None) -> Tuple[bool, ...]:
if sep is not None:
return tuple(map(_str_to_bool, s.split(sep)))
return tuple(map(partial(_str_to_bool, onesym=True), [*s]))


def _any_to_booltuple(
x: Union[str, Iterable[Union[str, bool]]], sep: Optional[str] = None
) -> Tuple[bool, ...]:
if isinstance(x, str):
return _str_to_booltuple(x, sep)
return tuple(map(_any_to_bool, x))


def variadic_attrs(
selfobj,
varsel: Optional[Iterable[Union[str, bool]]] = None,
insep: Optional[str] = None,
outsep: str = "_",
):
odict: dict = selfobj.__getstate__()
odkeys: Tuple[str, ...] = tuple(odict.keys())
lodk: int = len(odkeys)
varsel: Iterable = varsel if varsel is not None else ([True] * lodk)
bvsel: Tuple[bool, ...] = _any_to_booltuple(varsel, insep)
strtuple: Tuple[str, ...] = tuple(
str(odict[odkeys[i]]) if bvsel[i] else "" for i in range(lodk)
)
return (outsep.join(strtuple)).strip().strip(outsep)
57 changes: 57 additions & 0 deletions ebtorch/nn/utils/cacher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ~~ Imports ~~ ────────────────────────────────────────────────────────────────
from collections.abc import Callable
from typing import List
from typing import Optional
from typing import Tuple

# ~~ Exports ~~ ────────────────────────────────────────────────────────────────
__all__: List[str] = ["fromcache"]


# ~~ Utilities ~~ ──────────────────────────────────────────────────────────────
def _normargs(args, kwargs, kwpos: Tuple[str, ...], kwdef: tuple) -> tuple:
largs: int = len(args)
n_args: list = list(args) + [None] * (len(kwpos) - largs)
for i, argname in enumerate(kwpos[largs:], start=largs):
if argname in kwargs:
n_args[i] = kwargs[argname]
else:
n_args[i] = kwdef[i]
return tuple(n_args)


def _args2keyer(kwpos: Tuple[str, ...], kwdef: tuple) -> Callable:
def _args2key(*args, **kwargs) -> tuple:
return _normargs(args, kwargs, kwpos, kwdef)

return _args2key


def _retlookup(key: tuple, dictionary: dict) -> Optional:
return dictionary[key] if key in dictionary else None


# ~~ Cache retrieval function ~~ ───────────────────────────────────────────────
def fromcache(
func: Callable,
*,
kwpos: Tuple[str, ...],
kwdef: tuple,
cache: dict,
updateable: bool = True,
) -> Callable:
def _cached_func(*args, **kwargs):
key: tuple = _args2keyer(kwpos, kwdef)(*args, **kwargs)
if (rlk := _retlookup(key, cache)) is not None:
return rlk
else:
if updateable:
funcret = func(*args, **kwargs)
cache[key] = funcret
return funcret
else:
raise ValueError("Cache miss, but cache is not updateable.")

return _cached_func
34 changes: 34 additions & 0 deletions ebtorch/nn/utils/plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ──────────────────────────────────────────────────────────────────────────────
from typing import List
from typing import Optional

import matplotlib.pyplot as plt

# ──────────────────────────────────────────────────────────────────────────────
__all__: List[str] = ["custom_plot_setup", "plot_out"]
# ──────────────────────────────────────────────────────────────────────────────


def custom_plot_setup() -> None:
plt.rcParams["text.usetex"] = True
plt.style.use("ggplot")
plt.rcParams["axes.facecolor"] = "white"
plt.rcParams["axes.edgecolor"] = "black"
plt.rcParams["axes.spines.top"] = False
plt.rcParams["axes.spines.right"] = False
plt.rcParams["xtick.color"] = "black"
plt.rcParams["ytick.color"] = "black"
plt.rcParams["axes.labelcolor"] = "black"
plt.rcParams["grid.color"] = "gainsboro"


# ──────────────────────────────────────────────────────────────────────────────


def plot_out(savepath: Optional[str] = None) -> None:
if savepath:
plt.savefig(savepath, dpi=400, bbox_inches="tight")
else:
plt.show()
6 changes: 4 additions & 2 deletions ebtorch/typing/customtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from typing import List
from typing import Union

import numpy as np
import torch
from torch import Tensor

# ──────────────────────────────────────────────────────────────────────────────
__all__: List[str] = ["realnum", "strdev", "numlike"]
__all__: List[str] = ["realnum", "strdev", "numlike", "tensorlike"]
# ──────────────────────────────────────────────────────────────────────────────
realnum = Union[int, float]
strdev = Union[str, torch.device]
numlike = Union[realnum, Tensor]
tensorlike = Union[Tensor, np.ndarray]
numlike = Union[realnum, tensorlike]
# ──────────────────────────────────────────────────────────────────────────────
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def read(fname):

setup(
name=PACKAGENAME,
version="0.28.2",
version="0.28.3",
author="Emanuele Ballarin",
author_email="[email protected]",
url="https://github.com/emaballarin/ebtorch",
Expand Down

0 comments on commit 9c5fcd4

Please sign in to comment.