Skip to content

Commit

Permalink
Add TeLU activation function
Browse files Browse the repository at this point in the history
Signed-off-by: Emanuele Ballarin <[email protected]>
  • Loading branch information
emaballarin committed Jan 6, 2025
1 parent e6e4f37 commit 55dd803
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 10 deletions.
1 change: 1 addition & 0 deletions ebtorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
del SolvePoissonTensor
del StatefulTupleSelect
del SwiGLU
del TeLU
del TupleDecouple
del TupleSelect
del ViTStem
Expand Down
3 changes: 3 additions & 0 deletions ebtorch/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
from .serlu import SERLU
from .sinlu import SinLU
from .smelu import SmeLU
from .telu import TeLU
from .utils import *

# Deletions (from .)
Expand All @@ -109,10 +110,12 @@
del serlu
del sinlu
del smelu
del telu

# Deletions (from .functional)
# del mish # (already done by chance!)
# del serf # (already done by chance!)
# del telu # (already done by chance!)

# Deletions (from .utils)
del AdverApply
Expand Down
3 changes: 2 additions & 1 deletion ebtorch/nn/architectures_resnets_dm.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,8 @@ def __init__(
nn.AdaptiveAvgPool2d((1, 1)) if autopool else nn.AvgPool2d(4)
)

def _make_layer( # Do not make static.
# noinspection PyMethodMayBeStatic
def _make_layer(
self,
in_planes: int,
out_planes: int,
Expand Down
1 change: 1 addition & 0 deletions ebtorch/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .inner_functional import serlu
from .inner_functional import silhouette_score
from .inner_functional import smelu
from .inner_functional import telu
from .inner_functional import tensor_replicate

# Deletions (from .)
Expand Down
19 changes: 13 additions & 6 deletions ebtorch/nn/functional/inner_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,18 @@
from torch import Tensor

__all__ = [
"bisided_thresholding",
"cummatmul",
"field_transform",
"logit_to_prob",
"mish",
"serlu",
"smelu",
"serf",
"oldtranspose",
"serf",
"serlu",
"silhouette_score",
"cummatmul",
"smelu",
"telu",
"tensor_replicate",
"logit_to_prob",
"bisided_thresholding",
]


Expand Down Expand Up @@ -122,6 +123,12 @@ def serf(x: Tensor) -> Tensor:
return torch.erf(x / math.sqrt(2.0)) # type: ignore


@torch.jit.script
def telu(x: Tensor) -> Tensor:
"""Applies the TeLU function, element-wise."""
return x * torch.tanh(torch.exp(x))


def oldtranspose(x: Tensor) -> Tensor:
"""
Transpose a tensor along all dimensions, emulating x.T.
Expand Down
3 changes: 2 additions & 1 deletion ebtorch/nn/reshapelayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,6 @@ class FlatChannelize2DLayer(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x): # Do not make static!
# noinspection PyMethodMayBeStatic
def forward(self, x):
return x.reshape(*x.shape, 1, 1)
3 changes: 2 additions & 1 deletion ebtorch/nn/serf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,6 @@ class ScaledERF(torch.nn.Module):
def __init__(self) -> None:
super(ScaledERF, self).__init__()

def forward(self, x: Tensor) -> Tensor: # Do not make static!
# noinspection PyMethodMayBeStatic
def forward(self, x: Tensor) -> Tensor:
return fserf(x)
21 changes: 21 additions & 0 deletions ebtorch/nn/telu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import List

from torch import nn
from torch import Tensor

from .functional import telu as ftelu

__all__: List[str] = ["TeLU"]


class TeLU(nn.Module):
"""TeLU Function."""

def __init__(self) -> None:
super().__init__()

# noinspection PyMethodMayBeStatic
def forward(self, x: Tensor) -> Tensor:
return ftelu(x)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def read(fname):

setup(
name=PACKAGENAME,
version="0.28.9",
version="0.28.10",
author="Emanuele Ballarin",
author_email="[email protected]",
url="https://github.com/emaballarin/ebtorch",
Expand Down

0 comments on commit 55dd803

Please sign in to comment.