-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathcompute_rosetta_standardization.py
95 lines (74 loc) · 4 KB
/
compute_rosetta_standardization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
""" compute standardization parameters for rosetta datasets """
import os
from os.path import join, dirname, isfile
from typing import Optional, cast
import argparse
import logging
import pandas as pd
try:
from . import split_dataset as sd
except ImportError:
import split_dataset as sd
logger = logging.getLogger("METL." + __name__)
logger.setLevel(logging.DEBUG)
def save_standardize_params(ds_fn: str,
split_dir: Optional[str] = None,
energies_start_col: str = "total_score"):
""" save the means and standard deviations of all rosetta energies in dataset.
if there are multiple different PDBs, such as in the global rosetta dataset,
then the means and standard deviations are computed for each PDB separately """
ds = cast(pd.DataFrame, pd.read_hdf(ds_fn, key="variant"))
# default output directory for full-dataset standardization params
out_dir = join(dirname(ds_fn), "standardization_params")
out_suffix = "all"
# if params are being calculated on just the training set, grab a dataframe of just the training set
# and set the output directory to the split directory because the params will be specific to this split
if split_dir is not None:
# these params will be specific to this split
out_dir = join(split_dir, "standardization_params")
out_suffix = "train"
# given a split dir, so only compute the standardization parameters on the train set
set_idxs = sd.load_split_dir(split_dir)["train"]
ds = ds.iloc[set_idxs]
logger.info("computing standardization params on training set only")
else:
logger.info("computing standardization params on full dataset")
# ensure the output directory exists
os.makedirs(out_dir, exist_ok=True)
logger.info("saving standardization params to: {}".format(out_dir))
# standardization parameters are computed per-pdb
g = ds.groupby("pdb_fn")
g_mean = g.mean(numeric_only=True)
g_mean = g_mean.iloc[:, list(g_mean.columns).index(energies_start_col):]
# ddof=0 to match sklearn's StandardScaler (for a biased estimator of standard deviation)
g_std = g.std(ddof=0, numeric_only=True)
g_std = g_std.iloc[:, list(g_std.columns).index(energies_start_col):]
means_out_fn = join(out_dir, "energy_means_{}.tsv".format(out_suffix))
stds_out_fn = join(out_dir, "energy_stds_{}.tsv".format(out_suffix))
if isfile(means_out_fn) or isfile(stds_out_fn):
raise FileExistsError(
"Standardization params output file(s) already exist: {} or {}".format(means_out_fn, stds_out_fn))
g_mean.to_csv(means_out_fn, sep="\t", float_format="%.7f")
g_std.to_csv(stds_out_fn, sep="\t", float_format="%.7f")
def main(args):
save_standardize_params(args.ds_fn_h5, args.split_dir, args.energies_start_col)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("ds_fn_h5",
help="path to the rosetta dataset in hdf5 format",
type=str)
parser.add_argument("--split_dir",
help="path to the split directory containing the train/val/test split indices. if provided, "
"the standardization parameters will be computed on the training set only. this is "
"necessary for training a source model.",
type=str,
default=None,
required=False)
parser.add_argument("--energies_start_col",
help="the column name of the first energy term in the dataset. default is 'total_score'. "
"this is used to determine which columns in the dataset are energy terms. "
"leave this as default unless for some reason total_score is not the first energy term.",
type=str,
default="total_score",
required=False)
main(parser.parse_args())