Skip to content

Commit

Permalink
create an attention version of equivariant GNN
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 26, 2021
1 parent e5c993c commit 46cc9e4
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 2 deletions.
2 changes: 1 addition & 1 deletion egnn_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from egnn_pytorch.egnn_pytorch import EGNN
from egnn_pytorch.egnn_pytorch import EGNN, EGAT
90 changes: 90 additions & 0 deletions egnn_pytorch/egnn_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
from torch import nn, einsum

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helper functions

Expand Down Expand Up @@ -79,3 +81,91 @@ def forward(self, feats, coors, edges = None):
hidden_out = self.hidden_mlp(hidden_mlp_input)

return hidden_out, coors_out

# attention version

class EGAT(nn.Module):
def __init__(
self,
dim,
edge_dim = 0,
m_dim = 16,
heads = 4,
dim_head = 64,
fourier_features = 0
):
super().__init__()
self.fourier_features = fourier_features

attn_inner_dim = heads * dim_head
self.heads = heads
self.to_qkv = nn.Linear(dim, attn_inner_dim * 3, bias = False)
self.to_out = nn.Linear(attn_inner_dim, dim)

edge_input_dim = (fourier_features * 2) + (dim_head * 2) + edge_dim + 1

self.edge_mlp = nn.Sequential(
nn.Linear(edge_input_dim, edge_input_dim * 2),
nn.ReLU(),
nn.Linear(edge_input_dim * 2, m_dim)
)

self.to_attn_mlp = nn.Sequential(
nn.Linear(m_dim, m_dim * 4),
nn.ReLU(),
nn.Linear(m_dim * 4, 1),
Rearrange('... () -> ...')
)

self.coors_mlp = nn.Sequential(
nn.Linear(m_dim, m_dim * 4),
nn.ReLU(),
nn.Linear(m_dim * 4, 1),
Rearrange('... () -> ...')
)

def forward(self, feats, coors, edges = None):
b, n, d, h, fourier_features = *feats.shape, self.heads, self.fourier_features

rel_coors = rearrange(coors, 'b i d -> b i () d') - rearrange(coors, 'b j d -> b () j d')
rel_dist = rel_coors.norm(dim = -1, keepdim = True)

if fourier_features > 0:
rel_dist = fourier_encode_dist(rel_dist, num_encodings = fourier_features)
rel_dist = rearrange(rel_dist, 'b i j () d -> b i j d')

rel_dist = repeat(rel_dist, 'b i j d -> b h i j d', h = h)

# derive queries keys and values

q, k, v = self.to_qkv(feats).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

# expand queries and keys for concatting

q = repeat(q, 'b h i d -> b h i n d', n = n)
k = repeat(k, 'b h j d -> b h n j d', n = n)

edge_input = torch.cat((q, k, rel_dist), dim = -1)

if exists(edges):
edges = repeat(edges, 'b i j d -> b h i j d', h = h)
edge_input = torch.cat((edge_input, edges), dim = -1)

m_ij = self.edge_mlp(edge_input)

coor_weights = self.coors_mlp(m_ij)
coors_out = einsum('b h i j, b i j c -> b i c', coor_weights, rel_coors)

# derive attention

sim = self.to_attn_mlp(m_ij)
attn = sim.softmax(dim = -1)

# weighted sum of values and combine heads

out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)

return out, coors_out
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'egnn-pytorch',
packages = find_packages(),
version = '0.0.4',
version = '0.0.5',
license='MIT',
description = 'E(n)-Equivariant Graph Neural Network - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 46cc9e4

Please sign in to comment.