Skip to content

Commit

Permalink
hotfix
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelstanton committed Feb 29, 2024
1 parent eb93797 commit a69bb63
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ Rather than tack on auxiliary abstractions to a single input --> single task mod

```bash
conda create --name cortex-env python=3.10 -y && conda activate cortex-env
python -m pip install -r requirements.in
pip install -e .
python -m pip install pytorch-cortex
```


Expand Down
5 changes: 2 additions & 3 deletions cortex/acquisition/_graph_nei.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,7 @@ def __init__(
self.has_pointwise_reference = False

def get_objective_vals(self, tree_output: NeuralTreeOutput):
if isinstance(tree_output, NeuralTreeOutput):
tree_output_dict = tree_output_to_dict(tree_output, self.objectives, self.constraints, self.scaling)
tree_output_dict = tree_output_to_dict(tree_output, self.objectives, self.constraints, self.scaling)
return get_joint_objective_values(
tree_output_dict,
self.objectives,
Expand All @@ -219,7 +218,7 @@ def get_objective_vals(self, tree_output: NeuralTreeOutput):
)

def __call__(self, input: NeuralTreeOutput | torch.Tensor, pointwise=True):
if isinstance(input, NeuralTreeOutput):
if not torch.is_tensor(input):
obj_val_samples = self.get_objective_vals(input)

else:
Expand Down
9 changes: 7 additions & 2 deletions tutorials/4_guided_diffusion.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@
"with torch.inference_mode():\n",
" tree_output = model.call_from_str_array(initial_solution, corrupt_frac=0.0)\n",
" init_obj_vals = acq_fn.get_objective_vals(tree_output)\n",
"init_obj_vals"
" init_acq_vals = acq_fn(tree_output)\n",
"print(init_acq_vals)\n",
"print(init_obj_vals)"
]
},
{
Expand Down Expand Up @@ -238,8 +240,11 @@
"sns.set_theme(style=\"whitegrid\", font_scale=1.75)\n",
"\n",
"plt.plot(med_obj_val)\n",
"xlim = plt.xlim()\n",
"plt.hlines(init_acq_vals.median(), *xlim, label=\"Initial Value\", color=\"black\", linestyle=\"--\")\n",
"plt.xlabel(\"Diffusion Iteration\")\n",
"plt.ylabel(\"Median Acq. Value\")"
"plt.ylabel(\"Median Acq. Value\")\n",
"plt.legend(loc=\"center right\")"
]
},
{
Expand Down

0 comments on commit a69bb63

Please sign in to comment.