Skip to content

Commit

Permalink
Created using Colaboratory
Browse files Browse the repository at this point in the history
  • Loading branch information
sokrypton committed Oct 5, 2022
1 parent 4e96536 commit 694c00b
Showing 1 changed file with 112 additions and 71 deletions.
183 changes: 112 additions & 71 deletions mpnn/examples/afdesign_and_proteinmpnn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "-AXy0s_4cKaK"
},
"outputs": [],
Expand Down Expand Up @@ -84,66 +83,83 @@
" os.system(f\"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb\")\n",
" return f\"{pdb_code}.pdb\"\n",
"\n",
"def setup_mpnn(self, precompute=True, entropy=False, backprop=False):\n",
" mpnn_atom_idx = tuple(residue_constants.atom_order[k] for k in [\"N\",\"CA\",\"C\",\"O\"])\n",
" \n",
" def loss_callback(inputs, aux, opt, key):\n",
"class setup_mpnn:\n",
" def __init__(self, af_model, precompute=True, conditional=False, replace=0.01):\n",
" self.af = af_model\n",
" self.mpnn = mk_mpnn_model()\n",
" self.atom_idx = tuple(residue_constants.atom_order[k] for k in [\"N\",\"CA\",\"C\",\"O\"])\n",
" self.replace = replace\n",
" self.conditional = conditional\n",
"\n",
" if precompute:\n",
" logits = opt[\"mpnn\"]\n",
" \n",
" self.precompute()\n",
" else:\n",
" I = {\"X\": aux[\"atom_positions\"][None,:,mpnn_atom_idx],\n",
" \"mask\": aux[\"atom_mask\"][None,:,1],\n",
" \"residue_idx\": inputs[\"residue_index\"][None],\n",
" \"chain_idx\": inputs[\"asym_id\"][None],\n",
" \"key\": key} \n",
" \n",
" if self.protocol == \"binder\":\n",
" L = self._target_len\n",
" logits = mk_mpnn_model().get_logits(**I)[0,L:]\n",
" else:\n",
" L = self._params[\"seq\"].shape[1]\n",
" logits = mk_mpnn_model().get_logits(**I)[0,:L]\n",
" self.af._callbacks[\"design\"][\"post\"].append(self._design_callback)\n",
"\n",
" logits = aux[\"mpnn\"] = logits if backprop else jax.lax.stop_gradient(logits)\n",
" self.af._callbacks[\"model\"][\"loss\"].append(self._loss_callback)\n",
" self.af.opt[\"weights\"][\"mpnn_loss\"] = 1.0\n",
" self.af.opt[\"weights\"][\"mpnn_ent\"] = 0.0\n",
"\n",
" # define loss function\n",
" log_q = jax.nn.log_softmax(logits)\n",
" if entropy:\n",
" # maximize entropy of mpnn output (aka increase confidence of mpnn)\n",
" q = jax.nn.softmax(logits)\n",
" mpnn_loss = -(q * log_q).sum(-1).mean()\n",
" else:\n",
" # minimize similarity to mpnn output\n",
" p = inputs[\"seq\"][\"soft\"]\n",
" mpnn_loss = -(p * log_q).sum(-1).mean()\n",
" def run(self, seq, atom_positions, atom_mask, residue_index, asym_id, key):\n",
" # INPUTS\n",
" I = {\"X\": atom_positions[None,:,self.atom_idx],\n",
" \"mask\": atom_mask[None,:,1],\n",
" \"residue_idx\": residue_index[None],\n",
" \"chain_idx\": asym_id[None],\n",
" \"key\": key}\n",
" if self.conditional:\n",
" I[\"S\"] = seq[:1]\n",
" I[\"ar_mask\"] = 1 - np.eye(I[\"S\"].shape[1])[None]\n",
" if self.af.protocol == \"binder\":\n",
" L = self.af._target_len\n",
" I[\"ar_mask\"][:,L:,L:] = 0\n",
"\n",
" return {\"mpnn_loss\":mpnn_loss} \n",
" \n",
" if precompute:\n",
" inputs = self._inputs\n",
" I = {\"X\": inputs[\"batch\"][\"all_atom_positions\"][None,:,mpnn_atom_idx],\n",
" \"mask\": inputs[\"batch\"][\"all_atom_mask\"][None,:,1],\n",
" \"residue_idx\": inputs[\"residue_index\"][None],\n",
" \"chain_idx\": inputs[\"asym_id\"][None],\n",
" \"key\": self.key()}\n",
" if self.protocol == \"binder\":\n",
" L = self._target_len\n",
" logits = mk_mpnn_model().get_logits(**I)[0,L:]\n",
" # RUN \n",
" logits = self.mpnn.get_logits(**I)[0]\n",
" \n",
" # OUTPUTS\n",
" if self.af.protocol == \"binder\":\n",
" L = self.af._target_len\n",
" logits = logits[L:]\n",
" else:\n",
" L = self._params[\"seq\"].shape[1]\n",
" logits = mk_mpnn_model().get_logits(**I)[0,:L]\n",
" L = self.af._params[\"seq\"].shape[1]\n",
" logits = logits[:L]\n",
" return logits\n",
"\n",
" logits = np.asarray(logits)\n",
" self.opt[\"mpnn\"] = logits\n",
" def precompute(self):\n",
" inputs = self.af._inputs\n",
" logits = self.run(inputs[\"batch\"][\"aatype\"],\n",
" inputs[\"batch\"][\"all_atom_positions\"],\n",
" inputs[\"batch\"][\"all_atom_mask\"],\n",
" inputs[\"residue_index\"],\n",
" inputs[\"asym_id\"],\n",
" self.af.key())\n",
" self.af.opt[\"mpnn\"] = self.logits = logits\n",
" \n",
" def _design_callback(self, af_model):\n",
" self.logits = af_model.aux[\"mpnn\"]\n",
" af_model._inputs[\"bias\"] = (1-self.replace) * af_model._inputs[\"bias\"] + self.replace * af_model.aux[\"mpnn\"]\n",
"\n",
" else:\n",
" def design_callback(self):\n",
" self._inputs[\"bias\"] = 0.99 * self._inputs[\"bias\"] + 0.01 * self.aux[\"mpnn\"]\n",
" self._callbacks[\"design\"][\"post\"].append(design_callback)\n",
" def _loss_callback(self, inputs, aux, opt, seq, key):\n",
" if \"mpnn\" in opt:\n",
" logits = opt[\"mpnn\"]\n",
" else:\n",
" logits = self.run(seq[\"hard\"],\n",
" aux[\"atom_positions\"],\n",
" aux[\"atom_mask\"],\n",
" inputs[\"residue_index\"],\n",
" inputs[\"asym_id\"],\n",
" key) \n",
" aux[\"mpnn\"] = logits\n",
"\n",
" self._callbacks[\"model\"][\"loss\"].append(loss_callback)\n",
" self.opt[\"weights\"][\"mpnn_loss\"] = 1.0"
" # define loss function\n",
" log_q = jax.nn.log_softmax(logits)\n",
" q = jax.nn.softmax(logits)\n",
" p = inputs[\"seq\"][\"soft\"]\n",
" losses = {}\n",
" losses[\"mpnn_ent\"] = -(q * log_q).sum(-1).mean()\n",
" losses[\"mpnn_loss\"] = -(p * log_q).sum(-1).mean()\n",
" return losses"
]
},
{
Expand All @@ -165,10 +181,9 @@
"outputs": [],
"source": [
"clear_mem()\n",
"mpnn_model = mk_mpnn_model()\n",
"af_model = mk_af_model(protocol=\"fixbb\")\n",
"af_model.prep_inputs(pdb_filename=get_pdb(\"1TEN\"), chain=\"A\")\n",
"setup_mpnn(af_model, precompute=True)\n",
"mpnn_model = setup_mpnn(af_model, precompute=True)\n",
"\n",
"print(\"length\", af_model._len)\n",
"print(\"weights\", af_model.opt[\"weights\"])"
Expand All @@ -178,8 +193,8 @@
"cell_type": "code",
"source": [
"# precompute unconditional probabilities from mpnn\n",
"print(\"max_mpnn_loss\",-np.log(softmax(af_model.opt[\"mpnn\"],-1)).max(-1).mean())\n",
"plt.imshow(softmax(af_model.opt[\"mpnn\"],-1).T,vmin=0,vmax=1)"
"print(\"max_mpnn_loss\",-np.log(softmax(mpnn_model.logits,-1)).max(-1).mean())\n",
"plt.imshow(softmax(mpnn_model.logits,-1).T,vmin=0,vmax=1)"
],
"metadata": {
"id": "dm4AJIrU2VGD"
Expand All @@ -191,7 +206,7 @@
"cell_type": "code",
"source": [
"af_model.restart()\n",
"af_model.set_seq(bias=af_model.opt[\"mpnn\"])\n",
"af_model.set_seq(bias=mpnn_model.logits)\n",
"af_model.set_weights(mpnn_loss=0.1)\n",
"af_model.design_3stage(0,200,10)"
],
Expand Down Expand Up @@ -256,6 +271,24 @@
"id": "qLwS2s_xcjRI"
}
},
{
"cell_type": "code",
"source": [
"def rg_loss(inputs, outputs):\n",
" positions = outputs[\"structure_module\"][\"final_atom_positions\"]\n",
" ca = positions[:,residue_constants.atom_order[\"CA\"]]\n",
" center = ca.mean(0)\n",
" rg = jnp.sqrt(jnp.square(ca - center).sum(-1).mean() + 1e-8)\n",
" rg_th = 2.38 * ca.shape[0] ** 0.365\n",
" rg = jax.nn.elu(rg - rg_th)\n",
" return {\"rg\":rg}"
],
"metadata": {
"id": "pGVwRYqHqdb5"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -265,9 +298,12 @@
"outputs": [],
"source": [
"clear_mem()\n",
"af_model = mk_af_model(protocol=\"hallucination\")\n",
"af_model = mk_af_model(protocol=\"hallucination\",\n",
" loss_callback=rg_loss) # add custom Radius of Gyration loss\n",
"af_model.prep_inputs(length=100)\n",
"setup_mpnn(af_model, precompute=False)\n",
"af_model.opt[\"weights\"][\"rg_loss\"] = 0.1\n",
"mpnn_model = setup_mpnn(af_model, precompute=False, backprop=True)\n",
"mpnn_model.replace = 0.01 # rate at which to copy output mpnn logits to alphafold bias\n",
"\n",
"print(\"length\",af_model._len)\n",
"print(\"weights\",af_model.opt[\"weights\"])"
Expand All @@ -279,19 +315,22 @@
"# pre-design with gumbel initialization and softmax activation\n",
"af_model.restart()\n",
"af_model.set_seq(mode=\"gumbel\")\n",
"af_model.set_weights(mpnn_loss=0.01)\n",
"af_model.design_soft(100)"
"af_model.set_weights(mpnn_ent=0.1, # maximize confidence of mpnn output\n",
" mpnn_loss=0.01, # minimize difference between mpnn output and input sequence\n",
" helix=-0.1, # encourage non-helical content\n",
" ) \n",
"af_model.design_soft(100, verbose=10)"
],
"metadata": {
"id": "BjfCnSkFdtud"
"id": "Snim_g3ydgBX"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# lets see what the PDB looks like (if you don't like rerun the cell before)\n",
"# lets see what the PDB looks like (if you don't like, rerun the cell before)\n",
"af_model.plot_pdb()"
],
"metadata": {
Expand All @@ -303,11 +342,13 @@
{
"cell_type": "code",
"source": [
"af_model.set_weights(mpnn_loss=0.1)\n",
"af_model.design_3stage(0, 100, 10)"
"# refinement round!\n",
"af_model.set_seq(seq=af_model.aux[\"seq\"][\"pseudo\"])\n",
"af_model.set_weights(mpnn_ent=1.0, mpnn_loss=1.0, helix=0, pae=0.1) # increase mpnn weights\n",
"af_model.design_3stage(100, 100, 10)"
],
"metadata": {
"id": "_WcwTwYc1lsB"
"id": "iKOcI_lxecYb"
},
"execution_count": null,
"outputs": []
Expand All @@ -316,10 +357,10 @@
"cell_type": "code",
"source": [
"af_model.save_pdb(f\"{af_model.protocol}.pdb\")\n",
"af_model.plot_pdb()"
"af_model.plot_pdb(color=\"pLDDT\")"
],
"metadata": {
"id": "BkP5Jcqyo9wO"
"id": "qe9c1W6ydgD-"
},
"execution_count": null,
"outputs": []
Expand Down Expand Up @@ -364,7 +405,7 @@
"mpnn_model = mk_mpnn_model()\n",
"af_model = mk_af_model(protocol=\"binder\")\n",
"af_model.prep_inputs(pdb_filename=get_pdb(\"4MZK\"), chain=\"A\", binder_len=18)\n",
"setup_mpnn(af_model, precompute=False)\n",
"setup_mpnn(af_model, precompute=False, conditional=False)\n",
"\n",
"print(\"target_length\",af_model._target_len)\n",
"print(\"binder_length\",af_model._binder_len)\n",
Expand All @@ -380,9 +421,9 @@
"cell_type": "code",
"source": [
"af_model.restart()\n",
"af_model.set_weights(mpnn_loss=0.01)\n",
"af_model.set_weights(mpnn_loss=0.01, mpnn_ent=0.01)\n",
"af_model.design_3stage(100,0,0)\n",
"af_model.set_weights(mpnn_loss=0.1)\n",
"af_model.set_weights(mpnn_loss=0.1, mpnn_ent=0.1)\n",
"af_model.design_3stage(0,100,10)"
],
"metadata": {
Expand Down

0 comments on commit 694c00b

Please sign in to comment.