Skip to content

Commit

Permalink
adding "seq" to inputs (for custom loss)
Browse files Browse the repository at this point in the history
  • Loading branch information
sokrypton authored Aug 19, 2022
1 parent 2a508c5 commit 8632eb7
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions colabdesign/af/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ def _model(params, model_params, inputs, key, opt):
batch.update(all_atom.atom37_to_frames(**batch))
else:
batch = None

inputs["batch"] = batch

#######################################################################
# OUTPUTS
#######################################################################
Expand All @@ -174,20 +177,22 @@ def _model(params, model_params, inputs, key, opt):
aux["pae"] = jnp.full((L,L),jnp.nan).at[p[:,None],p[None,:]].set(aux["pae"])

if self._args["recycle_mode"] == "average": aux["prev"] = outputs["prev"]

#######################################################################
# LOSS
#######################################################################
inputs["batch"] = batch
if self._args["debug"]: aux["debug"] = {"inputs":inputs, "outputs":outputs, "opt":opt}

aux["losses"] = {}
self._get_loss(inputs=inputs, outputs=outputs, opt=opt, aux=aux)

inputs["seq"] = aux["seq"]
if self._loss_callback is not None:
loss_fns = self._loss_callback if isinstance(self._loss_callback,list) else [self._loss_callback]
for loss_fn in loss_fns:
aux["losses"].update(loss_fn(inputs, outputs, opt))

if self._args["debug"]:
aux["debug"] = {"inputs":inputs, "outputs":outputs, "opt":opt}

# weighted loss
w = opt["weights"]
Expand Down

0 comments on commit 8632eb7

Please sign in to comment.