Skip to content

Commit

Permalink
V1.1.1 (#110)
Browse files Browse the repository at this point in the history
* backup of original notebook

* major update to af_pseudo_diffusion_dgram notebook

* bugfix in recycle protocol

* bugfix (noise was not being decaded)

* updating default settings

* restoring old notebook for now... (still debuggin')

* adding i_cmap output, removing redundant opt (cmap_cutoff)

* Update design.py

* minor edits

* Update model.py

* remove assert to "fix" compatibility issue

collections.Iterable is now collections.abc.Iterable

* Update model.py

* replace np.int with np.int32 dtype to fix numpy error

* bugfix: design_semigreedy() wasn't using input bias when seq_logits was defined

* Update design.py

* Update design.py

* bugfix: bias was not being used in semigreedy protocol

* Update design.py
  • Loading branch information
sokrypton authored Jan 17, 2023
1 parent f3fb796 commit d56087a
Show file tree
Hide file tree
Showing 10 changed files with 459 additions and 69 deletions.
376 changes: 376 additions & 0 deletions af/examples/af_pseudo_diffusion_dgram_old.ipynb

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions af/examples/af_pseudo_diffusion_recycle.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,14 @@
" use_dropout = False\n",
" save_best = True\n",
"\n",
" if k > 0:\n",
" af_model._inputs[\"prev\"][\"prev_msa_first_row\"] *= 0\n",
" af_model._inputs[\"prev\"][\"prev_pos\"] *= 0\n",
"\n",
" # denoise\n",
" aux = af_model.predict(return_aux=True, verbose=False,\n",
" dropout=use_dropout,\n",
" num_recycles=num_recycles)\n",
" af_model._inputs[\"prev\"] = aux[\"prev\"]\n",
" af_model._inputs[\"prev\"][\"prev_msa_first_row\"] *= 0\n",
" af_model._inputs[\"prev\"][\"prev_pos\"] *= 0\n",
"\n",
" # per position confidence\n",
" cmap = aux[\"cmap\"] * (np.abs(offset) > cmap_seqsep)\n",
Expand Down Expand Up @@ -251,8 +251,8 @@
"num_seqs = 16 #@param [\"8\", \"16\", \"32\", \"64\", \"128\", \"256\", \"512\", \"1024\"] {type:\"raw\"}\n",
"sampling_temp = 0.1 \n",
"#@markdown #### AlphaFold Options\n",
"alphafold_model = \"model_3_ptm\" #@param [\"model_1_ptm\", \"model_2_ptm\", \"model_3_ptm\", \"model_4_ptm\", \"model_5_ptm\"]\n",
"num_recycles = 1 #@param [\"0\", \"1\", \"2\", \"3\"] {type:\"raw\"}\n",
"alphafold_model = \"model_4_ptm\" #@param [\"model_1_ptm\", \"model_2_ptm\", \"model_3_ptm\", \"model_4_ptm\", \"model_5_ptm\"]\n",
"num_recycles = 3 #@param [\"0\", \"1\", \"2\", \"3\"] {type:\"raw\"}\n",
"import pandas as pd\n",
"\n",
"# zero out template inputs\n",
Expand Down
4 changes: 2 additions & 2 deletions colabdesign/af/alphafold/common/residue_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,10 +770,10 @@ def _make_rigid_transformation_4x4(ex, ey, translation):
# and an array with (restype, atomtype, coord) for the atom positions
# and compute affine transformation matrices (4,4) from one rigid group to the
# previous group
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int)
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int32)
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int)
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int32)
restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
Expand Down
41 changes: 17 additions & 24 deletions colabdesign/af/alphafold/model/modules_multimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,29 +611,23 @@ def construct_input(query_embedding, template_batch, multichain_mask_2d):
to_concat.append((aatype[None, :, :], 1))
to_concat.append((aatype[:, None, :], 1))

if "template_dgram" in template_batch:
num_res = template_batch["template_aatype"].shape[0]
unit_vector = [jnp.zeros((num_res,num_res))] * 3
backbone_mask_2d = jnp.zeros((num_res,num_res))

else:
# Compute a feature representing the normalized vector between each
# backbone affine - i.e. in each residues local frame, what direction are
# each of the other residues.
raw_atom_pos = template_batch["template_all_atom_positions"]
atom_pos = geometry.Vec3Array.from_array(raw_atom_pos)
rigid, backbone_mask = folding_multimer.make_backbone_affine(
atom_pos,
template_batch["template_all_atom_mask"],
template_batch["template_aatype"])
points = rigid.translation
rigid_vec = rigid[:, None].inverse().apply_to_point(points)
unit_vector = rigid_vec.normalized()
unit_vector = [unit_vector.x, unit_vector.y, unit_vector.z]

backbone_mask_2d = backbone_mask[:, None] * backbone_mask[None, :]
backbone_mask_2d *= multichain_mask_2d
unit_vector = [x*backbone_mask_2d for x in unit_vector]
# Compute a feature representing the normalized vector between each
# backbone affine - i.e. in each residues local frame, what direction are
# each of the other residues.
raw_atom_pos = template_batch["template_all_atom_positions"]
atom_pos = geometry.Vec3Array.from_array(raw_atom_pos)
rigid, backbone_mask = folding_multimer.make_backbone_affine(
atom_pos,
template_batch["template_all_atom_mask"],
template_batch["template_aatype"])
points = rigid.translation
rigid_vec = rigid[:, None].inverse().apply_to_point(points)
unit_vector = rigid_vec.normalized()
unit_vector = [unit_vector.x, unit_vector.y, unit_vector.z]

backbone_mask_2d = jnp.sqrt(backbone_mask[:,None] * backbone_mask[None,:])
backbone_mask_2d *= multichain_mask_2d
unit_vector = [x*backbone_mask_2d for x in unit_vector]

# Note that the backbone_mask takes into account C, CA and N (unlike
# pseudo beta mask which just needs CB) so we add both masks as features.
Expand Down Expand Up @@ -694,7 +688,6 @@ def template_iteration_fn(x):
act)
return act


class TemplateEmbeddingIteration(hk.Module):
"""Single Iteration of Template Embedding."""

Expand Down
2 changes: 0 additions & 2 deletions colabdesign/af/alphafold/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def mask_mean(mask, value, axis=None, drop_mask_channel=False, eps=1e-10):
axis = [axis]
elif axis is None:
axis = list(range(len(mask_shape)))
assert isinstance(axis, collections.abc.Iterable), (
'axis needs to be either an iterable, integer or "None"')

broadcast_factor = 1.
for axis_ in axis:
Expand Down
42 changes: 23 additions & 19 deletions colabdesign/af/design.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,33 +251,33 @@ def _save_results(self, aux=None, save_best=False,
if self._args["best_metric"] in ["plddt","ptm","i_ptm","seqid","composite"] or metric_higher_better:
metric = -metric
if "metric" not in self._tmp["best"] or metric < self._tmp["best"]["metric"]:
self._tmp["best"]["aux"] = aux
self._tmp["best"]["aux"] = copy_dict(aux)
self._tmp["best"]["metric"] = metric

if verbose and ((self._k+1) % verbose) == 0:
self._print_log(f"{self._k+1}", aux=aux)

def predict(self, seq=None,
def predict(self, seq=None, bias=None,
num_models=None, num_recycles=None, models=None, sample_models=False,
dropout=False, hard=True, soft=False, temp=1,
return_aux=False, verbose=True, seed=None, **kwargs):
'''predict structure for input sequence (if provided)'''

def load_settings():
if "save" in self._tmp:
(self.opt, self._args, self._params, self._inputs) = self._tmp.pop("save")
[self.opt, self._args, self._params, self._inputs] = self._tmp.pop("save")

def save_settings():
load_settings()
self._tmp["save"] = (copy_dict(x) for x in [self.opt, self._args, self._params, self._inputs])
self._tmp["save"] = [copy_dict(x) for x in [self.opt, self._args, self._params, self._inputs]]

save_settings()

# set seed if defined
if seed is not None: self.set_seed(seed)

# set [seq]uence/[opt]ions
if seq is not None: self.set_seq(seq=seq)
if seq is not None: self.set_seq(seq=seq, bias=bias)
self.set_opt(hard=hard, soft=soft, temp=temp, dropout=dropout, use_pssm=False)

# run
Expand Down Expand Up @@ -415,13 +415,12 @@ def design_semigreedy(self, iters=100, tries=10, dropout=False,

# get starting sequence
if hasattr(self,"aux"):
seq = self.aux["seq_pseudo"].argmax(-1)
seq = self.aux["seq"]["logits"].argmax(-1)
else:
seq = self._params["seq"].argmax(-1)
seq = (self._params["seq"] + self._inputs["bias"]).argmax(-1)

# bias sampling towards the defined bias
if seq_logits is None:
seq_logits = self._inputs["bias"].copy()
if seq_logits is None: seq_logits = 0

model_flags = {k:kwargs.pop(k,None) for k in ["num_models","sample_models","models"]}
verbose = kwargs.pop("verbose",1)
Expand All @@ -440,16 +439,16 @@ def design_semigreedy(self, iters=100, tries=10, dropout=False,
model_nums = self._get_model_nums(**model_flags)
num_tries = (tries+(e_tries-tries)*((i+1)/iters))
for t in range(int(num_tries)):
mut_seq = self._mutate(seq, plddt, logits=seq_logits)
aux = self.predict(mut_seq, return_aux=True, model_nums=model_nums,
verbose=False, **kwargs)
mut_seq = self._mutate(seq=seq, plddt=plddt,
logits=seq_logits + self._inputs["bias"])
aux = self.predict(seq=mut_seq, return_aux=True, model_nums=model_nums, verbose=False, **kwargs)
buff.append({"aux":aux, "seq":np.array(mut_seq)})

# accept best
losses = [x["aux"]["loss"] for x in buff]
best = buff[np.argmin(losses)]
self.aux, seq = best["aux"], jnp.array(best["seq"])
self.set_seq(seq=seq)
self.set_seq(seq=seq, bias=self._inputs["bias"])
self._save_results(save_best=save_best, verbose=verbose)

# update plddt
Expand Down Expand Up @@ -506,7 +505,8 @@ def _design_mcmc(self, steps=1000, half_life=200, T_init=0.01, mutation_rate=1,

# initialize
plddt, best_loss, current_loss = None, np.inf, np.inf
current_seq = self._params["seq"].argmax(-1)
current_seq = (self._params["seq"] + self._inputs["bias"]).argmax(-1)
if seq_logits is None: seq_logits = 0

# run!
if verbose: print("Running MCMC with simulated annealing...")
Expand All @@ -516,12 +516,16 @@ def _design_mcmc(self, steps=1000, half_life=200, T_init=0.01, mutation_rate=1,
T = T_init * (np.exp(np.log(0.5) / half_life) ** i)

# mutate sequence
if i == 0: mut_seq = current_seq
else: mut_seq = self._mutate(current_seq, plddt, seq_logits, mutation_rate)
if i == 0:
mut_seq = current_seq
else:
mut_seq = self._mutate(seq=current_seq, plddt=plddt,
logits=seq_logits + self._inputs["bias"],
mutation_rate=mutation_rate)

# get loss
model_nums = self._get_model_nums(**model_flags)
aux = self.predict(mut_seq, return_aux=True, verbose=False, model_nums=model_nums, **kwargs)
aux = self.predict(seq=mut_seq, return_aux=True, verbose=False, model_nums=model_nums, **kwargs)
loss = aux["log"]["loss"]

# decide
Expand All @@ -536,5 +540,5 @@ def _design_mcmc(self, steps=1000, half_life=200, T_init=0.01, mutation_rate=1,

if loss < best_loss:
(best_loss, self._k) = (loss, i)
self.set_seq(seq=current_seq)
self._save_results(save_best=save_best, verbose=verbose)
self.set_seq(seq=current_seq, bias=self._inputs["bias"])
self._save_results(save_best=save_best, verbose=verbose)
2 changes: 1 addition & 1 deletion colabdesign/af/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
############################################################################
class _af_inputs:

def _get_seq(self, inputs, aux, key):
def _get_seq(self, inputs, aux, key=None):
params, opt = inputs["params"], inputs["opt"]
'''get sequence features'''
seq = soft_seq(params["seq"], inputs["bias"], opt, key)
Expand Down
20 changes: 10 additions & 10 deletions colabdesign/af/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ def __init__(self, protocol="fixbb", num_seq=1,
recycle_mode="last", num_recycles=0,
use_templates=False, best_metric="loss",
model_names=None, optimizer="sgd", learning_rate=0.1,
use_openfold=False, use_alphafold=True,
use_multimer=False, use_mlm=False,
use_dgram=False,
use_openfold=False, use_alphafold=True, use_multimer=False,
pre_callback=None, post_callback=None,
pre_design_callback=None, post_design_callback=None,
loss_callback=None, traj_iter=1, traj_max=10000, debug=False, data_dir="."):
Expand All @@ -42,11 +40,11 @@ def __init__(self, protocol="fixbb", num_seq=1,
self.protocol = protocol
self._num = num_seq
self._args = {"use_templates":use_templates, "use_multimer":use_multimer,
"recycle_mode":recycle_mode, "use_mlm": use_mlm, "realign": True,
"recycle_mode":recycle_mode, "use_mlm": False, "realign": True,
"debug":debug, "repeat":False, "homooligomer":False, "copies":1,
"optimizer":optimizer, "best_metric":best_metric,
"traj_iter":traj_iter, "traj_max":traj_max,
"clear_prev": True, "use_dgram":use_dgram}
"clear_prev": True, "use_dgram":False, "shuffle_msa":True}

self.opt = {"dropout":True, "use_pssm":False, "learning_rate":learning_rate, "norm_seq_grad":True,
"num_recycles":num_recycles, "num_models":num_models, "sample_models":sample_models,
Expand All @@ -55,7 +53,7 @@ def __init__(self, protocol="fixbb", num_seq=1,
"i_con": {"num":1, "cutoff":21.6875, "binary":False, "num_pos":float("inf")},
"template": {"dropout":0.0, "rm_ic":False},
"weights": {"seq_ent":0.0, "plddt":0.0, "pae":0.0, "exp_res":0.0, "helix":0.0},
"cmap_cutoff": 10.0, "fape_cutoff":10.0}
"fape_cutoff":10.0}

if self._args["use_mlm"]:
self.opt["mlm_dropout"] = 0.0
Expand Down Expand Up @@ -86,7 +84,7 @@ def __init__(self, protocol="fixbb", num_seq=1,
if recycle_mode in ["average","first","last","sample"]: num_recycles = 0
cfg.model.num_recycle = num_recycles
cfg.model.global_config.use_remat = True
cfg.model.global_config.use_dgram = use_dgram
cfg.model.global_config.use_dgram = self._args["use_dgram"]

# setup model
self._cfg = cfg
Expand Down Expand Up @@ -143,7 +141,8 @@ def _model(params, model_params, inputs, key):
L = inputs["aatype"].shape[0]

# get sequence
seq = self._get_seq(inputs, aux, key())
seq_key = key() if a["shuffle_msa"] else None
seq = self._get_seq(inputs, aux, seq_key)

# update sequence features
pssm = jnp.where(opt["use_pssm"], seq["pssm"], seq["pseudo"])
Expand All @@ -159,9 +158,9 @@ def _model(params, model_params, inputs, key):
inputs["seq"] = aux["seq"]

# update template features
inputs["mask_template_interchain"] = opt["template"]["rm_ic"]
if a["use_templates"]:
self._update_template(inputs, key())
inputs["mask_template_interchain"] = opt["template"]["rm_ic"]

# set dropout
inputs["dropout_scale"] = jnp.array(opt["dropout"], dtype=float)
Expand Down Expand Up @@ -191,7 +190,8 @@ def _model(params, model_params, inputs, key):
"pae": get_pae(outputs),
"ptm": get_ptm(inputs, outputs),
"i_ptm": get_ptm(inputs, outputs, interface=True),
"cmap": get_contact_map(outputs, opt["cmap_cutoff"]),
"cmap": get_contact_map(outputs, opt["con"]["cutoff"]),
"i_cmap": get_contact_map(outputs, opt["i_con"]["cutoff"]),
"prev": outputs["prev"]})

#######################################################################
Expand Down
22 changes: 18 additions & 4 deletions colabdesign/mpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,12 @@ def get_af_inputs(self, af):

self._inputs["residue_idx"] = af._inputs["residue_index"]
self._inputs["chain_idx"] = af._inputs["asym_id"]
self._inputs["bias"] = af._inputs["bias"]
self._inputs["lengths"] = np.array(self._lengths)

# set bias
L = sum(self._lengths)
self._inputs["bias"] = np.zeros((L,20))
self._inputs["bias"][-af._len:] = af._inputs["bias"]

if "offset" in af._inputs:
self._inputs["offset"] = af._inputs["offset"]
Expand All @@ -115,10 +119,17 @@ def get_af_inputs(self, af):
self._inputs["mask"] = batch["all_atom_mask"][:,1]
self._inputs["S"] = batch["aatype"]

if "fix_pos" in af.opt:
self._inputs["fix_pos"] = p = af.opt["fix_pos"]
# fix positions
if af.protocol == "binder":
p = np.arange(af._target_len)
else:
p = af.opt.get("fix_pos",None)

if p is not None:
self._inputs["fix_pos"] = p
self._inputs["bias"][p] = 1e7 * np.eye(21)[self._inputs["S"]][p,:20]

# tie positions
if af._args["homooligomer"]:
assert min(self._lengths) == max(self._lengths)
self._tied_lengths = True
Expand Down Expand Up @@ -190,9 +201,12 @@ def score(self, seq=None, **kwargs):
'''score sequence'''
I = copy_dict(self._inputs)
if seq is not None:
p = np.arange(I["S"].shape[0])
if self._tied_lengths and len(seq) == self._lengths[0]:
seq = seq * len(self._lengths)
I["S"] = np.array([aa_order.get(aa,-1) for aa in seq])
if "fix_pos" in I and len(seq) == (I["S"].shape[0] - I["fix_pos"].shape[0]):
p = np.delete(p,I["fix_pos"])
I["S"][p] = np.array([aa_order.get(aa,-1) for aa in seq])
I.update(kwargs)
key = I.pop("key",self.key())
O = jax.tree_map(np.array, self._score(**I, key=key))
Expand Down
9 changes: 7 additions & 2 deletions colabdesign/shared/protein.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,13 @@ def pdb_to_string(pdb_file, chains=None, models=None):
lines = []
seen = []
model = 1
for line in open(pdb_file,"rb"):
line = line.decode("utf-8","ignore").rstrip()

if "\n" in pdb_file:
old_lines = pdb_file.split("\n")
else:
with open(pdb_file,"rb") as f:
old_lines = [line.decode("utf-8","ignore").rstrip() for line in f]
for line in old_lines:
if line[:5] == "MODEL":
model = int(line[5:])
if models is None or model in models:
Expand Down

0 comments on commit d56087a

Please sign in to comment.