Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
Browse files Browse the repository at this point in the history
…actor_v2
  • Loading branch information
tomaarsen committed Feb 6, 2023
2 parents 7d4ad00 + 9b7f74e commit aab2377
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 13 deletions.
26 changes: 14 additions & 12 deletions src/setfit/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ def sample_dataset(dataset: Dataset, label_column: str = "label", num_samples: i


def create_fewshot_splits(
dataset: Dataset, sample_sizes: List[int], add_data_augmentation: bool = False, dataset_name: Optional[str] = None
dataset: Dataset,
sample_sizes: List[int],
add_data_augmentation: bool = False,
dataset_name: Optional[str] = None,
) -> DatasetDict:
"""Creates training splits from the dataset with an equal number of samples per class (when possible)."""
splits_ds = DatasetDict()
Expand Down Expand Up @@ -254,25 +257,24 @@ def __getitem__(self, idx: int) -> Tuple[TokenizerOutput, Union[int, List[int]]]
max_length=self.max_length,
padding="max_length",
truncation=True,
return_attention_mask=True,
return_token_type_ids=True,
return_attention_mask="attention_mask" in self.tokenizer.model_input_names,
return_token_type_ids="token_type_ids" in self.tokenizer.model_input_names,
)
label = self.y[idx]

return feature, label

@staticmethod
def collate_fn(batch):
features = {
"input_ids": [],
"attention_mask": [],
"token_type_ids": [],
}
def collate_fn(self, batch):

features = {input_name: [] for input_name in self.tokenizer.model_input_names}

labels = []
for feature, label in batch:
features["input_ids"].append(feature["input_ids"])
features["attention_mask"].append(feature["attention_mask"])
features["token_type_ids"].append(feature["token_type_ids"])
if "attention_mask" in features:
features["attention_mask"].append(feature["attention_mask"])
if "token_type_ids" in features:
features["token_type_ids"].append(feature["token_type_ids"])
labels.append(label)

# convert to tensors
Expand Down
6 changes: 5 additions & 1 deletion src/setfit/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,11 @@ def _prepare_dataloader(
max_length=max_length,
)
dataloader = DataLoader(
dataset, batch_size=batch_size, collate_fn=SetFitDataset.collate_fn, shuffle=shuffle, pin_memory=True
dataset,
batch_size=batch_size,
collate_fn=dataset.collate_fn,
shuffle=shuffle,
pin_memory=True,
)

return dataloader
Expand Down
29 changes: 29 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
import pandas as pd
import pytest
from datasets import Dataset, load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

from setfit.data import (
SAMPLE_SIZES,
SEEDS,
SetFitDataset,
add_templated_examples,
create_fewshot_splits,
create_fewshot_splits_multilabel,
Expand Down Expand Up @@ -197,3 +200,29 @@ def test_get_augmented_samples(dataset: str):
def test_get_augmented_samples_negative():
with pytest.raises(ValueError):
get_augmented_samples(None)


@pytest.mark.parametrize(
"tokenizer_name",
["sentence-transformers/paraphrase-albert-small-v2", "sentence-transformers/distiluse-base-multilingual-cased-v1"],
)
def test_correct_model_inputs(tokenizer_name):
# Arbitrary testing data
x = list(string.ascii_lowercase)
y = list(range(len(x)))

# Relatively Standard DataLoader setup using a SetFitDataset
# for training a differentiable classification head
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
dataset = SetFitDataset(x, y, tokenizer)
dataloader = DataLoader(
dataset,
batch_size=2,
collate_fn=dataset.collate_fn,
shuffle=True,
pin_memory=True,
)

# Verify that the x_batch contains exactly those keys that the model requires
x_batch, _ = next(iter(dataloader))
assert set(x_batch.keys()) == set(tokenizer.model_input_names)

0 comments on commit aab2377

Please sign in to comment.