From 86504e6b57b26bb2bb362e33c0edc3e49c0760fe Mon Sep 17 00:00:00 2001 From: Jacob Morrison Date: Tue, 22 Jun 2021 16:34:16 -0700 Subject: [PATCH] Making model test case consistently random (#5278) --- CHANGELOG.md | 3 ++- allennlp/common/testing/model_test_case.py | 15 ++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 95081b6942a..6c8e8627e5d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,7 +22,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed Broken link in `allennlp.fairness.fairness_metrics.Separation` docs - Ensured all `allennlp` submodules are imported with `allennlp.common.plugins.import_plugins()`. - Fixed `IndexOutOfBoundsException` in `MultiOptimizer` when checking if optimizer received any parameters. -- Removed confusing zero mask from VilBERT +- Removed confusing zero mask from VilBERT. +- Ensured `ensure_model_can_train_save_and_load` is consistently random. ### Changed diff --git a/allennlp/common/testing/model_test_case.py b/allennlp/common/testing/model_test_case.py index 48afe3307d8..c58c4f6777c 100644 --- a/allennlp/common/testing/model_test_case.py +++ b/allennlp/common/testing/model_test_case.py @@ -119,11 +119,6 @@ def ensure_model_can_train_save_and_load( Specifies which loss to test. For example, which_loss may be "adversary_loss" for `adversarial_bias_mitigator`. """ - if seed is not None: - random.seed(seed) - numpy.random.seed(seed) - torch.manual_seed(seed) - save_dir = self.TEST_DIR / "save_and_load_test" archive_file = save_dir / "model.tar.gz" model = train_model_from_file(param_file, save_dir, overrides=overrides) @@ -158,12 +153,22 @@ def ensure_model_can_train_save_and_load( data_loader_params["shuffle"] = False data_loader_params2 = Params(copy.deepcopy(data_loader_params.as_dict())) + if seed is not None: + random.seed(seed) + numpy.random.seed(seed) + torch.manual_seed(seed) + print("Reading with original model") data_loader = DataLoader.from_params( params=data_loader_params, reader=reader, data_path=params["validation_data_path"] ) data_loader.index_with(model.vocab) + if seed is not None: + random.seed(seed) + numpy.random.seed(seed) + torch.manual_seed(seed) + print("Reading with loaded model") data_loader2 = DataLoader.from_params( params=data_loader_params2, reader=reader, data_path=params["validation_data_path"]