From 3e5c8b62de621c4630f4335e7b25c9acde7e35c8 Mon Sep 17 00:00:00 2001 From: Samuel Stanton Date: Tue, 13 Feb 2024 13:12:07 -0500 Subject: [PATCH 1/9] add functional inference API --- cortex/model/__init__.py | 2 + cortex/model/_infer_with_model.py | 70 +++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) create mode 100644 cortex/model/_infer_with_model.py diff --git a/cortex/model/__init__.py b/cortex/model/__init__.py index 2395843..9d48108 100644 --- a/cortex/model/__init__.py +++ b/cortex/model/__init__.py @@ -1,5 +1,7 @@ +from ._infer_with_model import infer_with_model from ._weight_averaging import online_weight_update_ __all__ = [ + "infer_with_model", "online_weight_update_", ] diff --git a/cortex/model/_infer_with_model.py b/cortex/model/_infer_with_model.py new file mode 100644 index 0000000..e6bf42e --- /dev/null +++ b/cortex/model/_infer_with_model.py @@ -0,0 +1,70 @@ +from typing import Optional + +import numpy as np +import pandas as pd +import torch + +from cortex.io import load_hydra_config, load_model_checkpoint + + +def infer_with_model( + data: pd.DataFrame, + cfg_fpath: str, + weight_fpath: str, + batch_limit: int = 32, + cpu_offload: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +) -> dict[str, np.ndarray]: + """ + A functional interface for inference with a cortex model. + + Usage: + + ```python title="Example of inference with a cortex model checkpoint." + from cortex.model import infer_with_model + + ckpt_dir = + ckpt_name = + predictions = infer_with_model( + data=df, + cfg_fpath=f"{ckpt_dir}/{ckpt_name}.yaml", + weight_fpath=f"{ckpt_dir}/{ckpt_name}.pt", + ) + ``` + + Args: + data (pd.DataFrame): A dataframe containing the sequences to predict on. + cfg_fpath (str): The path to the Hydra config file on S3. + weight_fpath (str): The path to the PyTorch model weights on S3. + batch_limit (int, optional): The maximum number of sequences to predict on at once. Defaults to 32. + cpu_offload (bool, optional): Whether to use cpu offload. + If true, will run prediction with cpu offload. Defaults to True + device (torch.device, optional): The device to run the model on. Defaults to None. + dtype (torch.dtype, optional): The dtype to run the model on. Defaults to None. + + Returns: + dict[str, np.ndarray]: A dict of NumPy arrays of the predictions. + """ + # set default device and dtype + if device is None: + device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + if dtype is None and torch.cuda.is_available(): + dtype = torch.float32 + elif dtype is None: + dtype = torch.float64 + + # load Hydra config from s3 or locally + cfg = load_hydra_config(cfg_fpath) + + # load model checkpoint from s3 or locally + model, _ = load_model_checkpoint(cfg, weight_fpath, device=device, dtype=dtype) + + # model forward pass + with torch.inference_mode(): + return model.predict( + data=data, + batch_limit=batch_limit, + predict_tasks=None, + cpu_offload=cpu_offload, + ) From 9e17e6b71a963da7636dd5c0a2e9c0bb93c620e0 Mon Sep 17 00:00:00 2001 From: Samuel Stanton Date: Tue, 13 Feb 2024 20:44:57 -0500 Subject: [PATCH 2/9] add GraphNEI acquisition fn --- cortex/acquisition/__init__.py | 7 ++ cortex/acquisition/_graph_nei.py | 150 +++++++++++++++++++++++++++++++ cortex/model/tree/__init__.py | 3 +- 3 files changed, 159 insertions(+), 1 deletion(-) create mode 100644 cortex/acquisition/_graph_nei.py diff --git a/cortex/acquisition/__init__.py b/cortex/acquisition/__init__.py index e69de29..449f0aa 100644 --- a/cortex/acquisition/__init__.py +++ b/cortex/acquisition/__init__.py @@ -0,0 +1,7 @@ +from ._graph_nei import GraphNEI, get_graph_nei_runtime_kwargs, get_obj_vals + +__all__ = [ + "get_graph_nei_runtime_kwargs", + "get_obj_vals", + "GraphNEI", +] diff --git a/cortex/acquisition/_graph_nei.py b/cortex/acquisition/_graph_nei.py new file mode 100644 index 0000000..3d1fa93 --- /dev/null +++ b/cortex/acquisition/_graph_nei.py @@ -0,0 +1,150 @@ +import numpy as np +import torch +from botorch.acquisition.logei import qLogExpectedImprovement +from botorch.acquisition.multi_objective.logei import \ + qLogExpectedHypervolumeImprovement +from botorch.acquisition.objective import IdentityMCObjective +from botorch.utils.multi_objective.box_decompositions import \ + FastNondominatedPartitioning +from botorch.utils.multi_objective.hypervolume import infer_reference_point +from botorch.utils.multi_objective.pareto import is_non_dominated + +from cortex.model.tree import NeuralTree, NeuralTreeOutput, fetch_task_outputs + +GRAPH_OBJECTIVES = ["stability", "log_fluorescence"] +GRAPH_CONSTRAINTS = {} +# rescale stability and log_fluorescence to [0, 1] +GRAPH_OBJ_TRANSFORM = { + "stability": {"scale": 1 / 2.0, "shift": 2.0}, + "log_fluorescence": {"scale": 1 / 7.0, "shift": -4.0}, +} + + +def get_obj_vals( + tree_output: NeuralTreeOutput, + objectives: list[str], + constraints: dict[str, list[str]], + scaling: dict[str, dict[str, float]], +): + res = {} + obj_vals = [] + for obj in objectives: + obj_outputs = fetch_task_outputs(tree_output, obj) + pred_means = obj_outputs["loc"].squeeze(-1) + scaled_pred_means = (pred_means + scaling[obj]["shift"]) * scaling[obj]["scale"] + obj_vals.append(scaled_pred_means) + res[obj] = pred_means + res[f"{obj}_scaled"] = scaled_pred_means + obj_vals = torch.stack(obj_vals, dim=-1) # (num_samples, num_objectives) + + if len(constraints) == 0: + constraint_vals = torch.ones_like(obj_vals) + else: + constraint_vals = [] + for obj in constraints: + _current = [] + for const in constraints[obj]: + const_outputs = fetch_task_outputs(tree_output, const) + satisfied_prob = const_outputs["logits"].softmax(-1)[..., 1] + _current.append(satisfied_prob) + res[const] = satisfied_prob + constraint_vals.append(torch.stack(_current, dim=-1).prod(-1)) + constraint_vals = torch.stack(constraint_vals, dim=-1) # (num_samples, num_objectives) + + res["joint"] = obj_vals * constraint_vals + + return res + + +def get_graph_nei_runtime_kwargs( + model: NeuralTree, + candidate_points: np.ndarray, + objectives: list[str] = GRAPH_OBJECTIVES, + constraints: dict[str, list[str]] = GRAPH_CONSTRAINTS, + scaling: dict[str, dict[str, float]] = GRAPH_OBJ_TRANSFORM, +): + with torch.inference_mode(): + tree_output = model.call_from_str_array(candidate_points, corrupt_frac=0.0) + seed_preds = get_obj_vals( + tree_output, + objectives=objectives, + constraints=constraints, + scaling=scaling, + ) + + print("==== constructing acquisition function ====") + f_baseline = seed_preds["joint"] # (num_samples, num_baseline, num_objectives) + f_baseline_flat = f_baseline.reshape(-1, len(objectives)) + f_baseline_non_dom = f_baseline_flat[is_non_dominated(f_baseline_flat)] + print(f_baseline_non_dom) + f_ref = infer_reference_point(f_baseline_non_dom) + print(f"reference point: {f_ref}") + res = { + "f_ref": f_ref, + "f_baseline": f_baseline, + } + return res + + +class GraphNEI(object): + def __init__( + self, + objectives: list, + constraints: dict, + scaling: dict, + f_ref: torch.Tensor, # (num_objectives,) + f_baseline: torch.Tensor, # (num_samples, num_baseline, num_objectives) + ) -> None: + """ + Very simple implementation of PropertyDAG + NEHVI + """ + self.objectives = objectives + self.constraints = constraints + self.scaling = scaling + + f_non_dom = [] + for f in f_baseline: + f_non_dom.append(f[is_non_dominated(f)]) + + self._obj_dim = len(objectives) + if self._obj_dim == 1: + f_best = f_baseline.max(dim=-2).values.squeeze(-1) + self.acq_functions = [ + qLogExpectedImprovement( + model=None, + best_f=f, + objective=IdentityMCObjective(), + ) + for f in f_best + ] + else: + self.acq_functions = [ + qLogExpectedHypervolumeImprovement( + model=None, + ref_point=f_ref, + partitioning=FastNondominatedPartitioning(f_ref, f), + ) + for f in f_non_dom + ] + self.has_pointwise_reference = False + + def get_objective_vals(self, tree_output): + return get_obj_vals(tree_output, self.objectives, self.constraints, self.scaling) + + def __call__(self, tree_output, pointwise=True): + obj_vals = self.get_objective_vals(tree_output)["joint"] + + if pointwise: + obj_vals = obj_vals.unsqueeze(-2) # (num_samples, num_designs, 1, num_objectives) + + # assumes the first dimension of obj_vals corresponds to the qEHVI partitions + if self._obj_dim == 1: + acq_vals = torch.stack( + [fn._sample_forward(vals) for fn, vals in zip(self.acq_functions, obj_vals.squeeze(-1))] + ).squeeze(-1) + else: + acq_vals = torch.stack( + [fn._compute_log_qehvi(vals.unsqueeze(0)) for fn, vals in zip(self.acq_functions, obj_vals)] + ) + + return acq_vals.mean(0) diff --git a/cortex/model/tree/__init__.py b/cortex/model/tree/__init__.py index 553eb27..b998e19 100644 --- a/cortex/model/tree/__init__.py +++ b/cortex/model/tree/__init__.py @@ -1,7 +1,8 @@ -from ._abstract_tree import NeuralTree, NeuralTreeOutput +from ._abstract_tree import NeuralTree, NeuralTreeOutput, fetch_task_outputs from ._seq_model_tree import SequenceModelTree __all__ = [ + "fetch_task_outputs", "NeuralTree", "NeuralTreeOutput", "SequenceModelTree", From 663b578e66d32d3d3a6ad5c17bed840b3bb74f39 Mon Sep 17 00:00:00 2001 From: Samuel Stanton Date: Wed, 14 Feb 2024 16:58:01 -0500 Subject: [PATCH 3/9] fix failing test --- cortex/data/dataset/_tape_fluorescence.py | 10 ---------- tests/cortex/data/dataset/test_rfp_dataset.py | 11 +---------- 2 files changed, 1 insertion(+), 20 deletions(-) diff --git a/cortex/data/dataset/_tape_fluorescence.py b/cortex/data/dataset/_tape_fluorescence.py index b74030d..21b25f8 100644 --- a/cortex/data/dataset/_tape_fluorescence.py +++ b/cortex/data/dataset/_tape_fluorescence.py @@ -49,14 +49,4 @@ def _read_data(self, path: str, dedup: bool, train: bool, random_seed: int, **kw data.loc[:, "log_fluorescence"] = np.array([val[0] for val in data["log_fluorescence"].values]) data = tokenize_gfp_df(data) - # import pdb; pdb.set_trace() - - # if self.columns is None: - # self.columns = list(data.columns) - - # if dedup: - # data.drop_duplicates(inplace=True, ignore_index=True) - - # import pdb; pdb.set_trace() - return data diff --git a/tests/cortex/data/dataset/test_rfp_dataset.py b/tests/cortex/data/dataset/test_rfp_dataset.py index 28d1caa..0865089 100644 --- a/tests/cortex/data/dataset/test_rfp_dataset.py +++ b/tests/cortex/data/dataset/test_rfp_dataset.py @@ -7,16 +7,7 @@ def test_rfp_dataset(): # make temp root dir root = "./temp/" os.makedirs(root, exist_ok=True) - dataset = RedFluorescentProteinDataset( + _ = RedFluorescentProteinDataset( root=root, download=True, ) - - item = dataset[0] - - assert ( - item["tokenized_seq"].replace(" ", "") - == "LSKHGLTKDMTMKYRMEGCVDGHKFVITGHGNGSPFEGKQTINLCVVEGGPLPFSEDILSAVFNRVFTDYPQGMVDFFKNSCPAGYTWQRSLLFEDGAVCTASADITVSVEENCFYHESKFHGVNFPADGPVMKKMTINWEPCCEKIIPVPRQGILKGDVAMYLLLKDGGRYRCQFDTVYKAKTDSKKMPEWHFIQHKLTREDRSDAKNQKWQLAEHSVASRSALA" - ) - assert item["foldx_total_energy"] == -39.8155 - assert item["SASA"] == 11189.00587945787 From e571697096c5c46b1213caaf89deb4574ac65f1a Mon Sep 17 00:00:00 2001 From: Samuel Stanton Date: Wed, 14 Feb 2024 17:47:07 -0500 Subject: [PATCH 4/9] fix ruff error --- cortex/acquisition/_graph_nei.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/cortex/acquisition/_graph_nei.py b/cortex/acquisition/_graph_nei.py index 3d1fa93..706fe67 100644 --- a/cortex/acquisition/_graph_nei.py +++ b/cortex/acquisition/_graph_nei.py @@ -1,11 +1,9 @@ import numpy as np import torch from botorch.acquisition.logei import qLogExpectedImprovement -from botorch.acquisition.multi_objective.logei import \ - qLogExpectedHypervolumeImprovement +from botorch.acquisition.multi_objective.logei import qLogExpectedHypervolumeImprovement from botorch.acquisition.objective import IdentityMCObjective -from botorch.utils.multi_objective.box_decompositions import \ - FastNondominatedPartitioning +from botorch.utils.multi_objective.box_decompositions import FastNondominatedPartitioning from botorch.utils.multi_objective.hypervolume import infer_reference_point from botorch.utils.multi_objective.pareto import is_non_dominated From a431d4a5d1576c8e034227f1b3a56a586a2f0c2c Mon Sep 17 00:00:00 2001 From: Samuel Stanton Date: Mon, 19 Feb 2024 16:17:51 -0500 Subject: [PATCH 5/9] add occlusion attribution, initial sequence selection --- cortex/acquisition/__init__.py | 4 +- cortex/acquisition/_graph_nei.py | 181 +++++++++++++++++++++++-------- cortex/attribution/__init__.py | 3 + cortex/attribution/_occlusion.py | 20 ++++ cortex/optim/_initialization.py | 34 ++++++ 5 files changed, 194 insertions(+), 48 deletions(-) create mode 100644 cortex/attribution/_occlusion.py create mode 100644 cortex/optim/_initialization.py diff --git a/cortex/acquisition/__init__.py b/cortex/acquisition/__init__.py index 449f0aa..b5078aa 100644 --- a/cortex/acquisition/__init__.py +++ b/cortex/acquisition/__init__.py @@ -1,7 +1,7 @@ -from ._graph_nei import GraphNEI, get_graph_nei_runtime_kwargs, get_obj_vals +from ._graph_nei import GraphNEI, get_graph_nei_runtime_kwargs, get_joint_objective_values __all__ = [ "get_graph_nei_runtime_kwargs", - "get_obj_vals", + "get_joint_objective_values", "GraphNEI", ] diff --git a/cortex/acquisition/_graph_nei.py b/cortex/acquisition/_graph_nei.py index 706fe67..3eb9f8d 100644 --- a/cortex/acquisition/_graph_nei.py +++ b/cortex/acquisition/_graph_nei.py @@ -1,3 +1,5 @@ +from typing import Optional + import numpy as np import torch from botorch.acquisition.logei import qLogExpectedImprovement @@ -6,6 +8,7 @@ from botorch.utils.multi_objective.box_decompositions import FastNondominatedPartitioning from botorch.utils.multi_objective.hypervolume import infer_reference_point from botorch.utils.multi_objective.pareto import is_non_dominated +from torch import Tensor from cortex.model.tree import NeuralTree, NeuralTreeOutput, fetch_task_outputs @@ -18,40 +21,114 @@ } -def get_obj_vals( - tree_output: NeuralTreeOutput, +def get_joint_objective_values( + inputs: dict[str, Tensor], objectives: list[str], - constraints: dict[str, list[str]], - scaling: dict[str, dict[str, float]], -): - res = {} - obj_vals = [] + constraints: Optional[dict[str, list[str]]] = None, + scaling: Optional[dict[str, dict[str, float]]] = None, +) -> Tensor: + """Get joint objective values from predicted properties based on objectives and constraints. + + Parameters + ---------- + inputs : dict[str, Tensor] + dictionary of predicted properties. Each key is a property name and each value is a tensor of shape (ensemble size, batch_size) + objectives : list[str] + list of objective names. Each objective name must be a key in inputs. + constraints : Optional[dict[str, list[str]]], optional + dictionary of constraints. Each key is a constraint name and each value is a list of objective names that are constrained by the constraint. + scaling : Optional[dict[str, dict[str, float]]], optional + dictionary of scaling parameters. Each key is a property name and each value is a dictionary with keys "scale" and "shift". + + Returns + ------- + Tensor + Joint objective values of shape (ensemble size, batch_size, num_objectives) + + """ + + if not all([obj in inputs for obj in objectives]): + raise ValueError(f"Not all objectives {objectives} in predicted_properties {inputs.keys()}") + + objective_values: list[Tensor] = [] + for obj in objectives: - obj_outputs = fetch_task_outputs(tree_output, obj) - pred_means = obj_outputs["loc"].squeeze(-1) - scaled_pred_means = (pred_means + scaling[obj]["shift"]) * scaling[obj]["scale"] - obj_vals.append(scaled_pred_means) - res[obj] = pred_means - res[f"{obj}_scaled"] = scaled_pred_means - obj_vals = torch.stack(obj_vals, dim=-1) # (num_samples, num_objectives) - - if len(constraints) == 0: - constraint_vals = torch.ones_like(obj_vals) - else: - constraint_vals = [] - for obj in constraints: - _current = [] - for const in constraints[obj]: - const_outputs = fetch_task_outputs(tree_output, const) - satisfied_prob = const_outputs["logits"].softmax(-1)[..., 1] - _current.append(satisfied_prob) - res[const] = satisfied_prob - constraint_vals.append(torch.stack(_current, dim=-1).prod(-1)) - constraint_vals = torch.stack(constraint_vals, dim=-1) # (num_samples, num_objectives) - - res["joint"] = obj_vals * constraint_vals + pred_means = inputs[obj] - return res + if scaling is not None and obj in scaling: + pred_means = scale_value(pred_means, shift=scaling[obj]["shift"], scale=scaling[obj]["scale"]) + + objective_values.append(pred_means) + + objective_values = torch.stack(objective_values, dim=-1) + + if constraints is None: + return objective_values + + constraint_values: list[Tensor] = [] + + for obj in objectives: + if obj in constraints: + constraint_list = constraints[obj] + _current = [inputs[const] for const in constraint_list] + constraint_values.append(torch.stack(_current, dim=-1).prod(-1)) + + constraint_values = torch.stack(constraint_values, dim=-1) + + objective_values = objective_values * constraint_values + + return objective_values + + +def scale_value(value: Tensor, *, shift: float, scale: float) -> Tensor: + return (value + shift) * scale + + +def tree_output_to_dict( + tree_output: NeuralTreeOutput, + objectives: list[str], + constraints: Optional[dict[str, list[str]]] = None, + scaling: Optional[dict[str, dict[str, float]]] = None, +) -> dict[str, Tensor]: + """Convert tree output to dictionary of tensors. + + Parameters + ---------- + tree_output : NeuralTreeOutput + Tree output + objectives : list[str] + list of objective names. Each objective adds a key to the output dictionary. + constraints : Optional[dict[str, list[str]]], optional + Optional dictionary of constraints. Each key is added to the output dictionary. + scaling : Optional[dict[str, dict[str, float]]], optional + Optional dictionary of scaling parameters. Must be a subset of objectives and each value is a dictionary with keys "scale" and "shift". + + Returns + ------- + dict[str, Tensor] + dictionary of tensors with keys corresponding to objectives and constraints. + """ + + result: dict[str, Tensor] = {} + + for objective in objectives: + result[objective] = fetch_task_outputs(tree_output, objective)["loc"] + + if scaling is not None and objective in scaling: + result[f"{objective}_scaled"] = scale_value( + value=result[objective], + shift=scaling[objective]["shift"], + scale=scaling[objective]["scale"], + ) + + if constraints is not None: + for constraint in constraints: + constraint_values = fetch_task_outputs(tree_output, constraint)["logits"] + constraint_values = constraint_values.softmax(dim=-1)[..., 1] + + result[constraint] = constraint_values + + return result def get_graph_nei_runtime_kwargs( @@ -61,17 +138,18 @@ def get_graph_nei_runtime_kwargs( constraints: dict[str, list[str]] = GRAPH_CONSTRAINTS, scaling: dict[str, dict[str, float]] = GRAPH_OBJ_TRANSFORM, ): + print("==== predicting baseline point objective values ====") with torch.inference_mode(): tree_output = model.call_from_str_array(candidate_points, corrupt_frac=0.0) - seed_preds = get_obj_vals( - tree_output, + + tree_output_dict = tree_output_to_dict(tree_output, objectives=objectives, constraints=constraints, scaling=scaling) + f_baseline = get_joint_objective_values( + input=tree_output_dict, objectives=objectives, constraints=constraints, scaling=scaling, - ) + ) # (num_samples, num_baseline, num_objectives) - print("==== constructing acquisition function ====") - f_baseline = seed_preds["joint"] # (num_samples, num_baseline, num_objectives) f_baseline_flat = f_baseline.reshape(-1, len(objectives)) f_baseline_non_dom = f_baseline_flat[is_non_dominated(f_baseline_flat)] print(f_baseline_non_dom) @@ -87,9 +165,9 @@ def get_graph_nei_runtime_kwargs( class GraphNEI(object): def __init__( self, - objectives: list, - constraints: dict, - scaling: dict, + objectives: list[str], + constraints: dict[str, list[str]], + scaling: dict[str, dict[str, float]], f_ref: torch.Tensor, # (num_objectives,) f_baseline: torch.Tensor, # (num_samples, num_baseline, num_objectives) ) -> None: @@ -126,23 +204,34 @@ def __init__( ] self.has_pointwise_reference = False - def get_objective_vals(self, tree_output): - return get_obj_vals(tree_output, self.objectives, self.constraints, self.scaling) + def get_objective_vals(self, tree_output: NeuralTreeOutput): + if isinstance(tree_output, NeuralTreeOutput): + tree_output_dict = tree_output_to_dict(tree_output, self.objectives, self.constraints, self.scaling) + return get_joint_objective_values( + tree_output_dict, + self.objectives, + self.constraints, + self.scaling, + ) + + def __call__(self, input: NeuralTreeOutput | torch.Tensor, pointwise=True): + if isinstance(input, NeuralTreeOutput): + obj_val_samples = self.get_objective_vals(input) - def __call__(self, tree_output, pointwise=True): - obj_vals = self.get_objective_vals(tree_output)["joint"] + else: + obj_val_samples = input if pointwise: - obj_vals = obj_vals.unsqueeze(-2) # (num_samples, num_designs, 1, num_objectives) + obj_val_samples = obj_val_samples.unsqueeze(-2) # (num_samples, num_designs, 1, num_objectives) # assumes the first dimension of obj_vals corresponds to the qEHVI partitions if self._obj_dim == 1: acq_vals = torch.stack( - [fn._sample_forward(vals) for fn, vals in zip(self.acq_functions, obj_vals.squeeze(-1))] + [fn._sample_forward(vals) for fn, vals in zip(self.acq_functions, obj_val_samples.squeeze(-1))] ).squeeze(-1) else: acq_vals = torch.stack( - [fn._compute_log_qehvi(vals.unsqueeze(0)) for fn, vals in zip(self.acq_functions, obj_vals)] + [fn._compute_log_qehvi(vals.unsqueeze(0)) for fn, vals in zip(self.acq_functions, obj_val_samples)] ) return acq_vals.mean(0) diff --git a/cortex/attribution/__init__.py b/cortex/attribution/__init__.py index e69de29..893979b 100644 --- a/cortex/attribution/__init__.py +++ b/cortex/attribution/__init__.py @@ -0,0 +1,3 @@ +from ._occlusion import occlusion + +__all__ = ["occlusion"] diff --git a/cortex/attribution/_occlusion.py b/cortex/attribution/_occlusion.py new file mode 100644 index 0000000..cd829fa --- /dev/null +++ b/cortex/attribution/_occlusion.py @@ -0,0 +1,20 @@ +from typing import Optional + +import torch + + +def occlusion( + score_fn: callable, + tok_idxs: torch.LongTensor, + null_value: int, + is_excluded: Optional[torch.BoolTensor] = None, +): + scores = [] + for i in range(tok_idxs.size(-1)): + if torch.all(is_excluded[..., i]): + scores.append(torch.full_like(tok_idxs[..., 0].float(), -float("inf"))) + continue + occluded = tok_idxs.clone() + occluded[..., i] = null_value + scores.append(score_fn(occluded)) + return torch.stack(scores, dim=-1) diff --git a/cortex/optim/_initialization.py b/cortex/optim/_initialization.py new file mode 100644 index 0000000..ae12ca5 --- /dev/null +++ b/cortex/optim/_initialization.py @@ -0,0 +1,34 @@ +import pandas as pd +from botorch.utils.multi_objective.pareto import is_non_dominated + +from cortex.acquisition import get_joint_objective_values +from cortex.model import infer_with_model + + +def select_initial_sequences( + data: pd.DataFrame, + ckpt_dir: str, + ckpt_name: str, + graph_objectives: list[str], + graph_constraints: dict[str, list[str]], + graph_obj_transform: dict[str, dict[str, float]], +): + predictions = infer_with_model( + data=data, + cfg_fpath=f"{ckpt_dir}/{ckpt_name}.yaml", + weight_fpath=f"{ckpt_dir}/{ckpt_name}.pt", + ) + + obj_vals = get_joint_objective_values( + input=predictions, + objectives=graph_objectives, + constraints=graph_constraints, + scaling=graph_obj_transform, + ) + + non_dom_seeds = [] + for obj_val_sample in obj_vals: + is_non_dom = is_non_dominated(obj_val_sample) + non_dom_seeds.append(data.loc[is_non_dom, :]) + + return pd.concat(non_dom_seeds, ignore_index=True) From 2c120f52321345a54bf61c02bee3c4c6feb7bc7a Mon Sep 17 00:00:00 2001 From: Samuel Stanton Date: Mon, 19 Feb 2024 22:08:29 -0500 Subject: [PATCH 6/9] get guided diffusion running on example problem --- cortex/acquisition/_graph_nei.py | 4 +- cortex/attribution/__init__.py | 7 +- cortex/attribution/_occlusion.py | 22 +++++ cortex/metrics/__init__.py | 3 + cortex/metrics/_edit_dist.py | 8 ++ cortex/model/_infer_with_model.py | 19 ++-- cortex/model/root/_conv1d_root.py | 3 + cortex/model/tree/_abstract_tree.py | 3 + cortex/model/tree/_seq_model_tree.py | 44 +++++++++ cortex/optim/__init__.py | 8 ++ cortex/optim/_coordinate_selection.py | 130 +++++++++++++++++++++++++ cortex/optim/_initialization.py | 7 +- tutorials/hydra/roots/protein_seq.yaml | 2 + 13 files changed, 246 insertions(+), 14 deletions(-) create mode 100644 cortex/metrics/_edit_dist.py create mode 100644 cortex/optim/_coordinate_selection.py diff --git a/cortex/acquisition/_graph_nei.py b/cortex/acquisition/_graph_nei.py index 3eb9f8d..8efd8cc 100644 --- a/cortex/acquisition/_graph_nei.py +++ b/cortex/acquisition/_graph_nei.py @@ -112,7 +112,7 @@ def tree_output_to_dict( result: dict[str, Tensor] = {} for objective in objectives: - result[objective] = fetch_task_outputs(tree_output, objective)["loc"] + result[objective] = fetch_task_outputs(tree_output, objective)["loc"].squeeze(-1) if scaling is not None and objective in scaling: result[f"{objective}_scaled"] = scale_value( @@ -144,7 +144,7 @@ def get_graph_nei_runtime_kwargs( tree_output_dict = tree_output_to_dict(tree_output, objectives=objectives, constraints=constraints, scaling=scaling) f_baseline = get_joint_objective_values( - input=tree_output_dict, + inputs=tree_output_dict, objectives=objectives, constraints=constraints, scaling=scaling, diff --git a/cortex/attribution/__init__.py b/cortex/attribution/__init__.py index 893979b..90e5e94 100644 --- a/cortex/attribution/__init__.py +++ b/cortex/attribution/__init__.py @@ -1,3 +1,6 @@ -from ._occlusion import occlusion +from ._occlusion import approximate_occlusion, occlusion -__all__ = ["occlusion"] +__all__ = [ + "approximate_occlusion", + "occlusion", +] diff --git a/cortex/attribution/_occlusion.py b/cortex/attribution/_occlusion.py index cd829fa..e35e6a8 100644 --- a/cortex/attribution/_occlusion.py +++ b/cortex/attribution/_occlusion.py @@ -18,3 +18,25 @@ def occlusion( occluded[..., i] = null_value scores.append(score_fn(occluded)) return torch.stack(scores, dim=-1) + + +def approximate_occlusion( + score_fn: callable, + tok_embeddings: torch.FloatTensor, + null_embedding: torch.FloatTensor, + is_excluded: Optional[torch.BoolTensor] = None, +): + """ + First-order Taylor expansion of the occlusion score. + """ + tok_embeddings = torch.nn.Parameter(tok_embeddings) + score = score_fn(tok_embeddings).sum() + score.backward() + emb_grad = tok_embeddings.grad + + perturbation = null_embedding - tok_embeddings + + score_delta = (emb_grad * perturbation).sum(-1) + + score_delta = torch.where(is_excluded, torch.full_like(score_delta, -float("inf")), score_delta) + return score_delta diff --git a/cortex/metrics/__init__.py b/cortex/metrics/__init__.py index f4c5175..1b6a2e2 100644 --- a/cortex/metrics/__init__.py +++ b/cortex/metrics/__init__.py @@ -1 +1,4 @@ +from ._edit_dist import edit_dist from ._spearman_rho import spearman_rho + +__all__ = ["spearman_rho", "edit_dist"] diff --git a/cortex/metrics/_edit_dist.py b/cortex/metrics/_edit_dist.py new file mode 100644 index 0000000..98525c6 --- /dev/null +++ b/cortex/metrics/_edit_dist.py @@ -0,0 +1,8 @@ +import edlib + + +def edit_dist(x: str, y: str): + """ + Computes the edit distance between two strings. + """ + return edlib.align(x, y)["editDistance"] diff --git a/cortex/model/_infer_with_model.py b/cortex/model/_infer_with_model.py index e6bf42e..4cfcc68 100644 --- a/cortex/model/_infer_with_model.py +++ b/cortex/model/_infer_with_model.py @@ -9,8 +9,9 @@ def infer_with_model( data: pd.DataFrame, - cfg_fpath: str, - weight_fpath: str, + model: Optional[torch.nn.Module] = None, + cfg_fpath: Optional[str] = None, + weight_fpath: Optional[str] = None, batch_limit: int = 32, cpu_offload: bool = True, device: Optional[torch.device] = None, @@ -54,11 +55,17 @@ def infer_with_model( elif dtype is None: dtype = torch.float64 - # load Hydra config from s3 or locally - cfg = load_hydra_config(cfg_fpath) + if model is None and (cfg_fpath is None or weight_fpath is None): + raise ValueError("Either model or cfg_fpath and weight_fpath must be provided") - # load model checkpoint from s3 or locally - model, _ = load_model_checkpoint(cfg, weight_fpath, device=device, dtype=dtype) + if model is not None and (cfg_fpath is not None or weight_fpath is not None): + raise ValueError("Only one of model or cfg_fpath and weight_fpath must be provided") + + if model is None: + # load Hydra config from s3 or locally + cfg = load_hydra_config(cfg_fpath) + # load model checkpoint from s3 or locally + model, _ = load_model_checkpoint(cfg, weight_fpath, device=device, dtype=dtype) # model forward pass with torch.inference_mode(): diff --git a/cortex/model/root/_conv1d_root.py b/cortex/model/root/_conv1d_root.py index d7474f6..046dd54 100644 --- a/cortex/model/root/_conv1d_root.py +++ b/cortex/model/root/_conv1d_root.py @@ -121,6 +121,9 @@ def initialize_weights(self, **kwargs): # default random initialization pass + def get_token_embedding(self, tok_idx: int): + return self.tok_encoder(torch.tensor(tok_idx, device=self.device)) + @property def device(self): return self.tok_encoder.weight.device diff --git a/cortex/model/tree/_abstract_tree.py b/cortex/model/tree/_abstract_tree.py index cda69f8..3ee437e 100644 --- a/cortex/model/tree/_abstract_tree.py +++ b/cortex/model/tree/_abstract_tree.py @@ -153,3 +153,6 @@ def add_leaf( warnings.warn(msg, stacklevel=2) else: self.leaf_nodes[leaf_key] = leaf_node + + def call_from_trunk_output(self, trunk_output, leaf_keys: Optional[list[str]] = None, **kwargs): + return self(root_inputs=None, trunk_outputs=trunk_output, leaf_keys=leaf_keys, **kwargs) diff --git a/cortex/model/tree/_seq_model_tree.py b/cortex/model/tree/_seq_model_tree.py index 7aaabef..37b21d7 100644 --- a/cortex/model/tree/_seq_model_tree.py +++ b/cortex/model/tree/_seq_model_tree.py @@ -528,6 +528,50 @@ def format_task_outputs(self, task_out, task_keys, task_leaves): return predict_out + def call_from_str_array( + self, str_array, root_key: Optional[str] = None, leaf_keys: Optional[list[str]] = None, **kwargs + ): + if root_key is None: + root_key = _infer_root_key(self.root_nodes) + root_inputs = {root_key: {"seq_array": str_array, **kwargs}} + return self(root_inputs=root_inputs, leaf_keys=leaf_keys) + + def call_from_tok_idxs( + self, + tok_idxs: torch.LongTensor, + root_key: Optional[str] = None, + leaf_keys: Optional[list[str]] = None, + **kwargs, + ): + if root_key is None: + root_key = _infer_root_key(self.root_nodes) + root_inputs = {root_key: {"tgt_tok_idxs": tok_idxs, **kwargs}} + return self(root_inputs=root_inputs, leaf_keys=leaf_keys) + + def call_from_tok_embs( + self, + tok_embs: torch.FloatTensor, + root_key: Optional[str] = None, + leaf_keys: Optional[list[str]] = None, + **kwargs, + ): + if root_key is None: + root_key = _infer_root_key(self.root_nodes) + root_inputs = {root_key: {"src_tok_embs": tok_embs, **kwargs}} + return self(root_inputs=root_inputs, leaf_keys=leaf_keys) + + def get_tokenizer(self, root_key: Optional[str] = None): + if root_key is None: + root_key = _infer_root_key(self.root_nodes) + return self.root_nodes[root_key].tokenizer + + +def _infer_root_key(root_nodes): + if len(root_nodes) == 1: + return list(root_nodes.keys())[0] + else: + raise ValueError("root_key must be provided when there are multiple root nodes") + def get_param_prefixes(tree_outputs): param_prefixes = [] diff --git a/cortex/optim/__init__.py b/cortex/optim/__init__.py index e69de29..2d50d69 100644 --- a/cortex/optim/__init__.py +++ b/cortex/optim/__init__.py @@ -0,0 +1,8 @@ +from ._coordinate_selection import NOSCoordinateScore, greedy_occlusion_selection +from ._initialization import select_initial_sequences + +__all__ = [ + "greedy_occlusion_selection", + "NOSCoordinateScore", + "select_initial_sequences", +] diff --git a/cortex/optim/_coordinate_selection.py b/cortex/optim/_coordinate_selection.py new file mode 100644 index 0000000..ee113ae --- /dev/null +++ b/cortex/optim/_coordinate_selection.py @@ -0,0 +1,130 @@ +from typing import Optional + +import torch + +from cortex.attribution import occlusion +from cortex.model.tree import NeuralTree, NeuralTreeOutput, fetch_task_outputs + + +def greedy_occlusion_selection( + tok_idxs: torch.LongTensor, + score_fn: callable, + null_value: int, + num_coordinates: int, + is_excluded: Optional[torch.BoolTensor] = None, + take_second_prob: float = 0.5, +): + """ + Greedy coordinate selection based on sensitivity of `score_fn` to pointwise occlusion. + `score_fn` should be a callable that takes a tensor of token indices and returns a batch of scalar scores. + At each iteration, each coordinate is occluded and the score_fn is evaluated on the resulting tensor. + For each element in the batch, the coordinate with the highest score is selected and remains occluded. + This process is repeated until `num_coordinates` coordinates are selected. + Returns a tensor of indices of selected coordinates. + """ + num_feasible = (~is_excluded).float().sum(-1) + assert torch.all(num_feasible >= num_coordinates), "Not enough feasible coordinates" + is_selected = torch.zeros_like(tok_idxs, dtype=torch.bool) + for _ in range(num_coordinates): + scores = occlusion( + score_fn=score_fn, tok_idxs=tok_idxs, null_value=null_value, is_excluded=is_excluded + is_selected + ) + # don't select already selected coordinates + scores = scores.masked_fill(is_selected, -float("inf")) + # don't select excluded coordinates + if is_excluded is not None: + scores = scores.masked_fill(is_excluded, -float("inf")) + + _, sorted_idxs = torch.sort(scores, dim=-1, descending=True) + best_coord = sorted_idxs[..., 0] + second_best = sorted_idxs[..., 1] + second_available = (scores > -float("inf")).sum(-1) > 1 + take_second = (torch.rand_like(best_coord.float()) < take_second_prob) * second_available + best_coord = torch.where(take_second, second_best, best_coord) + + is_selected.scatter_(-1, best_coord.unsqueeze(-1), True) + tok_idxs = torch.where(is_selected, null_value, tok_idxs) + select_coord = torch.where(is_selected)[1].view(*tok_idxs.shape[:-1], num_coordinates) + print(select_coord) + return select_coord + + +class NOSCoordinateScore(object): + r""" + Wrapper for a `cortex` model that computes the following score: + $$ s_i = (1 - \lambda) (v(x_{-i}) - v(x)) - \lambda \logp(x_i \mid x_{-i}) $$ + where $x_i$ is the $i$th token in the input sequence $x$, $x_{-i}$ is the sequence with the $i$th token occluded, + $v$ is a value function, and $\lambda$ is a hyperparameter. + """ + + def __init__( + self, + model: NeuralTree, + value_fn, + logp_fn, + x_instances, + lambda_val: float, + root_key: str, + ): + self.model = model + self.value_fn = value_fn + self.logp_fn = logp_fn + self._x_instances = x_instances + with torch.inference_mode(): + self._ref_values = value_fn(model(x_instances)) + self._lambda_val = lambda_val + self._root_key = root_key + + def __call__(self, x_occluded): + with torch.inference_mode(): + model_output = self.model(x_occluded) + values = self.value_fn(model_output) + values = values - self._ref_values # want to change positions with large change in value + if self._lambda_val == 0.0: + return values + logp = self.logp_fn( + model_output, self._x_instances, x_occluded, self._root_key + ) # want to change positions with low probability + return (1 - self._lambda_val) * values - self._lambda_val * logp + + +def mlm_conditional_log_likelihood( + tree_output: NeuralTreeOutput, + x_instances, + x_occluded, + root_key: str, +): + """ + Compute the MLM conditional log-likelihood of the masked tokens in `x_occluded` given the unmasked tokens in `x_instances`. + """ + task_outputs = fetch_task_outputs(tree_output, root_key) + token_probs = task_outputs["logits"].log_softmax(-1) # (ensemble_size, batch_size, seq_len, vocab_size) + is_occluded = x_instances != x_occluded + token_cll = token_probs.gather(-1, x_instances[None, ..., None]).squeeze(-1) # (ensemble_size, batch_size, seq_len) + infill_cll = (token_cll * is_occluded).sum(-1) / is_occluded.sum(-1) # (ensemble_size, batch_size) + return infill_cll.mean(0) + + +def mlm_pseudo_log_likelihood( + tok_idxs: torch.LongTensor, + null_value: int, + model: NeuralTree, + root_key: str, + is_excluded: Optional[torch.BoolTensor] = None, +): + """ + Compute the MLM pseudo-log-likelihood of the full tok_idxs sequence + """ + scores = [] + for coord_idx in range(tok_idxs.size(-1)): + if is_excluded is not None and torch.all(is_excluded[..., coord_idx]): + scores.append(torch.full_like(tok_idxs[..., 0].float(), 0.0)) + continue + occluded = tok_idxs.clone() + occluded[..., coord_idx] = null_value + with torch.inference_mode(): + model_output = model(occluded, leaf_keys=[f"{root_key}_0"]) + scores.append(mlm_conditional_log_likelihood(model_output, tok_idxs, occluded, root_key)) + is_included = 1.0 - is_excluded.float() if is_excluded is not None else torch.ones_like(scores) + scores = is_included * torch.stack(scores, dim=-1) + return scores.sum(-1) / is_included.sum(-1) diff --git a/cortex/optim/_initialization.py b/cortex/optim/_initialization.py index ae12ca5..35d3a05 100644 --- a/cortex/optim/_initialization.py +++ b/cortex/optim/_initialization.py @@ -3,20 +3,19 @@ from cortex.acquisition import get_joint_objective_values from cortex.model import infer_with_model +from cortex.model.tree import NeuralTree def select_initial_sequences( data: pd.DataFrame, - ckpt_dir: str, - ckpt_name: str, + model: NeuralTree, graph_objectives: list[str], graph_constraints: dict[str, list[str]], graph_obj_transform: dict[str, dict[str, float]], ): predictions = infer_with_model( data=data, - cfg_fpath=f"{ckpt_dir}/{ckpt_name}.yaml", - weight_fpath=f"{ckpt_dir}/{ckpt_name}.pt", + model=model, ) obj_vals = get_joint_objective_values( diff --git a/tutorials/hydra/roots/protein_seq.yaml b/tutorials/hydra/roots/protein_seq.yaml index a20274c..03cbbcd 100644 --- a/tutorials/hydra/roots/protein_seq.yaml +++ b/tutorials/hydra/roots/protein_seq.yaml @@ -1,5 +1,7 @@ protein_seq: _target_: cortex.model.root.Conv1dRoot + corruption_process: + _target_: cortex.corruption.MaskCorruptionProcess tokenizer_transform: _target_: cortex.transforms.HuggingFaceTokenizerTransform tokenizer: From ff09e0fa688c3a0035f57df0cc8f81fe48bd342b Mon Sep 17 00:00:00 2001 From: Samuel Stanton Date: Mon, 19 Feb 2024 22:09:30 -0500 Subject: [PATCH 7/9] add guided diffusion tutorial --- tutorials/4_guided_diffusion.ipynb | 215 ++++++++++++++++++ tutorials/hydra/4_guided_diffusion.yaml | 25 ++ .../hydra/branches/protein_generation.yaml | 6 + .../guidance_objective/log_fluorescence.yaml | 12 + tutorials/hydra/optim/lambo.yaml | 11 + tutorials/hydra/tasks/protein_seq.yaml | 23 ++ 6 files changed, 292 insertions(+) create mode 100644 tutorials/4_guided_diffusion.ipynb create mode 100644 tutorials/hydra/4_guided_diffusion.yaml create mode 100644 tutorials/hydra/branches/protein_generation.yaml create mode 100644 tutorials/hydra/guidance_objective/log_fluorescence.yaml create mode 100644 tutorials/hydra/optim/lambo.yaml create mode 100644 tutorials/hydra/tasks/protein_seq.yaml diff --git a/tutorials/4_guided_diffusion.ipynb b/tutorials/4_guided_diffusion.ipynb new file mode 100644 index 0000000..f655f63 --- /dev/null +++ b/tutorials/4_guided_diffusion.ipynb @@ -0,0 +1,215 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Protein Design with Guided Discrete Diffusion" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from omegaconf import OmegaConf\n", + "import hydra\n", + "from cortex.logging import wandb_setup\n", + "\n", + "with hydra.initialize(config_path=\"./hydra\"):\n", + " cfg = hydra.compose(config_name=\"4_guided_diffusion\")\n", + " OmegaConf.set_struct(cfg, False)\n", + "\n", + "wandb_setup(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from cortex.data.dataset import TAPEFluorescenceDataset\n", + "\n", + "\n", + "dataset = TAPEFluorescenceDataset(\n", + " root='./.cache',\n", + " download=True,\n", + " train=True,\n", + ")\n", + "\n", + "med_idx = len(dataset) // 2\n", + "\n", + "init_df = dataset._data.sort_values(\"log_fluorescence\").iloc[med_idx : med_idx + 1]\n", + "init_df = init_df.sample(n=cfg.optim.max_num_solutions, replace=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import lightning as L\n", + "\n", + "# set random seed\n", + "L.seed_everything(seed=cfg.random_seed, workers=True)\n", + "\n", + "# instantiate model\n", + "model = hydra.utils.instantiate(cfg.tree)\n", + "model.build_tree(cfg, skip_task_setup=False)\n", + "\n", + "# instantiate trainer, set logger\n", + "trainer = hydra.utils.instantiate(cfg.trainer)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.fit(\n", + " model,\n", + " train_dataloaders=model.get_dataloader(split=\"train\"),\n", + " val_dataloaders=model.get_dataloader(split=\"val\"),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# construct guidance objective\n", + "initial_solution = init_df[\"tokenized_seq\"].values\n", + "acq_fn_runtime_kwargs = hydra.utils.call(\n", + " cfg.guidance_objective.runtime_kwargs, model=model, candidate_points=initial_solution\n", + ")\n", + "acq_fn = hydra.utils.instantiate(cfg.guidance_objective.static_kwargs, **acq_fn_runtime_kwargs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer_transform = model.root_nodes[\"protein_seq\"].eval_transform\n", + "tokenizer = tokenizer_transform[0].tokenizer\n", + "\n", + "tok_idxs = tokenizer_transform(initial_solution)\n", + "is_mutable = tokenizer.get_corruptible_mask(tok_idxs)\n", + "is_mutable\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "with torch.inference_mode():\n", + " tree_output = model.call_from_str_array(initial_solution, corrupt_frac=0.0)\n", + " init_obj_vals = acq_fn.get_objective_vals(tree_output)\n", + "init_obj_vals" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "optimizer = hydra.utils.instantiate(\n", + " cfg.optim,\n", + " params=tok_idxs,\n", + " is_mutable=is_mutable,\n", + " model=model,\n", + " objective=acq_fn,\n", + " constraint_fn=None,\n", + ")\n", + "for _ in range(cfg.num_steps):\n", + " optimizer.step()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "new_designs = optimizer.get_best_solutions()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with torch.inference_mode():\n", + " tree_output = model.call_from_str_array(new_designs[\"protein_seq\"].values, corrupt_frac=0.0)\n", + " final_obj_vals = acq_fn.get_objective_vals(tree_output)\n", + "final_obj_vals" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "history = optimizer._buffer\n", + "\n", + "med_obj_val = history.groupby(\"iteration\").obj_val.median()\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "sns.set_theme(style=\"whitegrid\", font_scale=1.75)\n", + "\n", + "plt.plot(med_obj_val)\n", + "plt.xlabel(\"Diffusion Iteration\")\n", + "plt.ylabel(\"Median Acq. Value\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.kdeplot(final_obj_vals.view(-1), fill=True, alpha=0.5, cut=0)\n", + "ylim = plt.ylim()\n", + "plt.vlines(init_obj_vals[0], *ylim, color=\"black\", linestyle=\"--\", label=\"Initial Value\")\n", + "plt.xlabel(\"Predicted Log Fluorescence\")\n", + "plt.legend()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "cortex-public", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tutorials/hydra/4_guided_diffusion.yaml b/tutorials/hydra/4_guided_diffusion.yaml new file mode 100644 index 0000000..e9cbe45 --- /dev/null +++ b/tutorials/hydra/4_guided_diffusion.yaml @@ -0,0 +1,25 @@ +defaults: + - tree: sequence_model + - roots: [protein_seq] + - trunk: sum_trunk + - branches: [protein_property, protein_generation] + - tasks: + - log_fluorescence + - protein_seq + - guidance_objective: log_fluorescence + - optim: lambo + +feature_dim: 32 +kernel_size: 3 +batch_size: 32 +max_epochs: 2 +data_dir: ./.cache +wandb_mode: offline +random_seed: 42 +num_steps: 8 +num_samples: 16 + +trainer: + _target_: lightning.Trainer + max_epochs: ${max_epochs} + num_sanity_val_steps: 1 diff --git a/tutorials/hydra/branches/protein_generation.yaml b/tutorials/hydra/branches/protein_generation.yaml new file mode 100644 index 0000000..21e9dc7 --- /dev/null +++ b/tutorials/hydra/branches/protein_generation.yaml @@ -0,0 +1,6 @@ +protein_generation: + _target_: cortex.model.branch.Conv1dBranch + out_dim: 32 + channel_dim: ${feature_dim} + num_blocks: 0 + kernel_size: ${kernel_size} diff --git a/tutorials/hydra/guidance_objective/log_fluorescence.yaml b/tutorials/hydra/guidance_objective/log_fluorescence.yaml new file mode 100644 index 0000000..d7cacf6 --- /dev/null +++ b/tutorials/hydra/guidance_objective/log_fluorescence.yaml @@ -0,0 +1,12 @@ +tag: log_fluorescence +static_kwargs: + _target_: cortex.acquisition.GraphNEI + objectives: + - log_fluorescence + constraints: null + scaling: null +runtime_kwargs: + _target_: cortex.acquisition.get_graph_nei_runtime_kwargs + objectives: ${guidance_objective.static_kwargs.objectives} + constraints: ${guidance_objective.static_kwargs.constraints} + scaling: ${guidance_objective.static_kwargs.scaling} diff --git a/tutorials/hydra/optim/lambo.yaml b/tutorials/hydra/optim/lambo.yaml new file mode 100644 index 0000000..44d2f7b --- /dev/null +++ b/tutorials/hydra/optim/lambo.yaml @@ -0,0 +1,11 @@ +_target_: cortex.optim.generative.LaMBO +max_num_solutions: ${num_samples} +num_mutations_per_step: 8 +max_guidance_updates: 4 +guidance_step_size: 0.1 +guidance_layer: trunk +kl_weight: 0.25 +feature_attr_temp: 1.0 +domain_name: protein_seq +exclude_initial_solution: false +resample_edit_positions: true diff --git a/tutorials/hydra/tasks/protein_seq.yaml b/tutorials/hydra/tasks/protein_seq.yaml new file mode 100644 index 0000000..7bf3713 --- /dev/null +++ b/tutorials/hydra/tasks/protein_seq.yaml @@ -0,0 +1,23 @@ +protein_generation: + protein_seq: + _target_: cortex.task.DenoisingLanguageModelTask + tokenizer: + _target_: cortex.tokenization.ProteinSequenceTokenizerFast + input_map: + protein_seq: ['tokenized_seq'] + root_key: protein_seq + ensemble_size: 1 + data_module: + _target_: cortex.data.data_module.TaskDataModule + _recursive_: false + batch_size: ${batch_size} + balance_train_partition: null + drop_last: true + lengths: [1.0, 0.0] + train_on_everything: false + num_workers: 1 + dataset_config: + _target_: cortex.data.dataset.TAPEFluorescenceDataset + root: ${data_dir} + download: true + train: ??? From 1ad2ce0fefc4ad1e784d5bef531f29962de9059e Mon Sep 17 00:00:00 2001 From: Samuel Stanton Date: Mon, 19 Feb 2024 22:10:35 -0500 Subject: [PATCH 8/9] add LaMBO optimizer --- cortex/optim/generative/__init__.py | 5 + cortex/optim/generative/_lambo.py | 420 ++++++++++++++++++++++++++++ 2 files changed, 425 insertions(+) create mode 100644 cortex/optim/generative/__init__.py create mode 100644 cortex/optim/generative/_lambo.py diff --git a/cortex/optim/generative/__init__.py b/cortex/optim/generative/__init__.py new file mode 100644 index 0000000..5fb4340 --- /dev/null +++ b/cortex/optim/generative/__init__.py @@ -0,0 +1,5 @@ +from ._lambo import LaMBO + +__all__ = [ + "LaMBO", +] diff --git a/cortex/optim/generative/_lambo.py b/cortex/optim/generative/_lambo.py new file mode 100644 index 0000000..9b9f3db --- /dev/null +++ b/cortex/optim/generative/_lambo.py @@ -0,0 +1,420 @@ +import math +import warnings +from typing import Callable, Optional + +import numpy as np +import pandas as pd +import torch +import wandb +from torch.distributions.kl import kl_divergence + +from cortex.attribution import approximate_occlusion +from cortex.corruption import GaussianCorruptionProcess, MaskCorruptionProcess +from cortex.optim._coordinate_selection import ( + mlm_pseudo_log_likelihood, +) + + +class LaMBO(object): + def __init__( + self, + params: torch.LongTensor, + is_mutable: torch.BoolTensor, + model, + objective, + max_num_solutions: int, + num_mutations_per_step: Optional[int] = 2, + max_guidance_updates: int = 16, + guidance_step_size: float = 1.0, + guidance_layer: str = "trunk", + kl_weight: float = 0.5, + feature_attr_temp: float = 0.125, + constraint_fn: Optional[Callable[[str], bool]] = None, + domain_name: Optional[str] = None, + exclude_initial_solution: bool = False, + resample_edit_positions: bool = False, + ) -> None: + self.model = model + self.objective = objective + self.max_num_solutions = max_num_solutions + self.num_mutations_per_step = num_mutations_per_step + self.max_guidance_updates = max_guidance_updates + self.guidance_step_size = guidance_step_size + self.guidance_layer = guidance_layer + self.kl_weight = kl_weight + self.feature_attr_temp = feature_attr_temp + self.constraint_fn = constraint_fn + self.domain_name = "iterate" if domain_name is None else domain_name + self.resample_edit_positions = resample_edit_positions + + self.initial_solution = np.array([self.tokenizer.decode(t_idxs) for t_idxs in params]) + + initial_obj_val, _ = self.score_sequences(self.initial_solution) + print(initial_obj_val) + self.initial_obj_val = initial_obj_val + self.is_mutable = is_mutable + + self._step_count = 0 + self._buffer = pd.DataFrame( + { + "iteration": self._step_count, + domain_name: self.initial_solution, + "obj_val": self.initial_obj_val.cpu().numpy(), + }, + index=list(range(initial_obj_val.size(0))), + ) + self.exclude_initial_solution = exclude_initial_solution + + @property + def tokenizer(self): + return self.model.root_nodes[self.domain_name].tokenizer + + @property + def tokens_to_long_tensor(self): + return self.model.root_nodes[self.domain_name].eval_transform + + def step(self) -> None: + self.model.eval() + self.model.requires_grad_(False) + + print(f"==== LaMBO-2 Step {self._step_count + 1} ====") + prev_frame = self._buffer.iloc[-self.initial_solution.shape[0] :] + base_solutions = prev_frame[self.domain_name].values + + # TODO figure out cleaner structure + base_tok_idxs = self.tokens_to_long_tensor(base_solutions) + base_output = self.model.call_from_tok_idxs(base_tok_idxs, corrupt_frac=0.0) + base_obj_vals = self.objective(base_output) + base_tok_embs = base_output.root_outputs[self.domain_name].src_tok_embs + base_padding_mask = base_output.root_outputs[self.domain_name].padding_mask + + # tensor of token indices that are not viable sample output + non_viable_idxs = torch.tensor( + [self.tokenizer.vocab[token] for token in self.tokenizer.sampling_vocab_excluded] + ) + + # check feasibility constraints + base_is_feasible = self.check_constraints(base_tok_idxs).to(base_tok_idxs.device) + if not base_is_feasible.any(): + msg = "No feasible starting points, search may fail to find feasible solutions. Try changing the initial solution or the constraint function." + warnings.warn(msg, stacklevel=2) + + print(f"Optimizing {base_obj_vals.size(0)} solutions") + + # create mutable tensors for optimization + tgt_tok_idxs = base_tok_idxs.clone() + tgt_tok_embs = base_tok_embs.clone() + tgt_padding_mask = base_padding_mask.clone() + + # choose edit positions + is_corruptible = self.tokenizer.get_corruptible_mask(tgt_tok_idxs) + self._coordinate_selection(tgt_tok_idxs, tgt_tok_embs, tgt_padding_mask, is_corruptible) + + # score current solutions + if self._step_count == 0 and self.exclude_initial_solution: + tgt_obj_vals = torch.full_like(base_obj_vals, float("-inf")) + else: + tgt_obj_vals = base_obj_vals + if self.kl_weight > 0.0: + tgt_obj_vals *= 1 - self.kl_weight + tgt_obj_vals += self.kl_weight * mlm_pseudo_log_likelihood( + tgt_tok_idxs, + null_value=self.tokenizer.masking_idx, + model=self.model.call_from_tok_idxs, + root_key=self.domain_name, + is_excluded=~self.tokenizer.get_corruptible_mask(tgt_tok_idxs), + ) + + # set up forwards pass inputs + generation_inputs = self._set_up_root_inputs(tgt_tok_idxs, tgt_tok_embs, tgt_padding_mask) + is_corrupted = generation_inputs[self.domain_name]["is_corrupted"] + + # get latent features, make them leaf variables + activations, trunk_outputs = self._get_latent_variables(generation_inputs) + + delta = torch.nn.Parameter(torch.zeros_like(activations)) + optimizer = torch.optim.Adam([delta], lr=self.guidance_step_size) + + print("\n") + for lang_step in range(self.max_guidance_updates): + delta.grad = None + + # forward pass from modified activations + if self.guidance_layer == "root": + masked_delta = torch.where(is_corrupted[..., None], delta, 0.0) + generation_inputs[self.domain_name]["src_tok_embs"] = activations + masked_delta + tree_output = self.model(root_inputs=generation_inputs) + trunk_outputs = tree_output.trunk_outputs + elif self.guidance_layer == "trunk": + masked_delta = torch.where(is_corrupted[..., None], delta, 0.0) + trunk_outputs.trunk_features = activations + masked_delta + tree_output = self.model.call_from_trunk_output(trunk_outputs) + + # guided token distribution logits + token_logits = tree_output.leaf_outputs[f"{self.domain_name}_0"].logits + + # fix unguided reference token distribution + if lang_step == 0: + adj_logits = token_logits.scatter( + dim=-1, + index=non_viable_idxs.expand(*token_logits.shape[:-1], -1).to(token_logits.device), + value=float("-inf"), + ) + base_probs = adj_logits.detach().clone().softmax(-1).clamp_min(1e-6) + base_dist = torch.distributions.Categorical(probs=base_probs) + + # compute guidance loss + guided_dist = torch.distributions.Categorical(logits=token_logits) + entropy = torch.masked_select(guided_dist.entropy(), is_corrupted).mean() + kl_div = torch.masked_select(kl_divergence(guided_dist, base_dist), is_corrupted).mean() + obj_loss = -1.0 * self.objective(tree_output).mean() + design_loss = self.kl_weight * kl_div + (1.0 - self.kl_weight) * obj_loss + design_loss.backward() + feature_grad = delta.grad.detach().clone() + optimizer.step() + + # update solution + tgt_tok_idxs, tgt_obj_vals = self._update_solution( + trunk_outputs, + activations, + delta, + tgt_tok_idxs, + tgt_obj_vals, + is_corrupted, + self.tokenizer, + non_viable_idxs, + ) + + grad_norm = feature_grad.norm(dim=(-2, -1), keepdim=True) + print( + tgt_obj_vals.median().item(), + design_loss.item(), + obj_loss.item(), + kl_div.item(), + entropy.item(), + ) + + self._step_count += 1 + + tgt_str_array = [self.tokenizer.decode(t_idxs) for t_idxs in tgt_tok_idxs] + df = pd.DataFrame({self.domain_name: tgt_str_array}) + df.loc[:, "obj_val"] = tgt_obj_vals.cpu().numpy() + df.loc[:, "iteration"] = self._step_count + + self._buffer = pd.concat([self._buffer, df], ignore_index=True) + + metrics = { + "step": self._step_count, + "masked_design_loss": design_loss.item(), + "masked_design_loss_grad_norm": grad_norm.mean().item(), + "masked_token_loss": kl_div.item(), + "masked_obj_loss": obj_loss.item(), + "token_entropy": entropy.item(), + } + wandb.log(metrics) + + def _coordinate_selection( + self, + tok_idxs: torch.LongTensor, + tok_embeddings: torch.FloatTensor, + padding_mask: torch.BoolTensor, + is_corruptible: torch.BoolTensor, + ): + def coord_score(tok_embeddings): + tree_output = self.model.call_from_tok_embs( + tok_embeddings, root_key=self.domain_name, corrupt_frac=0.0, padding_mask=padding_mask + ) + return self.objective(tree_output) + + null_embedding = self.model.root_nodes[self.domain_name].get_token_embedding(self.tokenizer.masking_idx) + # model_call = partial(self.model.call_from_tok_embs, root_key=self.domain_name) + # self._coordinate_score = NOSCoordinateScore( + # model=model_call, + # value_fn=self.objective, + # logp_fn=mlm_conditional_log_likelihood, + # x_instances=tok_idxs, + # lambda_val=0.0, + # root_key=self.domain_name, + # ) + + # edit_idxs are all corruptible and mutable positions + pos_is_feasible = is_corruptible * self.is_mutable + if self.num_mutations_per_step is None: + self._corruption_allowed = pos_is_feasible + elif self._step_count == 0 or self.resample_edit_positions: + position_scores = approximate_occlusion( + coord_score, + tok_embeddings, + null_embedding, + is_excluded=~pos_is_feasible, + ) + position_probs = (position_scores * self.feature_attr_temp).softmax(-1) + edit_idxs = torch.multinomial(position_probs, self.num_mutations_per_step, replacement=False) + edit_idxs = edit_idxs.sort(dim=-1).values + + # edit_probs = self.model.edit_probs( + # acq_fn=self.objective, + # base_tok_idxs=tgt_tok_idxs, + # temp=self.feature_attr_temp, + # is_mutable=self.is_mutable, + # ) + # edit_idxs = torch.multinomial( + # edit_probs, self.num_mutations_per_step, replacement=False + # ) + # edit_idxs = edit_idxs.sort(dim=-1).values + + # edit_idxs = greedy_occlusion_selection( + # tok_idxs=tok_idxs, + # score_fn=self._coordinate_score, + # num_coordinates=self.num_mutations_per_step, + # null_value=self.tokenizer.masking_idx, + # is_excluded=~pos_is_feasible, + # ) + self._corruption_allowed = torch.zeros_like(tok_idxs) + self._corruption_allowed = self._corruption_allowed.scatter(dim=-1, index=edit_idxs, value=1).bool() + print(f"Selected edit positions: {edit_idxs}") + + def _get_latent_variables( + self, + generation_inputs: dict, + ): + with torch.no_grad(): + tree_outputs = self.model(generation_inputs, leaf_keys=[f"{self.domain_name}_0"]) + if self.guidance_layer == "root": + activations = tree_outputs.root_outputs[self.domain_name].src_tok_embs + + elif self.guidance_layer == "trunk": + activations = tree_outputs.trunk_outputs.trunk_features + + trunk_outputs = tree_outputs.trunk_outputs + + return activations, trunk_outputs + + def _update_solution( + self, + trunk_outputs, + activations, + delta, + tgt_tok_idxs, + tgt_obj_vals, + is_corrupted, + tokenizer, + non_viable_idxs, + ): + # update latent features only at masked locations + with torch.no_grad(): + new_activations = torch.where(is_corrupted[..., None], activations + delta, activations) + activations.copy_(new_activations) + # compute token logits from updated features + sample_tok_idxs = self.decode(trunk_outputs, non_viable_idxs) + sample_tok_idxs = torch.where(is_corrupted, sample_tok_idxs, tgt_tok_idxs) + + sample_obj_vals, sample_tok_embs = self.score_sequences(sample_tok_idxs) + sample_obj_vals *= 1 - self.kl_weight + if self.kl_weight > 0.0: + sample_obj_vals += self.kl_weight * mlm_pseudo_log_likelihood( + sample_tok_idxs, + null_value=tokenizer.masking_idx, + model=self.model.call_from_tok_idxs, + root_key=self.domain_name, + is_excluded=~tokenizer.get_corruptible_mask(tgt_tok_idxs), + ) + + if tgt_obj_vals is None: + tgt_obj_vals = sample_obj_vals + sample_is_improved = torch.ones_like(sample_obj_vals).bool() + else: + sample_is_improved = sample_obj_vals >= tgt_obj_vals + + sample_is_feasible = self.check_constraints(sample_tok_idxs) + print(f"Feasible samples: {sample_is_feasible.sum()}/{sample_is_feasible.size(0)}") + + # keep improved feasible sequences + replace_mask = sample_is_improved * sample_is_feasible.to(sample_is_improved) + # tgt_tok_embs = torch.where(replace_mask[..., None, None], sample_tok_embs, tgt_tok_embs) + tgt_tok_idxs = torch.where(replace_mask[..., None], sample_tok_idxs, tgt_tok_idxs) + tgt_obj_vals = torch.where(replace_mask, sample_obj_vals, tgt_obj_vals) + + return tgt_tok_idxs, tgt_obj_vals + + def _set_up_root_inputs( + self, + tgt_tok_idxs, + tgt_tok_embs, + tgt_padding_mask, + ): + corrupt_frac = 1.0 / math.sqrt(1 + self._step_count) + + root_inputs = {self.domain_name: {}} + corrupt_kwargs = { + "corrupt_frac": corrupt_frac, + } + corruption_process = self.model.root_nodes[self.domain_name].corruption_process + + # corrupt random subset of positions where self._corruption_allowed is True + if isinstance(corruption_process, MaskCorruptionProcess): + corrupt_kwargs["x_start"] = tgt_tok_idxs + corrupt_kwargs["corruption_allowed"] = self._corruption_allowed + corrupt_kwargs["mask_val"] = self.tokenizer.masking_idx + src_tok_idxs, is_corrupted = corruption_process(**corrupt_kwargs) + root_inputs[self.domain_name]["tgt_tok_idxs"] = src_tok_idxs + root_inputs[self.domain_name]["is_corrupted"] = is_corrupted + + # corrupt all positions where self._corruption_allowed is True + elif isinstance(corruption_process, GaussianCorruptionProcess): + corrupt_kwargs["x_start"] = tgt_tok_embs + corrupt_kwargs["corruption_allowed"] = self._corruption_allowed[..., None] + src_tok_embs, is_corrupted = corruption_process(**corrupt_kwargs) + is_corrupted = is_corrupted.sum(-1).bool() + root_inputs[self.domain_name]["src_tok_embs"] = src_tok_embs + root_inputs[self.domain_name]["is_corrupted"] = is_corrupted + root_inputs[self.domain_name]["padding_mask"] = tgt_padding_mask + else: + raise NotImplementedError + + return root_inputs + + def check_constraints(self, sample_tok_idxs): + tokenizer = self.model.root_nodes[self.domain_name].tokenizer + if self.constraint_fn is not None: + sample_seqs = [tokenizer.decode(t_idxs) for t_idxs in sample_tok_idxs] + sample_seqs = np.array(sample_seqs) + sample_is_feasible = self.constraint_fn(sample_seqs) + else: + sample_is_feasible = np.array([True for _ in sample_tok_idxs]) + return torch.from_numpy(sample_is_feasible) + + def score_sequences(self, sequences): + with torch.inference_mode(): + if isinstance(sequences, np.ndarray): + tree_output = self.model.call_from_str_array(sequences, corrupt_frac=0.0) + elif isinstance(sequences, torch.Tensor): + tree_output = self.model.call_from_tok_idxs(sequences, corrupt_frac=0.0) + else: + raise ValueError("Invalid sequences type") + sample_tok_embs = tree_output.root_outputs[self.domain_name].src_tok_embs + sample_obj_vals = self.objective(tree_output) + + return sample_obj_vals, sample_tok_embs + + def get_best_solutions(self) -> pd.DataFrame: + res = self._buffer.iloc[-self.initial_solution.shape[0] :].copy() + res["obj_val_init"] = self.initial_obj_val.cpu().numpy() + return res + + def decode(self, trunk_outputs, non_viable_idxs): + leaf_key = f"{self.domain_name}_0" + with torch.no_grad(): + tree_output = self.model(root_inputs=None, trunk_outputs=trunk_outputs, leaf_keys=[leaf_key]) + logits = tree_output.leaf_outputs[leaf_key].logits + + # adjust logits to prevent sampling utility tokens + adj_logits = logits.scatter( + dim=-1, + index=non_viable_idxs.expand(*logits.shape[:-1], -1).to(logits.device), + value=float("-inf"), + ) + + # sample new tokens at masked locations + sample_dist = torch.distributions.Categorical(logits=adj_logits) + return sample_dist.sample() From f3b50a55fe1caa89efc327ba1c0ef6942212b662 Mon Sep 17 00:00:00 2001 From: Samuel Stanton Date: Tue, 20 Feb 2024 09:46:45 -0500 Subject: [PATCH 9/9] update README --- README.md | 53 +++++++++++++++++++++------------------- cortex/utils/__init__.py | 0 2 files changed, 28 insertions(+), 25 deletions(-) delete mode 100644 cortex/utils/__init__.py diff --git a/README.md b/README.md index cccc317..6ad19b0 100644 --- a/README.md +++ b/README.md @@ -26,60 +26,63 @@ Deep learning is easy to learn and difficult to master. Seemingly insignificant ## Installation -1. Create a new conda environment. - ```bash conda create --name cortex-env python=3.10 -y && conda activate cortex-env + python -m pip install -r requirements.in + pip install -e . ``` -2. (optional) If desired install dependencies from frozen requirements files. - - `pip install -r requirements.txt -r requirements-dev.txt` - - These files fix the exact version of all dependencies and therefore should create a known good environment. - However, this is likely more stringent than strictly necessary and can make it difficult to work in environments with multiple projects installed. - If you skip this step, all dependencies will be fetched during package installation based on `requirements.in` which attempts to be as loose as possible in specifying compatible package versions. - - To update the frozen dependencies run - `pip-compile --resolver=backtracking requirements.in`. +If you have a package version issue we provide pinned versions of all dependencies in `requirements.txt`. +To update the frozen dependencies run -3. Install cortex. +```bash +pip-compile --resolver=backtracking requirements.in +``` - `pip install -e .[dev]` ## Running Use `cortex_train_model --config-name ` to train, e.g.: ``` -cortex_train_model --config-name train_ab_seqcnn wandb_mode=offline fit=smoke_test +cortex_train_model --config-name train_protein_model wandb_mode=offline ``` -Supported configs are - -- `train_ab_seqcnn` to train a SeqCNN from scratch. - -## How to launch a WANDB sweep on a cluster +## How to launch a WANDB sweep 1. Configure the sweep `.yaml`, e.g. `./wandb_config/ab_model_sweep.yaml` 2. Run `wandb sweep wandb_config/ab_model_sweep.yaml` -3. Copy the sweep id to `scripts/wandb_agent_array.bsub` -4. Run `bsub < scripts/wandb_agent_array.bsub` +3. Launch the wandb agents using a scheduler of your choice, e.g. SLURM or LSF + ## Contributing -Contributions are welcome, especially tutorials and documentation. +Contributions are welcome! ### Install dev requirements and pre-commit hooks -``` + +```bash python -m pip install -r requirements-dev.in pre-commit install ``` ### Testing -`pytest -v --cov-report term-missing --cov=./cortex ./tests` +```bash +pytest -v --cov-report term-missing --cov=./cortex ./tests +``` + +### Build and browse docs locally + +```bash +make -C docs html +cd docs/build/html +python -m http.server +``` + +Then open `http://localhost:8000` in your browser. +``` ### Maintainers diff --git a/cortex/utils/__init__.py b/cortex/utils/__init__.py deleted file mode 100644 index e69de29..0000000