Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Feb 10, 2018
1 parent f114aa5 commit 33f0941
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions python/mxnet/gluon/nn/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ class PReLU(HybridBlock):
Outputs:
- **out**: output tensor with the same shape as `data`.
"""
def __init__(self, alpha_initializer=initializer.Constant(0.25), *args):
super(PReLU, self).__init__(*args)
def __init__(self, alpha_initializer=initializer.Constant(0.25), **kwargs):
super(PReLU, self).__init__(**kwargs)
with self.name_scope():
self.alpha = self.params.get('alpha', shape=(1,), init=alpha_initializer)

Expand Down Expand Up @@ -158,7 +158,7 @@ def __init__(self, alpha=1.0, **kwargs):
self._alpha = alpha

def hybrid_forward(self, F, x):
return - self._alpha * F.relu(1.0 - F.exp(x)) + F.relu(x)
return F.where(x > 0, x, self._alpha * (F.exp(x) - 1.0))


class SELU(HybridBlock):
Expand All @@ -178,11 +178,9 @@ def __init__(self, **kwargs):
super(SELU, self).__init__(**kwargs)
self._scale = 1.0507009873554804934193349852946
self._alpha = 1.6732632423543772848170429916717
with self.name_scope():
self.elu = ELU()

def hybrid_forward(self, F, x):
return self._scale * F.where(x >= 0, x, self._alpha * self.elu(x))
return self._scale * F.where(x > 0, x, self._alpha * (F.exp(x) - 1.0))


class Swish(HybridBlock):
Expand Down

0 comments on commit 33f0941

Please sign in to comment.