Skip to content

Commit

Permalink
add more unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelstanton committed Feb 12, 2024
1 parent c2ee29f commit d815089
Show file tree
Hide file tree
Showing 58 changed files with 780 additions and 309 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,4 @@ repos:
hooks:
- id: ruff
name: ruff
args: [--config=pyproject.toml]
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"cls_token": "<cls>",
"eos_token": "<eos>",
"mask_token": "<mask>",
"pad_token": "<pad>",
"unk_token": "<unk>"
}
6 changes: 6 additions & 0 deletions cortex/assets/protein_seq_tokenizer_32/tokenizer_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"clean_up_tokenization_spaces": true,
"do_lower_case": false,
"model_max_length": 1024,
"tokenizer_class": "PmlmTokenizer"
}
32 changes: 32 additions & 0 deletions cortex/assets/protein_seq_tokenizer_32/vocab.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
<cls>
<pad>
<eos>
<unk>
L
A
G
V
S
E
R
T
I
D
P
K
Q
N
F
Y
M
H
W
C
B
U
Z
O
.
-
<null_1>
<mask>
8 changes: 4 additions & 4 deletions cortex/cmdline/train_cortex_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ def main(cfg):
"""
general setup
"""
random.seed(
None
) # make sure random seed resets between Hydra multirun jobs for random job-name generation
random.seed(None) # make sure random seed resets between Hydra multirun jobs for random job-name generation

try:
with warnings.catch_warnings():
Expand Down Expand Up @@ -87,7 +85,9 @@ def execute(cfg):
trainer.fit(
model,
train_dataloaders=CombinedLoader(leaf_train_loaders, mode="min_size"),
val_dataloaders=CombinedLoader(task_test_loaders, mode="max_size_cycle"), # change to max_size when lightning upgraded to >1.9.5
val_dataloaders=CombinedLoader(
task_test_loaders, mode="max_size_cycle"
), # change to max_size when lightning upgraded to >1.9.5
)

# save model
Expand Down
25 changes: 25 additions & 0 deletions cortex/constants/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
ALIGNMENT_GAP_TOKEN = "-"
CANON_AMINO_ACIDS = [
"A",
"R",
"N",
"D",
"C",
"E",
"Q",
"G",
"H",
"I",
"L",
"K",
"M",
"F",
"P",
"S",
"T",
"W",
"Y",
"V",
]
COMPLEX_SEP_TOKEN = "."
NULL_TOKENS = ["<null_1>"]
8 changes: 2 additions & 6 deletions cortex/corruption/_abstract_corruption.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ class CorruptionProcess(ABC):
the corruption interface.
"""

def __init__(
self, schedule: str = "cosine", max_steps: int = 1000, *args, **kwargs
):
def __init__(self, schedule: str = "cosine", max_steps: int = 1000, *args, **kwargs):
betas = get_named_beta_schedule(schedule, max_steps)

# Use float64 for accuracy.
Expand Down Expand Up @@ -68,9 +66,7 @@ def __call__(
is_corrupted = torch.full_like(x_start, False, dtype=torch.bool)
return x_start, is_corrupted

x_corrupt, is_corrupted = self._corrupt(
x_start, corrupt_frac=corrupt_frac, *args, **kwargs
)
x_corrupt, is_corrupted = self._corrupt(x_start, corrupt_frac=corrupt_frac, *args, **kwargs)
# only change values where corruption_allowed is True
if corruption_allowed is not None:
corruption_allowed = corruption_allowed.to(x_start.device)
Expand Down
6 changes: 2 additions & 4 deletions cortex/corruption/_diffusion_noise_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
"""

import math
from typing import Callable

import numpy as np
from typing import Callable


def get_named_beta_schedule(schedule_name: str, num_diffusion_timesteps: int) -> np.ndarray:
Expand Down Expand Up @@ -53,9 +53,7 @@ def get_named_beta_schedule(schedule_name: str, num_diffusion_timesteps: int) ->
beta_mid = scale * 0.0001 # scale * 0.02
beta_end = scale * 0.02
first_part = np.linspace(beta_start, beta_mid, 10, dtype=np.float64)
second_part = np.linspace(
beta_mid, beta_end, num_diffusion_timesteps - 10, dtype=np.float64
)
second_part = np.linspace(beta_mid, beta_end, num_diffusion_timesteps - 10, dtype=np.float64)
return np.concatenate([first_part, second_part])
else:
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
Expand Down
8 changes: 2 additions & 6 deletions cortex/corruption/_gaussian_corruption.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,8 @@ def __init__(self, noise_variance: float = 10.0, *args, **kwargs):
super().__init__(*args, **kwargs)
self.noise_variance = noise_variance

def _corrupt(
self, x_start: torch.Tensor, corrupt_frac: float, *args, **kwargs
) -> tuple[torch.Tensor]:
def _corrupt(self, x_start: torch.Tensor, corrupt_frac: float, *args, **kwargs) -> tuple[torch.Tensor]:
noise_scale = corrupt_frac * math.sqrt(self.noise_variance)
x_corrupt = (1.0 - corrupt_frac) * x_start + noise_scale * torch.randn_like(
x_start
)
x_corrupt = (1.0 - corrupt_frac) * x_start + noise_scale * torch.randn_like(x_start)
is_corrupted = torch.ones_like(x_start, dtype=torch.bool)
return x_corrupt, is_corrupted
10 changes: 1 addition & 9 deletions cortex/corruption/_mask_corruption.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,7 @@ def __call__(
*args,
**kwargs
):
return super().__call__(
x_start,
timestep,
corrupt_frac,
corruption_allowed,
mask_val=mask_val,
*args,
**kwargs
)
return super().__call__(x_start, timestep, corrupt_frac, corruption_allowed, mask_val=mask_val, *args, **kwargs)

def _corrupt(
self, x_start: torch.Tensor, corrupt_frac: float, mask_val: int, *args, **kwargs
Expand Down
2 changes: 1 addition & 1 deletion cortex/data/data_module/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from ._task_data_module import TaskDataModule
from ._task_data_module import TaskDataModule
14 changes: 3 additions & 11 deletions cortex/data/data_module/_task_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def setup(self, stage=None):
self._dataset_config.base_dataset.train = True
train_val = hydra.utils.instantiate(self._dataset_config)
if self._train_on_everything:
train_val.df = pd.concat(
[train_val.df, self.datasets["test"].df], ignore_index=True
)
train_val.df = pd.concat([train_val.df, self.datasets["test"].df], ignore_index=True)
# columns = train_val.columns
train_dataset, val_dataset = random_split(
train_val,
Expand Down Expand Up @@ -131,9 +129,7 @@ def get_dataloader(self, split: str = "train"):
# Full batch for evaluation on the test set
if split == "test":
self._dataloader_kwargs["batch_size"] = len(self.datasets[split])
dataloader = DataLoader(
self.datasets[split], shuffle=True, drop_last=True, **self._dataloader_kwargs
)
dataloader = DataLoader(self.datasets[split], shuffle=True, drop_last=True, **self._dataloader_kwargs)
if split == "test":
self._dataloader_kwargs["batch_size"] = self._batch_size
return dataloader
Expand All @@ -142,11 +138,7 @@ def _partition_train_indices(self):
if self._balance_train_partition is None:
return [list(range(len(self.datasets["train"])))]

train_df = (
self.datasets["train_val"]
._data.iloc[self.datasets["train"].indices]
.reset_index(drop=True)
)
train_df = self.datasets["train_val"]._data.iloc[self.datasets["train"].indices].reset_index(drop=True)
if isinstance(self._balance_train_partition, str):
partition = [self._balance_train_partition]
else:
Expand Down
14 changes: 6 additions & 8 deletions cortex/data/dataset/_dataframe_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
Simple Pytorch Dataset for reading from a dataframe.
"""

from collections import OrderedDict
from typing import Any

import pandas as pd
from torch.utils.data import Dataset
from collections import OrderedDict
from typing import Any


class DataFrameDataset(Dataset):
Expand All @@ -23,14 +23,14 @@ def __init__(self, data: pd.DataFrame, columns: list[str], dedup: bool = True):

def __len__(self):
return len(self.df)

def _fetch_item(self, index) -> pd.DataFrame:
# check if int or slice
if isinstance(index, int):
item = self._data.iloc[index : index + 1]
else:
item = self._data.iloc[index]

def _format_item(self, item: pd.DataFrame) -> OrderedDict[str, Any]:
if len(item) == 1:
return OrderedDict([(c, item[c].iloc[0]) for c in self.columns])
Expand All @@ -39,14 +39,12 @@ def _format_item(self, item: pd.DataFrame) -> OrderedDict[str, Any]:
def __getitem__(self, index) -> OrderedDict[str, Any]:
item = self._fetch_item(index)
return self._format_item(item)


def ordered_dict_collator(
batch: list[OrderedDict[str, Any]],
) -> OrderedDict[str, Any]:
"""
Collates a batch of OrderedDicts into a single OrderedDict.
"""
return OrderedDict(
[(key, [item[key] for item in batch]) for key in batch[0].keys()]
)
return OrderedDict([(key, [item[key] for item in batch]) for key in batch[0].keys()])
10 changes: 4 additions & 6 deletions cortex/data/dataset/_transformed_dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import Optional
from collections import OrderedDict
from typing import Any, Optional

import hydra
import pandas as pd
from omegaconf import DictConfig
from torch.nn import Sequential
from torch.utils.data import ConcatDataset, Dataset

from cortex.datasets._dataframe_dataset import DataFrameDataset

from collections import OrderedDict
from typing import Any

class TransformedDataset(DataFrameDataset):
def __init__(
Expand All @@ -21,9 +21,7 @@ def __init__(
if isinstance(base_dataset, DictConfig):
base_dataset = hydra.utils.instantiate(base_dataset)
if isinstance(base_dataset, ConcatDataset):
data = pd.concat(
[dataset._data for dataset in base_dataset.datasets], ignore_index=True
)
data = pd.concat([dataset._data for dataset in base_dataset.datasets], ignore_index=True)
else:
data = base_dataset._data.reset_index(drop=True)

Expand Down
2 changes: 1 addition & 1 deletion cortex/logging/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from ._wandb_setup import wandb_setup
from ._wandb_setup import wandb_setup
4 changes: 1 addition & 3 deletions cortex/metrics/_spearman_rho.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,5 @@ def spearman_rho(scores: np.ndarray, targets: np.ndarray):
return stats.spearmanr(targets, scores).correlation
spearman_rho = 0.0
for idx in range(targets.shape[-1]):
spearman_rho += (
stats.spearmanr(targets[..., idx], scores[..., idx]).correlation / targets.shape[-1]
)
spearman_rho += stats.spearmanr(targets[..., idx], scores[..., idx]).correlation / targets.shape[-1]
return spearman_rho
5 changes: 5 additions & 0 deletions cortex/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ._weight_averaging import online_weight_update_

__all__ = [
"online_weight_update_",
]
28 changes: 28 additions & 0 deletions cortex/model/_weight_averaging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Iterable, Optional

import torch


def online_weight_update_(
src_state_dict: dict[str, torch.Tensor],
tgt_state_dict: dict[str, torch.Tensor],
decay_rate: float,
param_prefixes: Optional[Iterable[str]] = None,
):
if param_prefixes is None:
param_keys = src_state_dict.keys()
else:
param_keys = [k for k in src_state_dict.keys() if any(k.startswith(prefix) for prefix in param_prefixes)]

for param_key in param_keys:
param_src = src_state_dict[param_key]
param_tgt = tgt_state_dict[param_key]
if torch.is_tensor(param_tgt) and param_tgt.dtype is not torch.bool and param_tgt.dtype is not torch.long:
param_tgt.mul_(decay_rate)
param_tgt.data.add_(param_src.data * (1.0 - decay_rate))
elif torch.is_tensor(param_tgt):
param_tgt.copy_(param_src)
else:
raise RuntimeError("Parameter {} is not a tensor.".format(param_key))

return None
4 changes: 4 additions & 0 deletions cortex/model/block/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from ._conv1d_resid_block import Conv1dResidBlock

__all__ = [
"Conv1dResidBlock",
]
6 changes: 2 additions & 4 deletions cortex/model/block/_conv1d_resid_block.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from torch import nn, Tensor
from torch import Tensor, nn

from cortex.model.elemental import MaskLayerNorm1d, swish

Expand Down Expand Up @@ -56,9 +56,7 @@ def __init__(
self.act_fn = nn.ReLU(inplace=True)

if not in_channels == out_channels:
self.proj = nn.Conv1d(
in_channels, out_channels, kernel_size=1, padding="same", stride=1
)
self.proj = nn.Conv1d(in_channels, out_channels, kernel_size=1, padding="same", stride=1)
else:
self.proj = None

Expand Down
9 changes: 8 additions & 1 deletion cortex/model/branch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,9 @@
from ._abstract_branch import BranchNode, BranchNodeOutput
from ._conv1d_branch import Conv1dBranch
from ._conv1d_branch import Conv1dBranch, Conv1dBranchOutput

__all__ = [
"BranchNode",
"BranchNodeOutput",
"Conv1dBranch",
"Conv1dBranchOutput",
]
Loading

0 comments on commit d815089

Please sign in to comment.