Skip to content

Commit

Permalink
Address subset issue highlighted in #95 (#97)
Browse files Browse the repository at this point in the history
* fix: address subset issue highlighted in #95

* test: add checks that would have caught the test set subset issue
  • Loading branch information
CompRhys authored Jan 21, 2025
1 parent 63e0ea0 commit 181e2b2
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 11 deletions.
2 changes: 1 addition & 1 deletion aviary/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def evaluate(
metrics_str = " ".join(
f"{key} {val:<9.2f}" for key, val in avrg_metrics[target].items()
)
print(f"{action:>9}: {target} {metrics_str}")
print(f"{action:>9}: {target} N {len(data_loader):,} {metrics_str}")

return avrg_metrics

Expand Down
18 changes: 8 additions & 10 deletions aviary/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,17 +316,13 @@ def train_ensemble(
when early stopping. Defaults to None.
verbose (bool, optional): Whether to show progress bars for each epoch.
"""
if isinstance(train_set, Subset):
train_set = train_set.dataset
if isinstance(val_set, Subset):
val_set = val_set.dataset

train_loader = DataLoader(train_set, **data_params)
print(f"Training on {len(train_set):,} samples")

if val_set is not None:
data_params.update({"batch_size": 16 * data_params["batch_size"]})
val_loader = DataLoader(val_set, **data_params)
print(f"Validating on {len(val_set):,} samples")
else:
val_loader = None

Expand Down Expand Up @@ -354,7 +350,13 @@ def train_ensemble(

for target, normalizer in normalizer_dict.items():
if normalizer is not None:
sample_target = Tensor(train_set.df[target].values)
if isinstance(train_set, Subset):
sample_target = Tensor(
train_set.dataset.df[target].iloc[train_set.indices].values
)
else:
sample_target = Tensor(train_set.df[target].values)

if not restart_params["resume"]:
normalizer.fit(sample_target)
print(f"Dummy MAE: {(sample_target - normalizer.mean).abs().mean():.4f}")
Expand Down Expand Up @@ -455,10 +457,6 @@ def results_multitask(
"------------Evaluate model on Test Set------------\n"
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n"
)

if isinstance(test_set, Subset):
test_set = test_set.dataset

test_loader = DataLoader(test_set, **data_params)
print(f"Testing on {len(test_set):,} samples")

Expand Down
1 change: 1 addition & 0 deletions tests/test_cgcnn_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,5 +136,6 @@ def test_cgcnn_clf(df_matbench_phonons):

ens_acc, *_, ens_roc_auc = get_metrics(targets, ens_logits, task).values()

assert len(targets) == len(test_set) == len(test_idx)
assert ens_acc > 0.85
assert ens_roc_auc > 0.9
1 change: 1 addition & 0 deletions tests/test_cgcnn_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def test_cgcnn_regression(df_matbench_phonons):

mae, rmse, r2 = get_metrics(targets, y_ens, task).values()

assert len(targets) == len(test_set) == len(test_idx)
assert r2 > 0.7
assert mae < 150
assert rmse < 300
2 changes: 2 additions & 0 deletions tests/test_roost_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,5 +138,7 @@ def test_roost_clf(df_matbench_phonons):

ens_acc, *_, ens_roc_auc = get_metrics(targets, ens_logits, task).values()

assert len(logits) == ensemble
assert len(targets) == len(test_set) == len(test_idx)
assert ens_acc > 0.9
assert ens_roc_auc > 0.9
1 change: 1 addition & 0 deletions tests/test_roost_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def test_roost_regression(df_matbench_phonons):

mae, rmse, r2 = get_metrics(targets, y_ens, task).values()

assert len(targets) == len(test_set) == len(test_idx)
assert r2 > 0.7
assert mae < 150
assert rmse < 300
1 change: 1 addition & 0 deletions tests/test_wren_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,5 +146,6 @@ def test_wren_clf(df_matbench_phonons_wyckoff):

ens_acc, *_, ens_roc_auc = get_metrics(targets, ens_logits, task).values()

assert len(targets) == len(test_set) == len(test_idx)
assert ens_acc > 0.85
assert ens_roc_auc > 0.9
1 change: 1 addition & 0 deletions tests/test_wren_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def test_wren_regression(df_matbench_phonons_wyckoff):

mae, rmse, r2 = get_metrics(targets, y_ens, task).values()

assert len(targets) == len(test_set) == len(test_idx)
assert r2 > 0.7
assert mae < 150
assert rmse < 300

0 comments on commit 181e2b2

Please sign in to comment.