Skip to content

Commit

Permalink
Hermitian stuff
Browse files Browse the repository at this point in the history
Signed-off-by: Emanuele Ballarin <[email protected]>
  • Loading branch information
emaballarin committed Dec 21, 2024
1 parent b377673 commit 4f97362
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 1 deletion.
2 changes: 2 additions & 0 deletions ebtorch/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@
del fromcache
del fxfx2module
del gather_model_repr
del hermitize
del matched_apply
del model_reqgrad
del model_reqgrad_
Expand All @@ -144,6 +145,7 @@
del petroff_2021_color
del petroff_2021_cycler
del plot_out
del randhermn
del repr_fx_flat_adapter
del repr_sizes_flat_adapter
del seed_everything
Expand Down
2 changes: 2 additions & 0 deletions ebtorch/nn/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
from .onlyutils import download_gdrive
from .onlyutils import emplace_kv
from .onlyutils import fxfx2module
from .onlyutils import hermitize
from .onlyutils import no_op
from .onlyutils import randhermn
from .onlyutils import stablediv
from .onlyutils import subset_state_dict
from .onlyutils import suppress_std
Expand Down
17 changes: 17 additions & 0 deletions ebtorch/nn/utils/onlyutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,16 @@
from typing import Union

import requests
import torch as th
from httpx import Client
from safe_assert import safe_assert as sassert
from torch import dtype as _dtype
from torch import nn
from torch import Tensor

from ...typing import actvt
from ...typing import numlike
from ...typing import strdev

__all__ = [
"argser_f",
Expand All @@ -56,6 +59,8 @@
"suppress_std",
"TelegramBotEcho",
"stablediv",
"hermitize",
"randhermn",
]


Expand Down Expand Up @@ -199,6 +204,18 @@ def suppress_std(which: str = "all") -> None:
sys.stderr = old_stderr


def hermitize(x: Tensor) -> Tensor:
return (x + x.conj().t()) / 2


def randhermn(
n: int,
dtype: Optional[_dtype] = th.cdouble,
device: Optional[strdev] = None,
):
return 2 * hermitize(th.randn(n, n, dtype=dtype, device=device))


# Classes
class _FxToFxobj: # NOSONAR
__slots__ = ("fx",)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def read(fname):

setup(
name=PACKAGENAME,
version="0.28.7",
version="0.28.8",
author="Emanuele Ballarin",
author_email="[email protected]",
url="https://github.com/emaballarin/ebtorch",
Expand Down

0 comments on commit 4f97362

Please sign in to comment.