Skip to content

Commit

Permalink
structural refinement must be carried out in float64
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 1, 2021
1 parent 63fe134 commit 5e3be98
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 22 deletions.
50 changes: 29 additions & 21 deletions alphafold2_pytorch/alphafold2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from alphafold2_pytorch.reversible import ReversibleSequence

from se3_transformer_pytorch import SE3Transformer
from se3_transformer_pytorch.utils import torch_default_dtype

# helpers

Expand Down Expand Up @@ -450,19 +451,20 @@ def __init__(
self.mds_iters = mds_iters
self.structure_module_refinement_iters = structure_module_refinement_iters

self.structure_module_embeds = nn.Embedding(num_tokens, structure_module_dim)
self.to_refined_coords_delta = nn.Linear(structure_module_dim, 1)

self.structure_module = SE3Transformer(
dim = structure_module_dim,
depth = structure_module_depth,
input_degrees = 1,
num_degrees = 3,
output_degrees = 2,
heads = structure_module_heads,
num_neighbors = structure_module_knn,
differentiable_coors = True
)
with torch_default_dtype(torch.float64):
self.structure_module_embeds = nn.Embedding(num_tokens, structure_module_dim)
self.to_refined_coords_delta = nn.Linear(structure_module_dim, 1)

self.structure_module = SE3Transformer(
dim = structure_module_dim,
depth = structure_module_depth,
input_degrees = 1,
num_degrees = 3,
output_degrees = 2,
heads = structure_module_heads,
num_neighbors = structure_module_knn,
differentiable_coors = True
)

def forward(
self,
Expand Down Expand Up @@ -651,15 +653,21 @@ def forward(
coords = torch.cat(coords, dim = 0)
x = self.structure_module_embeds(seq)

for _ in range(self.structure_module_refinement_iters):
output = self.structure_module(x, coords, mask = mask)
x, refined_coords = output['0'], output['1']
original_dtype = coords.dtype

x, coords = map(lambda t: t.double(), (x, coords))

with torch_default_dtype(torch.float64):
for _ in range(self.structure_module_refinement_iters):
output = self.structure_module(x, coords, mask = mask)
x, refined_coords = output['0'], output['1']

refined_coords = rearrange(refined_coords, 'b n d c -> b n c d')
refined_coords = self.to_refined_coords_delta(refined_coords)
refined_coords = rearrange(refined_coords, 'b n c () -> b n c')
refined_coords = rearrange(refined_coords, 'b n d c -> b n c d')
refined_coords = self.to_refined_coords_delta(refined_coords)
refined_coords = rearrange(refined_coords, 'b n c () -> b n c')

x = rearrange(x, 'b n c () -> b n c')
coords = coords + refined_coords
x = rearrange(x, 'b n c () -> b n c')
coords = coords + refined_coords

coords.type(original_dtype)
return coords
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'alphafold2-pytorch',
packages = find_packages(),
version = '0.0.34',
version = '0.0.35',
license='MIT',
description = 'AlphaFold2 - Pytorch',
author = 'Phil Wang, Eric Alcaide',
Expand Down

0 comments on commit 5e3be98

Please sign in to comment.