Skip to content

Commit

Permalink
Merge pull request #908 from zalandoresearch/samplers
Browse files Browse the repository at this point in the history
GH-678: Data Samplers
  • Loading branch information
yosipk authored Jul 19, 2019
2 parents 8b50e2e + 36a5638 commit 2f04c94
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 2 deletions.
134 changes: 134 additions & 0 deletions flair/samplers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import logging
from collections import defaultdict

from torch.utils.data.sampler import Sampler
import random, torch

from flair.data import FlairDataset


log = logging.getLogger("flair")


class ImbalancedClassificationDatasetSampler(Sampler):
"""Use this to upsample rare classes and downsample common classes in your unbalanced classification dataset.
"""

def __init__(self, data_source: FlairDataset):
"""
Initialize by passing a classification dataset with labels, i.e. either TextClassificationDataSet or
:param data_source:
"""
super().__init__(data_source)

self.indices = list(range(len(data_source)))
self.num_samples = len(data_source)

# first determine the distribution of classes in the dataset
label_count = defaultdict(int)
for sentence in data_source:
for label in sentence.get_label_names():
label_count[label] += 1

# weight for each sample
offset = 0
weights = [
1.0 / (offset + label_count[data_source[idx].get_label_names()[0]])
for idx in self.indices
]

self.weights = torch.DoubleTensor(weights)

def __iter__(self):
return (
self.indices[i]
for i in torch.multinomial(self.weights, self.num_samples, replacement=True)
)

def __len__(self):
return self.num_samples


class ChunkSampler(Sampler):
"""Splits data into blocks and randomizes them before sampling. This causes some order of the data to be preserved,
while still shuffling the data.
"""

def __init__(self, data_source, block_size=5, plus_window=5):
"""Initialize by passing a block_size and a plus_window parameter.
:param data_source: dataset to sample from
:param block_size: minimum size of each block
:param plus_window: randomly adds between 0 and this value to block size at each epoch
"""
super().__init__(data_source)
self.data_source = data_source
self.num_samples = len(self.data_source)

self.block_size = block_size
self.plus_window = plus_window

def __iter__(self):
data = list(range(len(self.data_source)))

blocksize = self.block_size + random.randint(0, self.plus_window)

log.info(
f"Chunk sampling with blocksize = {blocksize} ({self.block_size} + {self.plus_window})"
)

# Create blocks
blocks = [data[i : i + blocksize] for i in range(0, len(data), blocksize)]
# shuffle the blocks
random.shuffle(blocks)
# concatenate the shuffled blocks
data[:] = [b for bs in blocks for b in bs]
return iter(data)

def __len__(self):
return self.num_samples


class ExpandingChunkSampler(Sampler):
"""Splits data into blocks and randomizes them before sampling. Block size grows with each epoch.
This causes some order of the data to be preserved, while still shuffling the data.
"""

def __init__(self, data_source):
"""Initialize by passing a block_size and a plus_window parameter.
:param data_source: dataset to sample from
"""
super().__init__(data_source)
self.data_source = data_source
self.num_samples = len(self.data_source)

self.block_size = 1
self.plus_window = 0
self.epoch_count = 0

def __iter__(self):

self.epoch_count += 1

data = list(range(len(self.data_source)))
blocksize = self.block_size + random.randint(0, self.plus_window)

log.info(
f"Chunk sampling with blocksize = {blocksize} ({self.block_size} + {self.plus_window})"
)

# Create blocks
blocks = [data[i : i + blocksize] for i in range(0, len(data), blocksize)]
# shuffle the blocks
random.shuffle(blocks)
# concatenate the shuffled blocks
data[:] = [b for bs in blocks for b in bs]

if self.epoch_count % 2 == 0:
self.block_size += 1
else:
self.plus_window += 1

return iter(data)

def __len__(self):
return self.num_samples
36 changes: 35 additions & 1 deletion tests/test_model_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
DocumentRNNEmbeddings,
)
from flair.models import SequenceTagger, TextClassifier, LanguageModel
from flair.samplers import ImbalancedClassificationDatasetSampler
from flair.trainers import ModelTrainer
from flair.trainers.language_model_trainer import LanguageModelTrainer, TextCorpus
from flair.training_utils import EvaluationMetric
from flair.optim import AdamW


Expand Down Expand Up @@ -418,6 +418,40 @@ def test_train_load_use_classifier(results_base_path, tasks_base_path):
shutil.rmtree(results_base_path)


@pytest.mark.integration
def test_train_classifier_with_sampler(results_base_path, tasks_base_path):
corpus = flair.datasets.ClassificationCorpus(tasks_base_path / "imdb")
label_dict = corpus.make_label_dictionary()

word_embedding: WordEmbeddings = WordEmbeddings("turian")
document_embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings(
[word_embedding], 32, 1, False, 64, False, False
)

model = TextClassifier(document_embeddings, label_dict, False)

trainer = ModelTrainer(model, corpus)
trainer.train(
results_base_path,
max_epochs=2,
shuffle=False,
sampler=ImbalancedClassificationDatasetSampler,
)

sentence = Sentence("Berlin is a really nice city.")

for s in model.predict(sentence):
for l in s.labels:
assert l.value is not None
assert 0.0 <= l.score <= 1.0
assert type(l.score) is float

loaded_model = TextClassifier.load(results_base_path / "final-model.pt")

# clean up results directory
shutil.rmtree(results_base_path)


@pytest.mark.integration
def test_train_load_use_classifier_with_prob(results_base_path, tasks_base_path):
corpus = flair.datasets.ClassificationCorpus(tasks_base_path / "imdb")
Expand Down
1 change: 0 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@

trainer.train(
"resources/taggers/example-ner",
EvaluationMetric.MICRO_F1_SCORE,
learning_rate=0.1,
mini_batch_size=32,
max_epochs=20,
Expand Down

0 comments on commit 2f04c94

Please sign in to comment.