Skip to content

Commit

Permalink
add model training tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelstanton committed Feb 17, 2024
1 parent 1f8df67 commit acfe98b
Show file tree
Hide file tree
Showing 12 changed files with 213 additions and 61 deletions.
16 changes: 2 additions & 14 deletions cortex/cmdline/train_cortex_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch
import wandb
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.trainer.supporters import CombinedLoader

from cortex.logging import wandb_setup

Expand Down Expand Up @@ -73,21 +72,10 @@ def execute(cfg):
model = hydra.utils.instantiate(cfg.tree)
model.build_tree(cfg, skip_task_setup=False)

# set up dataloaders
leaf_train_loaders = {}
task_test_loaders = {}
for l_key in model.leaf_nodes:
task_key, _ = l_key.rsplit("_", 1)
leaf_train_loaders[l_key] = model.task_dict[task_key].data_module.train_dataloader()
if task_key not in task_test_loaders:
task_test_loaders[task_key] = model.task_dict[task_key].data_module.test_dataloader()

trainer.fit(
model,
train_dataloaders=CombinedLoader(leaf_train_loaders, mode="min_size"),
val_dataloaders=CombinedLoader(
task_test_loaders, mode="max_size_cycle"
), # change to max_size when lightning upgraded to >1.9.5
train_dataloaders=model.get_dataloader(split="train"),
val_dataloaders=model.get_dataloader(split="val"),
)

# save model
Expand Down
1 change: 0 additions & 1 deletion cortex/config/hydra/train_protein_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ tree:
num_training_steps: ${trainer.max_epochs}

tasks:

protein_property:
log_fluorescence:
# ensemble_size: ${ensemble_size}
Expand Down
17 changes: 15 additions & 2 deletions cortex/logging/_wandb_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,32 @@


def wandb_setup(cfg: DictConfig):
if not hasattr(cfg, "wandb_host"):
cfg["wandb_host"] = "https://api.wandb.ai"

if not hasattr(cfg, "wandb_mode"):
cfg["wandb_mode"] = "online"

if not hasattr(cfg, "project_name"):
cfg["project_name"] = "cortex"

if not hasattr(cfg, "exp_name"):
cfg["exp_name"] = "default_group"

wandb.login(host=cfg.wandb_host)

wandb.init(
project=cfg.project_name,
mode=cfg.wandb_mode,
group=cfg.exp_name,
)
cfg["job_name"] = wandb.run.name
cfg["__version__"] = cortex.__version__
log_cfg = flatten_config(OmegaConf.to_container(cfg, resolve=True), sep="/")
log_cfg = flatten_config(OmegaConf.to_container(cfg, resolve=True))
wandb.config.update(log_cfg)


def flatten_config(d: DictConfig, parent_key="", sep="_"):
def flatten_config(d: DictConfig, parent_key="", sep="/"):
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
Expand Down
18 changes: 18 additions & 0 deletions cortex/model/tree/_seq_model_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
from botorch.models.transforms.outcome import OutcomeTransform
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.trainer.supporters import CombinedLoader
from torch import nn

from cortex.model import online_weight_update_
Expand Down Expand Up @@ -70,6 +71,23 @@ def eval(self, *args, **kwargs):
self.load_state_dict(self._eval_state_dict)
return super().eval(*args, **kwargs)

def get_dataloader(self, split="train"):
loaders = {}
for l_key in self.leaf_nodes:
task_key, _ = l_key.rsplit("_", 1)
if split == "train":
loaders[l_key] = self.task_dict[task_key].data_module.train_dataloader()
elif split == "val" and task_key not in loaders:
loaders[task_key] = self.task_dict[task_key].data_module.test_dataloader()
elif split == "val":
pass
else:
raise ValueError(f"Invalid split {split}")

# change val to max_size when lightning upgraded to >1.9.5
mode = "min_size" if split == "train" else "max_size_cycle"
return CombinedLoader(loaders, mode=mode)

def training_step(self, batch: dict, batch_idx: int, dataloader_idx: Optional[int] = None):
leaf_keys = list(batch.keys())
rng = np.random.default_rng()
Expand Down
95 changes: 95 additions & 0 deletions tutorials/3_training_a_neural_tree.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training a Neural Tree\n",
"\n",
"So far we've learned the basic structure of a `NeuralTree` and seen how task objects are used to interface with datasets.\n",
"Now we'll see how a `NeuralTree` is trained.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load the configuration file"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from omegaconf import OmegaConf\n",
"import hydra\n",
"\n",
"with hydra.initialize(config_path=\"./hydra\"):\n",
" cfg = hydra.compose(config_name=\"3_training_a_neural_tree\")\n",
" OmegaConf.set_struct(cfg, False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import lightning as L\n",
"\n",
"# set random seed\n",
"L.seed_everything(seed=cfg.random_seed, workers=True)\n",
"\n",
"# instantiate model\n",
"model = hydra.utils.instantiate(cfg.tree)\n",
"model.build_tree(cfg, skip_task_setup=False)\n",
"\n",
"# instantiate trainer, set logger\n",
"trainer = hydra.utils.instantiate(cfg.trainer)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer.fit(\n",
" model,\n",
" train_dataloaders=model.get_dataloader(split=\"train\"),\n",
" val_dataloaders=model.get_dataloader(split=\"val\"),\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "cortex-public",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
55 changes: 11 additions & 44 deletions tutorials/hydra/2_defining_a_task.yaml
Original file line number Diff line number Diff line change
@@ -1,46 +1,13 @@
defaults:
- tree: sequence_model
- roots: [protein_seq]
- trunk: sum_trunk
- branches: [protein_property]
- tasks:
- log_fluorescence

feature_dim: 32
kernel_size: 3
data_dir: ./cache

tree:
_target_: cortex.model.tree.SequenceModelTree
roots:
protein_seq:
_target_: cortex.model.root.Conv1dRoot
tokenizer_transform:
_target_: cortex.transforms.HuggingFaceTokenizerTransform
tokenizer:
_target_: cortex.tokenization.ProteinSequenceTokenizerFast
max_len: 256
embed_dim: ${feature_dim}
channel_dim: ${feature_dim}
out_dim: ${feature_dim}
num_blocks: 2
kernel_size: ${kernel_size}
trunk:
_target_: cortex.model.trunk.SumTrunk
branches:
protein_property:
_target_: cortex.model.branch.Conv1dBranch
out_dim: 8
channel_dim: ${feature_dim}
num_blocks: 0
kernel_size: ${kernel_size}
tasks:
protein_property:
log_fluorescence:
_target_: cortex.task.RegressionTask
input_map:
protein_seq: ['tokenized_seq']
outcome_cols: ['log_fluorescence']
root_key: protein_seq
ensemble_size: 1
data_module:
_target_: cortex.data.data_module.TaskDataModule
_recursive_: false
batch_size: 2
dataset_config:
_target_: cortex.data.dataset.TAPEFluorescenceDataset
root: ${data_dir}
download: true
train: ???
batch_size: 2
max_epochs: 0
data_dir: ./.cache
20 changes: 20 additions & 0 deletions tutorials/hydra/3_training_a_neural_tree.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
defaults:
- tree: sequence_model
- roots: [protein_seq]
- trunk: sum_trunk
- branches: [protein_property]
- tasks:
- log_fluorescence

feature_dim: 32
kernel_size: 3
batch_size: 32
max_epochs: 2
data_dir: ./.cache
wandb_mode: offline
random_seed: 42

trainer:
_target_: lightning.Trainer
max_epochs: ${max_epochs}
num_sanity_val_steps: 1
6 changes: 6 additions & 0 deletions tutorials/hydra/branches/protein_property.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
protein_property:
_target_: cortex.model.branch.Conv1dBranch
out_dim: 8
channel_dim: ${feature_dim}
num_blocks: 0
kernel_size: ${kernel_size}
12 changes: 12 additions & 0 deletions tutorials/hydra/roots/protein_seq.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
protein_seq:
_target_: cortex.model.root.Conv1dRoot
tokenizer_transform:
_target_: cortex.transforms.HuggingFaceTokenizerTransform
tokenizer:
_target_: cortex.tokenization.ProteinSequenceTokenizerFast
max_len: 256
embed_dim: ${feature_dim}
channel_dim: ${feature_dim}
out_dim: ${feature_dim}
num_blocks: 2
kernel_size: ${kernel_size}
17 changes: 17 additions & 0 deletions tutorials/hydra/tasks/log_fluorescence.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
protein_property:
log_fluorescence:
_target_: cortex.task.RegressionTask
input_map:
protein_seq: ['tokenized_seq']
outcome_cols: ['log_fluorescence']
root_key: protein_seq
ensemble_size: 1
data_module:
_target_: cortex.data.data_module.TaskDataModule
_recursive_: false
batch_size: ${batch_size}
dataset_config:
_target_: cortex.data.dataset.TAPEFluorescenceDataset
root: ${data_dir}
download: true
train: ???
16 changes: 16 additions & 0 deletions tutorials/hydra/tree/sequence_model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
_target_: cortex.model.tree.SequenceModelTree
_recursive_: false
fit_cfg:
reinitialize_roots: false
linear_probing: false
weight_averaging: null
optimizer:
_target_: torch.optim.Adam
lr: 5e-3
weight_decay: 0.
betas: [0.99, 0.999]
fused: false
lr_scheduler:
_target_: transformers.get_cosine_schedule_with_warmup
num_warmup_steps: 1
num_training_steps: ${max_epochs}
1 change: 1 addition & 0 deletions tutorials/hydra/trunk/sum_trunk.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
_target_: cortex.model.trunk.SumTrunk

0 comments on commit acfe98b

Please sign in to comment.