From a69bb634b97328bc68396e0f4c1ee72c829e9928 Mon Sep 17 00:00:00 2001 From: Samuel Stanton Date: Thu, 29 Feb 2024 08:49:51 -0500 Subject: [PATCH] hotfix --- README.md | 3 +-- cortex/acquisition/_graph_nei.py | 5 ++--- tutorials/4_guided_diffusion.ipynb | 9 +++++++-- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index c5792a9..0004be2 100644 --- a/README.md +++ b/README.md @@ -17,8 +17,7 @@ Rather than tack on auxiliary abstractions to a single input --> single task mod ```bash conda create --name cortex-env python=3.10 -y && conda activate cortex-env -python -m pip install -r requirements.in -pip install -e . +python -m pip install pytorch-cortex ``` diff --git a/cortex/acquisition/_graph_nei.py b/cortex/acquisition/_graph_nei.py index 52a31ed..def5b64 100644 --- a/cortex/acquisition/_graph_nei.py +++ b/cortex/acquisition/_graph_nei.py @@ -209,8 +209,7 @@ def __init__( self.has_pointwise_reference = False def get_objective_vals(self, tree_output: NeuralTreeOutput): - if isinstance(tree_output, NeuralTreeOutput): - tree_output_dict = tree_output_to_dict(tree_output, self.objectives, self.constraints, self.scaling) + tree_output_dict = tree_output_to_dict(tree_output, self.objectives, self.constraints, self.scaling) return get_joint_objective_values( tree_output_dict, self.objectives, @@ -219,7 +218,7 @@ def get_objective_vals(self, tree_output: NeuralTreeOutput): ) def __call__(self, input: NeuralTreeOutput | torch.Tensor, pointwise=True): - if isinstance(input, NeuralTreeOutput): + if not torch.is_tensor(input): obj_val_samples = self.get_objective_vals(input) else: diff --git a/tutorials/4_guided_diffusion.ipynb b/tutorials/4_guided_diffusion.ipynb index bc65484..c57a582 100644 --- a/tutorials/4_guided_diffusion.ipynb +++ b/tutorials/4_guided_diffusion.ipynb @@ -172,7 +172,9 @@ "with torch.inference_mode():\n", " tree_output = model.call_from_str_array(initial_solution, corrupt_frac=0.0)\n", " init_obj_vals = acq_fn.get_objective_vals(tree_output)\n", - "init_obj_vals" + " init_acq_vals = acq_fn(tree_output)\n", + "print(init_acq_vals)\n", + "print(init_obj_vals)" ] }, { @@ -238,8 +240,11 @@ "sns.set_theme(style=\"whitegrid\", font_scale=1.75)\n", "\n", "plt.plot(med_obj_val)\n", + "xlim = plt.xlim()\n", + "plt.hlines(init_acq_vals.median(), *xlim, label=\"Initial Value\", color=\"black\", linestyle=\"--\")\n", "plt.xlabel(\"Diffusion Iteration\")\n", - "plt.ylabel(\"Median Acq. Value\")" + "plt.ylabel(\"Median Acq. Value\")\n", + "plt.legend(loc=\"center right\")" ] }, {