Skip to content

Commit

Permalink
add ability to customize initial values for scale and bias for coordi…
Browse files Browse the repository at this point in the history
…nate norm
  • Loading branch information
lucidrains committed Jun 5, 2021
1 parent 44c3687 commit ce278a5
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
13 changes: 9 additions & 4 deletions egnn_pytorch/egnn_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,14 @@ def forward(self, x):
# https://github.com/lucidrains/se3-transformer-pytorch/blob/main/se3_transformer_pytorch/se3_transformer_pytorch.py#L95

class CoorsNorm(nn.Module):
def __init__(self, eps = 1e-8):
def __init__(self, eps = 1e-8, scale_init = 1., bias_init = 0.):
super().__init__()
self.eps = eps
self.scale = nn.Parameter(torch.ones(1))
self.bias = nn.Parameter(torch.zeros(1))
scale = torch.zeros(1).fill_(scale_init)
bias = torch.zeros(1).fill_(bias_init)

self.scale = nn.Parameter(scale)
self.bias = nn.Parameter(bias)

def forward(self, coors):
norm = coors.norm(dim = -1, keepdim = True)
Expand All @@ -110,6 +113,8 @@ def __init__(
init_eps = 1e-3,
norm_feats = False,
norm_coors = False,
norm_coors_scale_init = 1.,
norm_coors_bias_init = 0.,
update_feats = True,
update_coors = True,
only_sparse_neighbors = False,
Expand Down Expand Up @@ -141,7 +146,7 @@ def __init__(
) if soft_edges else None

self.node_norm = nn.LayerNorm(dim) if norm_feats else nn.Identity()
self.coors_norm = CoorsNorm() if norm_coors else nn.Identity()
self.coors_norm = CoorsNorm(scale_init = norm_coors_scale_init, bias_init = norm_coors_bias_init) if norm_coors else nn.Identity()

self.m_pool_method = m_pool_method

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'egnn-pytorch',
packages = find_packages(),
version = '0.2.0',
version = '0.2.1',
license='MIT',
description = 'E(n)-Equivariant Graph Neural Network - Pytorch',
author = 'Phil Wang, Eric Alcaide',
Expand Down

0 comments on commit ce278a5

Please sign in to comment.