Skip to content

Commit

Permalink
start testing on rfp dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelstanton committed Feb 13, 2024
1 parent e2448e9 commit 48a9f66
Show file tree
Hide file tree
Showing 32 changed files with 986 additions and 110 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
.vscode
.DS_Store
6 changes: 6 additions & 0 deletions cortex/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from importlib.metadata import PackageNotFoundError, version

try:
__version__ = version("pytorch-cortex")
except PackageNotFoundError:
__version__ = "unknown version"
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
binding:
protein_property:
_target_: cortex.model.branch.Conv1dBranch
out_dim: 8
embed_dim: ${channel_dim}
Expand Down
3 changes: 3 additions & 0 deletions cortex/config/hydra/general_settings/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# @package _global_
dtype: float
seed: 0 # random seed, set to null to use random seed
17 changes: 17 additions & 0 deletions cortex/config/hydra/logging/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# @package _global_
# Directories for loading and storing data
ckpt_name: ${job_name}
ckpt_file: ${ckpt_name}.pt
ckpt_cfg: ${ckpt_name}.yaml
save_ckpt: true

data_dir: /home/stantos5/scratch/code/remote/prescient-github/cortex/temp
project_name: cortex
__version__: null
exp_name: dry_run
job_name: null
timestamp: ${now:%Y-%m-%d_%H-%M-%S}
log_dir: ${data_dir}/${exp_name}/${job_name}/${timestamp} # use this directory for local output
wandb_mode: online
wandb_host: https://api.wandb.ai
warnings_filter: ignore
7 changes: 7 additions & 0 deletions cortex/config/hydra/model_globals/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# @package _global_
channel_dim: 128
embed_dim: 32
ensemble_size: 4
dropout_prob: 0.0
kernel_size: 5
pooling_type: mean
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
full_seq:
protein_seq:
_target_: cortex.model.root.Conv1dRoot
corruption_process:
_target_: cortex.corruption.MaskCorruptionProcess
tokenizer_transform:
_target_: cortex.transform.HuggingFaceTokenizerTransform
_target_: cortex.transforms.HuggingFaceTokenizerTransform
tokenizer:
_target_: cortex.tokenization.CachedBertTokenizerFast
max_len: 512
_target_: cortex.tokenization.ProteinSequenceTokenizerFast
max_len: 256
out_dim: ${embed_dim}
embed_dim: ${embed_dim}
channel_dim: ${channel_dim}
num_blocks: 4
num_blocks: 2
kernel_size: ${kernel_size}
dropout_prob: ${dropout_prob}
layernorm: true
Expand Down
23 changes: 23 additions & 0 deletions cortex/config/hydra/tasks/protein_property/delta_g.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
delta_g:
_target_: cortex.task.RegressionTask
input_map:
protein_seq: ['tokenized_seq']
outcome_cols: ['foldx_total_energy']
corrupt_train_inputs: true
corrupt_inference_inputs: false
root_key: protein_seq
nominal_label_var: 0.01
data_module:
_target_: cortex.data.data_module.TaskDataModule
_recursive_: false
batch_size: ${fit.batch_size}
balance_train_partition: null
drop_last: true
lengths: [1.0, 0.0]
train_on_everything: false
num_workers: ${num_workers}
dataset_config:
_target_: cortex.data.dataset.RedFluorescentProteinDataset
root: ${dataset_root_dir}
download: ${download_datasets}
train: ???
Empty file.
56 changes: 56 additions & 0 deletions cortex/config/hydra/train_protein_model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
defaults:
- general_settings: default
- logging: default
- model_globals: default
- roots: [protein_seq]
- trunk: default
- branches: [protein_property]
- tree: protein_model
- tasks:
- protein_property/delta_g
- _self_

fit:
batch_size: 32

trainer:
_target_: lightning.Trainer
accelerator: gpu
max_epochs: 64
devices: 1
# devices: 8
# strategy: ddp
num_sanity_val_steps: 1


tree:
_recursive_: false
fit_cfg:
reinitialize_roots: true
linear_probing: false
weight_averaging: null
optimizer:
_target_: torch.optim.Adam
lr: 5e-3
weight_decay: 0.
betas: [0.99, 0.999]
fused: false
lr_scheduler:
_target_: transformers.get_cosine_schedule_with_warmup
num_warmup_steps: 10
num_training_steps: ${trainer.max_epochs}

tasks:

protein_property:
delta_g:
# ensemble_size: ${ensemble_size}
ensemble_size: 1

train_on_everything: false
linear_probing: false
dataset_root_dir: /home/stantos5/scratch/datasets
download_datasets: true
num_workers: 2

ckpt_name: ${exp_name}_${job_name}
1 change: 1 addition & 0 deletions cortex/config/hydra/tree/protein_model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
_target_: cortex.model.tree.SequenceModelTree
18 changes: 2 additions & 16 deletions cortex/data/data_module/_task_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,10 @@ def __init__(

def setup(self, stage=None):
if stage == "fit":
self._dataset_config.base_dataset.train = True
self._dataset_config.train = True
train_val = hydra.utils.instantiate(self._dataset_config)
if self._train_on_everything:
train_val.df = pd.concat([train_val.df, self.datasets["test"].df], ignore_index=True)
# columns = train_val.columns
train_dataset, val_dataset = random_split(
train_val,
lengths=self._lengths,
Expand All @@ -83,21 +82,14 @@ def setup(self, stage=None):
self.datasets["train"] = train_dataset
self.datasets["val"] = val_dataset
if stage == "test":
self._dataset_config.base_dataset.train = False
self._dataset_config.train = False
test_dataset = hydra.utils.instantiate(self._dataset_config)
# columns = test_dataset.columns
self.datasets["test"] = test_dataset
if stage == "predict":
self._dataset_config.base_dataset.train = False
predict_dataset = hydra.utils.instantiate(self._dataset_config)
# columns = predict_dataset.columns
self.datasets["predict"] = predict_dataset

# there's probably a cleaner way to do this, the problem is you don't know the
# columns until you instantiate the dataset
# self._collate_fn = self._collate_fn or ordered_dict_collator
# self._dataloader_kwargs["collate_fn"] = self._collate_fn

def train_dataloader(self):
return self.get_dataloader(split="train")

Expand Down Expand Up @@ -145,10 +137,4 @@ def _partition_train_indices(self):
partition = list(self._balance_train_partition)

index_list = list(train_df.groupby(partition).indices.values())
# partition_col = self._balance_train_partition
# index_list = []
# for partition_val in train_df[partition_col].unique():
# partition_df = train_df[train_df[partition_col] == partition_val]
# partition_indices = partition_df.index.values.tolist()
# index_list.append(partition_indices)
return index_list
10 changes: 9 additions & 1 deletion cortex/data/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,10 @@
from ._dataframe_dataset import DataFrameDataset, ordered_dict_collator
from ._data_frame_dataset import DataFrameDataset, ordered_dict_collator
from ._rfp_dataset import RedFluorescentProteinDataset
from ._transformed_dataset import TransformedDataset

__all__ = [
"DataFrameDataset",
"TransformedDataset",
"RedFluorescentProteinDataset",
"ordered_dict_collator",
]
110 changes: 110 additions & 0 deletions cortex/data/dataset/_data_frame_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import os
from collections import OrderedDict
from pathlib import Path
from typing import Any, Optional, TypeVar, Union

import pandas as pd
from pandas import DataFrame
from torch.utils.data import Dataset

from cortex.io import download_and_extract_archive
from cortex.transforms import Transform

T = TypeVar("T")


class DataFrameDataset(Dataset):
_data: DataFrame
_name: str = "temp"
_target: str = "data.csv"
columns = None

def __init__(
self,
root: Union[str, Path],
*,
download: bool = False,
download_source: Optional[str] = None,
dedup: bool = True,
train: bool = True,
random_seed: int = 0xDEADBEEF,
**kwargs: Any,
) -> None:
"""
:param root: Root directory where the dataset subdirectory exists or,
if :attr:`download` is ``True``, the directory where the dataset
subdirectory will be created and the dataset downloaded.
"""
if isinstance(root, str):
root = Path(root).resolve()
self._root = root

path = self._root / self._name

if os.path.exists(path / self._target):
pass
elif download:
if download_source is None:
raise ValueError("If `download` is `True`, `download_source` must be provided.")
download_and_extract_archive(
resource=download_source,
source=path,
destination=path,
name=f"{self._name}.tar.gz",
remove_archive=True,
)
else:
raise ValueError(
f"Dataset not found at {path}. " "If `download` is `True`, the dataset will be downloaded."
)

if self._target.endswith(".csv"):
data = pd.read_csv(path / self._target, **kwargs)
elif self._target.endswith(".parquet"):
data = pd.read_parquet(path / self._target, **kwargs)
else:
raise ValueError(f"Unsupported file format: {self._target}")

if self.columns is None:
self.columns = list(data.columns)

if dedup:
data.drop_duplicates(inplace=True)

# split data into train and test using random seed
train_indices = data.sample(frac=0.8, random_state=random_seed).index
test_indices = data.index.difference(train_indices)

select_indices = train_indices if train else test_indices
self._data = data.loc[select_indices].reset_index(drop=True)

def __len__(self) -> int:
return len(self._data)

def _fetch_item(self, index) -> pd.DataFrame:
# check if int or slice
if isinstance(index, int):
item = self._data.iloc[index : index + 1]
else:
item = self._data.iloc[index]
return item

def _format_item(self, item: pd.DataFrame) -> OrderedDict[str, Any]:
if len(item) == 1:
return OrderedDict([(c, item[c].iloc[0]) for c in self.columns])
return OrderedDict([(c, item[c]) for c in self.columns])

def __getitem__(self, index) -> OrderedDict[str, Any]:
item = self._fetch_item(index)
return self._format_item(item)


def ordered_dict_collator(
batch: list[OrderedDict[str, Any]],
) -> OrderedDict[str, Any]:
"""
Collates a batch of OrderedDicts into a single OrderedDict.
"""
res = OrderedDict([(key, [item[key] for item in batch]) for key in batch[0].keys()])
res["batch_size"] = len(batch)
return res
50 changes: 0 additions & 50 deletions cortex/data/dataset/_dataframe_dataset.py

This file was deleted.

Loading

0 comments on commit 48a9f66

Please sign in to comment.