Skip to content

Commit

Permalink
fix bug, bias not needed
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 5, 2021
1 parent 3588b04 commit 04228e8
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
10 changes: 3 additions & 7 deletions egnn_pytorch/egnn_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,16 @@ def forward(self, x):
# https://github.com/lucidrains/se3-transformer-pytorch/blob/main/se3_transformer_pytorch/se3_transformer_pytorch.py#L95

class CoorsNorm(nn.Module):
def __init__(self, eps = 1e-8, scale_init = 1., bias_init = 0.):
def __init__(self, eps = 1e-8, scale_init = 1.):
super().__init__()
self.eps = eps
scale = torch.zeros(1).fill_(scale_init)
bias = torch.zeros(1).fill_(bias_init)

self.scale = nn.Parameter(scale)
self.bias = nn.Parameter(bias)

def forward(self, coors):
norm = coors.norm(dim = -1, keepdim = True)
normed_coors = coors / norm.clamp(min = self.eps)
return normed_coors * self.scale + self.bias
return normed_coors * self.scale

# global linear attention

Expand Down Expand Up @@ -181,7 +178,6 @@ def __init__(
norm_feats = False,
norm_coors = False,
norm_coors_scale_init = 1e-2,
norm_coors_bias_init = 0.,
update_feats = True,
update_coors = True,
only_sparse_neighbors = False,
Expand Down Expand Up @@ -213,7 +209,7 @@ def __init__(
) if soft_edges else None

self.node_norm = nn.LayerNorm(dim) if norm_feats else nn.Identity()
self.coors_norm = CoorsNorm(scale_init = norm_coors_scale_init, bias_init = norm_coors_bias_init) if norm_coors else nn.Identity()
self.coors_norm = CoorsNorm(scale_init = norm_coors_scale_init) if norm_coors else nn.Identity()

self.m_pool_method = m_pool_method

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'egnn-pytorch',
packages = find_packages(),
version = '0.2.4',
version = '0.2.5',
license='MIT',
description = 'E(n)-Equivariant Graph Neural Network - Pytorch',
author = 'Phil Wang, Eric Alcaide',
Expand Down

0 comments on commit 04228e8

Please sign in to comment.