From 181e2b2b2d679a12f6dbb430853d92508e8d71f2 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 20 Jan 2025 21:06:31 -0500 Subject: [PATCH] Address subset issue highlighted in #95 (#97) * fix: address subset issue highlighted in #95 * test: add checks that would have caught the test set subset issue --- aviary/core.py | 2 +- aviary/utils.py | 18 ++++++++---------- tests/test_cgcnn_classification.py | 1 + tests/test_cgcnn_regression.py | 1 + tests/test_roost_classification.py | 2 ++ tests/test_roost_regression.py | 1 + tests/test_wren_classification.py | 1 + tests/test_wren_regression.py | 1 + 8 files changed, 16 insertions(+), 11 deletions(-) diff --git a/aviary/core.py b/aviary/core.py index f8a3ee7c..7425c98c 100644 --- a/aviary/core.py +++ b/aviary/core.py @@ -346,7 +346,7 @@ def evaluate( metrics_str = " ".join( f"{key} {val:<9.2f}" for key, val in avrg_metrics[target].items() ) - print(f"{action:>9}: {target} {metrics_str}") + print(f"{action:>9}: {target} N {len(data_loader):,} {metrics_str}") return avrg_metrics diff --git a/aviary/utils.py b/aviary/utils.py index bbe35667..6da411b9 100644 --- a/aviary/utils.py +++ b/aviary/utils.py @@ -316,17 +316,13 @@ def train_ensemble( when early stopping. Defaults to None. verbose (bool, optional): Whether to show progress bars for each epoch. """ - if isinstance(train_set, Subset): - train_set = train_set.dataset - if isinstance(val_set, Subset): - val_set = val_set.dataset - train_loader = DataLoader(train_set, **data_params) print(f"Training on {len(train_set):,} samples") if val_set is not None: data_params.update({"batch_size": 16 * data_params["batch_size"]}) val_loader = DataLoader(val_set, **data_params) + print(f"Validating on {len(val_set):,} samples") else: val_loader = None @@ -354,7 +350,13 @@ def train_ensemble( for target, normalizer in normalizer_dict.items(): if normalizer is not None: - sample_target = Tensor(train_set.df[target].values) + if isinstance(train_set, Subset): + sample_target = Tensor( + train_set.dataset.df[target].iloc[train_set.indices].values + ) + else: + sample_target = Tensor(train_set.df[target].values) + if not restart_params["resume"]: normalizer.fit(sample_target) print(f"Dummy MAE: {(sample_target - normalizer.mean).abs().mean():.4f}") @@ -455,10 +457,6 @@ def results_multitask( "------------Evaluate model on Test Set------------\n" "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n" ) - - if isinstance(test_set, Subset): - test_set = test_set.dataset - test_loader = DataLoader(test_set, **data_params) print(f"Testing on {len(test_set):,} samples") diff --git a/tests/test_cgcnn_classification.py b/tests/test_cgcnn_classification.py index e75d0e5c..0c8d10cb 100644 --- a/tests/test_cgcnn_classification.py +++ b/tests/test_cgcnn_classification.py @@ -136,5 +136,6 @@ def test_cgcnn_clf(df_matbench_phonons): ens_acc, *_, ens_roc_auc = get_metrics(targets, ens_logits, task).values() + assert len(targets) == len(test_set) == len(test_idx) assert ens_acc > 0.85 assert ens_roc_auc > 0.9 diff --git a/tests/test_cgcnn_regression.py b/tests/test_cgcnn_regression.py index 51d14b92..f6c908c6 100644 --- a/tests/test_cgcnn_regression.py +++ b/tests/test_cgcnn_regression.py @@ -135,6 +135,7 @@ def test_cgcnn_regression(df_matbench_phonons): mae, rmse, r2 = get_metrics(targets, y_ens, task).values() + assert len(targets) == len(test_set) == len(test_idx) assert r2 > 0.7 assert mae < 150 assert rmse < 300 diff --git a/tests/test_roost_classification.py b/tests/test_roost_classification.py index deb500d7..86322bf1 100644 --- a/tests/test_roost_classification.py +++ b/tests/test_roost_classification.py @@ -138,5 +138,7 @@ def test_roost_clf(df_matbench_phonons): ens_acc, *_, ens_roc_auc = get_metrics(targets, ens_logits, task).values() + assert len(logits) == ensemble + assert len(targets) == len(test_set) == len(test_idx) assert ens_acc > 0.9 assert ens_roc_auc > 0.9 diff --git a/tests/test_roost_regression.py b/tests/test_roost_regression.py index b24f2311..125b0031 100644 --- a/tests/test_roost_regression.py +++ b/tests/test_roost_regression.py @@ -137,6 +137,7 @@ def test_roost_regression(df_matbench_phonons): mae, rmse, r2 = get_metrics(targets, y_ens, task).values() + assert len(targets) == len(test_set) == len(test_idx) assert r2 > 0.7 assert mae < 150 assert rmse < 300 diff --git a/tests/test_wren_classification.py b/tests/test_wren_classification.py index d6f01f9c..c7db96e0 100644 --- a/tests/test_wren_classification.py +++ b/tests/test_wren_classification.py @@ -146,5 +146,6 @@ def test_wren_clf(df_matbench_phonons_wyckoff): ens_acc, *_, ens_roc_auc = get_metrics(targets, ens_logits, task).values() + assert len(targets) == len(test_set) == len(test_idx) assert ens_acc > 0.85 assert ens_roc_auc > 0.9 diff --git a/tests/test_wren_regression.py b/tests/test_wren_regression.py index c080dd2a..a8ad73db 100644 --- a/tests/test_wren_regression.py +++ b/tests/test_wren_regression.py @@ -145,6 +145,7 @@ def test_wren_regression(df_matbench_phonons_wyckoff): mae, rmse, r2 = get_metrics(targets, y_ens, task).values() + assert len(targets) == len(test_set) == len(test_idx) assert r2 > 0.7 assert mae < 150 assert rmse < 300