diff --git a/flair/samplers.py b/flair/samplers.py new file mode 100644 index 0000000000..f070de143e --- /dev/null +++ b/flair/samplers.py @@ -0,0 +1,134 @@ +import logging +from collections import defaultdict + +from torch.utils.data.sampler import Sampler +import random, torch + +from flair.data import FlairDataset + + +log = logging.getLogger("flair") + + +class ImbalancedClassificationDatasetSampler(Sampler): + """Use this to upsample rare classes and downsample common classes in your unbalanced classification dataset. + """ + + def __init__(self, data_source: FlairDataset): + """ + Initialize by passing a classification dataset with labels, i.e. either TextClassificationDataSet or + :param data_source: + """ + super().__init__(data_source) + + self.indices = list(range(len(data_source))) + self.num_samples = len(data_source) + + # first determine the distribution of classes in the dataset + label_count = defaultdict(int) + for sentence in data_source: + for label in sentence.get_label_names(): + label_count[label] += 1 + + # weight for each sample + offset = 0 + weights = [ + 1.0 / (offset + label_count[data_source[idx].get_label_names()[0]]) + for idx in self.indices + ] + + self.weights = torch.DoubleTensor(weights) + + def __iter__(self): + return ( + self.indices[i] + for i in torch.multinomial(self.weights, self.num_samples, replacement=True) + ) + + def __len__(self): + return self.num_samples + + +class ChunkSampler(Sampler): + """Splits data into blocks and randomizes them before sampling. This causes some order of the data to be preserved, + while still shuffling the data. + """ + + def __init__(self, data_source, block_size=5, plus_window=5): + """Initialize by passing a block_size and a plus_window parameter. + :param data_source: dataset to sample from + :param block_size: minimum size of each block + :param plus_window: randomly adds between 0 and this value to block size at each epoch + """ + super().__init__(data_source) + self.data_source = data_source + self.num_samples = len(self.data_source) + + self.block_size = block_size + self.plus_window = plus_window + + def __iter__(self): + data = list(range(len(self.data_source))) + + blocksize = self.block_size + random.randint(0, self.plus_window) + + log.info( + f"Chunk sampling with blocksize = {blocksize} ({self.block_size} + {self.plus_window})" + ) + + # Create blocks + blocks = [data[i : i + blocksize] for i in range(0, len(data), blocksize)] + # shuffle the blocks + random.shuffle(blocks) + # concatenate the shuffled blocks + data[:] = [b for bs in blocks for b in bs] + return iter(data) + + def __len__(self): + return self.num_samples + + +class ExpandingChunkSampler(Sampler): + """Splits data into blocks and randomizes them before sampling. Block size grows with each epoch. + This causes some order of the data to be preserved, while still shuffling the data. + """ + + def __init__(self, data_source): + """Initialize by passing a block_size and a plus_window parameter. + :param data_source: dataset to sample from + """ + super().__init__(data_source) + self.data_source = data_source + self.num_samples = len(self.data_source) + + self.block_size = 1 + self.plus_window = 0 + self.epoch_count = 0 + + def __iter__(self): + + self.epoch_count += 1 + + data = list(range(len(self.data_source))) + blocksize = self.block_size + random.randint(0, self.plus_window) + + log.info( + f"Chunk sampling with blocksize = {blocksize} ({self.block_size} + {self.plus_window})" + ) + + # Create blocks + blocks = [data[i : i + blocksize] for i in range(0, len(data), blocksize)] + # shuffle the blocks + random.shuffle(blocks) + # concatenate the shuffled blocks + data[:] = [b for bs in blocks for b in bs] + + if self.epoch_count % 2 == 0: + self.block_size += 1 + else: + self.plus_window += 1 + + return iter(data) + + def __len__(self): + return self.num_samples diff --git a/tests/test_model_integration.py b/tests/test_model_integration.py index 7d1fa621f1..a9d0d3f8f0 100644 --- a/tests/test_model_integration.py +++ b/tests/test_model_integration.py @@ -15,9 +15,9 @@ DocumentRNNEmbeddings, ) from flair.models import SequenceTagger, TextClassifier, LanguageModel +from flair.samplers import ImbalancedClassificationDatasetSampler from flair.trainers import ModelTrainer from flair.trainers.language_model_trainer import LanguageModelTrainer, TextCorpus -from flair.training_utils import EvaluationMetric from flair.optim import AdamW @@ -418,6 +418,40 @@ def test_train_load_use_classifier(results_base_path, tasks_base_path): shutil.rmtree(results_base_path) +@pytest.mark.integration +def test_train_classifier_with_sampler(results_base_path, tasks_base_path): + corpus = flair.datasets.ClassificationCorpus(tasks_base_path / "imdb") + label_dict = corpus.make_label_dictionary() + + word_embedding: WordEmbeddings = WordEmbeddings("turian") + document_embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings( + [word_embedding], 32, 1, False, 64, False, False + ) + + model = TextClassifier(document_embeddings, label_dict, False) + + trainer = ModelTrainer(model, corpus) + trainer.train( + results_base_path, + max_epochs=2, + shuffle=False, + sampler=ImbalancedClassificationDatasetSampler, + ) + + sentence = Sentence("Berlin is a really nice city.") + + for s in model.predict(sentence): + for l in s.labels: + assert l.value is not None + assert 0.0 <= l.score <= 1.0 + assert type(l.score) is float + + loaded_model = TextClassifier.load(results_base_path / "final-model.pt") + + # clean up results directory + shutil.rmtree(results_base_path) + + @pytest.mark.integration def test_train_load_use_classifier_with_prob(results_base_path, tasks_base_path): corpus = flair.datasets.ClassificationCorpus(tasks_base_path / "imdb") diff --git a/train.py b/train.py index 5916517c60..f9ea985ede 100644 --- a/train.py +++ b/train.py @@ -55,7 +55,6 @@ trainer.train( "resources/taggers/example-ner", - EvaluationMetric.MICRO_F1_SCORE, learning_rate=0.1, mini_batch_size=32, max_epochs=20,