Skip to content

Commit

Permalink
Update!
Browse files Browse the repository at this point in the history
Signed-off-by: Emanuele Ballarin <[email protected]>
  • Loading branch information
emaballarin committed Jan 16, 2024
1 parent 96f773a commit d7f1e86
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 1 deletion.
7 changes: 7 additions & 0 deletions ebtorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
# ==============================================================================
#
# SPDX-License-Identifier: Apache-2.0
#
# Versioning
__version__ = "0.20.0"

# Imports (wildcard)
from .data import *
from .distributed import *
Expand Down Expand Up @@ -89,6 +93,7 @@
del SolvePoissonTensor
del SwiGLU
del TupleDecouple
del SilhouetteScore
del WideResNet
del beta_reco_bce
del build_repeated_sequential
Expand All @@ -97,6 +102,8 @@
del pixelwise_bce_mean
del pixelwise_bce_sum
del oldtranspose
del silhouette_score
del cummatmul

# Deletions (from .optim)
del AdaBound
Expand Down
3 changes: 3 additions & 0 deletions ebtorch/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .architectures import RBLinear
from .architectures import ResBlock
from .architectures import SGRUHCell
from .architectures import SilhouetteScore
from .architectures import SirenSine
from .architectures import SwiGLU
from .architectures import TupleDecouple
Expand All @@ -41,9 +42,11 @@
from .coordconv import CoordConv3d
from .debuglayers import ProbePrintLayer
from .fieldtransform import FieldTransform
from .functional import cummatmul
from .functional import field_transform
from .functional import mish
from .functional import oldtranspose
from .functional import silhouette_score
from .kwta import BrokenReLU
from .kwta import KWTA1d
from .kwta import KWTA2d
Expand Down
12 changes: 12 additions & 0 deletions ebtorch/nn/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from torch.nn import functional as F
from torch.utils.hooks import RemovableHandle

from .functional import silhouette_score
from .penalties import beta_gaussian_kldiv

__all__ = [
Expand All @@ -56,6 +57,7 @@
"Clamp",
"SwiGLU",
"TupleDecouple",
"SilhouetteScore",
]

# CUSTOM TYPES
Expand Down Expand Up @@ -865,3 +867,13 @@ def forward(self, xtuple: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]:
*xtuple[self.idx + 1 :],
)
)


class SilhouetteScore(nn.Module):
"""
Layerized computation of the Silhouette Score.
"""

@staticmethod
def forward(features: Tensor, labels: Tensor) -> Tensor:
return silhouette_score(features, labels)
2 changes: 2 additions & 0 deletions ebtorch/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
#
# ------------------------------------------------------------------------------
# Imports (specific)
from .inner_functional import cummatmul
from .inner_functional import field_transform
from .inner_functional import mish
from .inner_functional import oldtranspose
from .inner_functional import serf
from .inner_functional import serlu
from .inner_functional import silhouette_score
from .inner_functional import smelu

# Deletions (from .)
Expand Down
74 changes: 74 additions & 0 deletions ebtorch/nn/functional/inner_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
# SPDX-License-Identifier: Apache-2.0
# IMPORTS
import math
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import torch
import torch.nn.functional as F
Expand All @@ -42,6 +46,8 @@
"smelu",
"serf",
"oldtranspose",
"silhouette_score",
"cummatmul",
]


Expand Down Expand Up @@ -124,3 +130,71 @@ def oldtranspose(x: Tensor) -> Tensor:
Transposed of x.
"""
return x.permute(*torch.arange(x.ndim - 1, -1, -1))


def silhouette_score(feats: Tensor, labels: Tensor) -> Union[float, Tensor]: # NOSONAR
if feats.shape[0] != labels.shape[0]:
raise ValueError(
f"`feats` (shape {feats.shape}) and `labels` (shape {labels.shape}) must have same length"
)
device, dtype = feats.device, feats.dtype
unique_labels: Union[Tensor, Tuple[Tensor, ...]] = torch.unique(labels)
num_samples: int = feats.shape[0]
if not (1 < len(unique_labels) < num_samples):
raise ValueError("The number of unique `labels` must be ∈ (1, `num_samples`)")
scores: List[Tensor] = []
for l_label in unique_labels:
curr_cluster: Tensor = feats[labels == l_label]
num_elements: int = len(curr_cluster)
if num_elements > 1:
intra_cluster_dists: Tensor = torch.cdist(curr_cluster, curr_cluster)
mean_intra_dists: Tensor = torch.sum(intra_cluster_dists, dim=1) / (
num_elements - 1
)
dists_to_other_clusters: List[Tensor] = []
for other_l in unique_labels:
if other_l != l_label:
other_cluster: Tensor = feats[labels == other_l]
inter_cluster_dists: Tensor = torch.cdist(
curr_cluster, other_cluster
)
mean_inter_dists: Tensor = torch.sum(inter_cluster_dists, dim=1) / (
len(other_cluster)
)
dists_to_other_clusters.append(mean_inter_dists)
dists_to_other_clusters_t: Tensor = torch.stack(
dists_to_other_clusters, dim=1
)
min_dists: Tensor = torch.min(dists_to_other_clusters_t, dim=1)[0]
curr_scores: Tensor = (min_dists - mean_intra_dists) / (
torch.maximum(min_dists, mean_intra_dists)
)
else:
curr_scores: Tensor = torch.tensor([0], device=device, dtype=dtype)

scores.append(curr_scores)

scores_t: Tensor = torch.cat(scores, dim=0)
if len(scores_t) != num_samples:
raise ValueError(
f"`scores_t` (shape {scores_t.shape}) should have same length as `feats` (shape {feats.shape})"
)
return torch.mean(scores_t)


def cummatmul(
input_list: Union[List[Tensor], Tensor], tensorize: Optional[bool] = None
) -> Union[List[Tensor], Tensor]:
if tensorize is None:
if isinstance(input_list, Tensor):
tensorize = True
else:
tensorize = False
cmm_list: List[Tensor] = [input_list[0]]
mat: Tensor
for mat in input_list[1:]:
cmm_list.append(torch.matmul(cmm_list[-1], mat))
if tensorize:
return torch.stack(cmm_list)
else:
return cmm_list
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def check_dependencies(dependencies: list[str]):

setup(
name=PACKAGENAME,
version="0.19.3",
version="0.20.0",
author="Emanuele Ballarin",
author_email="[email protected]",
url="https://github.com/emaballarin/ebtorch",
Expand Down

0 comments on commit d7f1e86

Please sign in to comment.