From 32ef533be8c29ae4a10b91087a23445e1473476f Mon Sep 17 00:00:00 2001 From: Arjun Subramonian Date: Tue, 18 May 2021 23:32:19 -0700 Subject: [PATCH 1/2] added shuffle disable option in BucketBatchSampler --- CHANGELOG.md | 1 + allennlp/data/samplers/bucket_batch_sampler.py | 11 ++++++++++- tests/data/samplers/bucket_batch_sampler_test.py | 14 ++++++++++++++ 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f8a9b13503..79457e13f8f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `SpanExtractorWithSpanWidthEmbedding`, putting specific span embedding computations into the `_embed_spans` method and leaving the common code in `SpanExtractorWithSpanWidthEmbedding` to unify the arguments, and modified `BidirectionalEndpointSpanExtractor`, `EndpointSpanExtractor` and `SelfAttentiveSpanExtractor` accordingly. Now, `SelfAttentiveSpanExtractor` can also embed span widths. - Added a `min_steps` parameter to `BeamSearch` to set a minimum length for the predicted sequences. - Added the `FinalSequenceScorer` abstraction to calculate the final scores of the generated sequences in `BeamSearch`. +- Added `shuffle` argument to `BucketBatchSampler` which allows for disabling shuffling. ### Fixed diff --git a/allennlp/data/samplers/bucket_batch_sampler.py b/allennlp/data/samplers/bucket_batch_sampler.py index d65a676f14c..38d01746eac 100644 --- a/allennlp/data/samplers/bucket_batch_sampler.py +++ b/allennlp/data/samplers/bucket_batch_sampler.py @@ -57,6 +57,10 @@ class BucketBatchSampler(BatchSampler): If `True`, the sampler will drop the last batch if its size would be less than batch_size`. + shuffle : `bool`, (default = `True`) + If `False`, the sampler won't shuffle the batches. padding_noise will be ignored and set + to `0.0`. + """ def __init__( @@ -65,11 +69,15 @@ def __init__( sorting_keys: List[str] = None, padding_noise: float = 0.1, drop_last: bool = False, + shuffle: bool = True, ): self.sorting_keys = sorting_keys self.padding_noise = padding_noise self.batch_size = batch_size self.drop_last = drop_last + self.shuffle = shuffle + if not shuffle: + self.padding_noise = 0.0 def _argsort_by_padding( self, instances: Iterable[Instance] @@ -113,7 +121,8 @@ def get_batch_indices(self, instances: Sequence[Instance]) -> Iterable[List[int] if self.drop_last and len(batch_indices) < self.batch_size: continue batches.append(batch_indices) - random.shuffle(batches) + if self.shuffle: + random.shuffle(batches) for batch in batches: yield batch diff --git a/tests/data/samplers/bucket_batch_sampler_test.py b/tests/data/samplers/bucket_batch_sampler_test.py index 3a972facdc2..450c825cc3c 100644 --- a/tests/data/samplers/bucket_batch_sampler_test.py +++ b/tests/data/samplers/bucket_batch_sampler_test.py @@ -24,6 +24,20 @@ def test_create_batches_groups_correctly(self): expected_groups.remove(group) assert expected_groups == [] + def test_disable_shuffle(self): + sampler = BucketBatchSampler(batch_size=2, sorting_keys=["text"], shuffle=False) + + grouped_instances = [] + for indices in sampler.get_batch_indices(self.instances): + grouped_instances.append([self.instances[idx] for idx in indices]) + expected_groups = [ + [self.instances[4], self.instances[2]], + [self.instances[0], self.instances[1]], + [self.instances[3]], + ] + for idx, group in enumerate(grouped_instances): + assert group == expected_groups[idx] + def test_guess_sorting_key_picks_the_longest_key(self): sampler = BucketBatchSampler(batch_size=2, padding_noise=0) instances = [] From 1cb7fd274629f3dc11ae23b528baaef810242b00 Mon Sep 17 00:00:00 2001 From: ArjunSubramonian Date: Wed, 19 May 2021 10:10:38 -0700 Subject: [PATCH 2/2] Update allennlp/data/samplers/bucket_batch_sampler.py Co-authored-by: Pete --- allennlp/data/samplers/bucket_batch_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/allennlp/data/samplers/bucket_batch_sampler.py b/allennlp/data/samplers/bucket_batch_sampler.py index 38d01746eac..e4aa125741f 100644 --- a/allennlp/data/samplers/bucket_batch_sampler.py +++ b/allennlp/data/samplers/bucket_batch_sampler.py @@ -58,7 +58,7 @@ class BucketBatchSampler(BatchSampler): its size would be less than batch_size`. shuffle : `bool`, (default = `True`) - If `False`, the sampler won't shuffle the batches. padding_noise will be ignored and set + If `False`, the sampler won't shuffle the batches. `padding_noise` will be ignored and set to `0.0`. """