Skip to content

Commit

Permalink
Dynamic features in datasets based on model input names (#288)
Browse files Browse the repository at this point in the history
* Dynamic features names based on model input names

* Test that SetFitDataset produces the right model parameters

Co-authored-by: Tom Aarsen <[email protected]>
  • Loading branch information
AleksanderObuchowski and tomaarsen authored Jan 25, 2023
1 parent 0cb8ffd commit 9b7f74e
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 17 deletions.
26 changes: 14 additions & 12 deletions src/setfit/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ def sample_dataset(dataset: Dataset, label_column: str = "label", num_samples: i


def create_fewshot_splits(
dataset: Dataset, sample_sizes: List[int], add_data_augmentation: bool = False, dataset_name: Optional[str] = None
dataset: Dataset,
sample_sizes: List[int],
add_data_augmentation: bool = False,
dataset_name: Optional[str] = None,
) -> DatasetDict:
"""Creates training splits from the dataset with an equal number of samples per class (when possible)."""
splits_ds = DatasetDict()
Expand Down Expand Up @@ -254,25 +257,24 @@ def __getitem__(self, idx: int) -> Tuple[TokenizerOutput, Union[int, List[int]]]
max_length=self.max_length,
padding="max_length",
truncation=True,
return_attention_mask=True,
return_token_type_ids=True,
return_attention_mask="attention_mask" in self.tokenizer.model_input_names,
return_token_type_ids="token_type_ids" in self.tokenizer.model_input_names,
)
label = self.y[idx]

return feature, label

@staticmethod
def collate_fn(batch):
features = {
"input_ids": [],
"attention_mask": [],
"token_type_ids": [],
}
def collate_fn(self, batch):

features = {input_name: [] for input_name in self.tokenizer.model_input_names}

labels = []
for feature, label in batch:
features["input_ids"].append(feature["input_ids"])
features["attention_mask"].append(feature["attention_mask"])
features["token_type_ids"].append(feature["token_type_ids"])
if "attention_mask" in features:
features["attention_mask"].append(feature["attention_mask"])
if "token_type_ids" in features:
features["token_type_ids"].append(feature["token_type_ids"])
labels.append(label)

# convert to tensors
Expand Down
26 changes: 21 additions & 5 deletions src/setfit/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,11 @@ def _prepare_dataloader(
max_length=max_length,
)
dataloader = DataLoader(
dataset, batch_size=batch_size, collate_fn=SetFitDataset.collate_fn, shuffle=shuffle, pin_memory=True
dataset,
batch_size=batch_size,
collate_fn=dataset.collate_fn,
shuffle=shuffle,
pin_memory=True,
)

return dataloader
Expand All @@ -367,8 +371,16 @@ def _prepare_optimizer(
l2_weight = l2_weight or self.l2_weight
optimizer = torch.optim.AdamW(
[
{"params": self.model_body.parameters(), "lr": body_learning_rate, "weight_decay": l2_weight},
{"params": self.model_head.parameters(), "lr": learning_rate, "weight_decay": l2_weight},
{
"params": self.model_body.parameters(),
"lr": body_learning_rate,
"weight_decay": l2_weight,
},
{
"params": self.model_head.parameters(),
"lr": learning_rate,
"weight_decay": l2_weight,
},
],
)

Expand All @@ -394,7 +406,9 @@ def _freeze_or_not(self, model: torch.nn.Module, to_freeze: bool) -> None:

def predict(self, x_test: List[str], as_numpy: bool = False) -> Union[torch.Tensor, "ndarray"]:
embeddings = self.model_body.encode(
x_test, normalize_embeddings=self.normalize_embeddings, convert_to_tensor=self.has_differentiable_head
x_test,
normalize_embeddings=self.normalize_embeddings,
convert_to_tensor=self.has_differentiable_head,
)

outputs = self.model_head.predict(embeddings)
Expand All @@ -408,7 +422,9 @@ def predict(self, x_test: List[str], as_numpy: bool = False) -> Union[torch.Tens

def predict_proba(self, x_test: List[str], as_numpy: bool = False) -> Union[torch.Tensor, "ndarray"]:
embeddings = self.model_body.encode(
x_test, normalize_embeddings=self.normalize_embeddings, convert_to_tensor=self.has_differentiable_head
x_test,
normalize_embeddings=self.normalize_embeddings,
convert_to_tensor=self.has_differentiable_head,
)

outputs = self.model_head.predict_proba(embeddings)
Expand Down
29 changes: 29 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
import pandas as pd
import pytest
from datasets import Dataset, load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

from setfit.data import (
SAMPLE_SIZES,
SEEDS,
SetFitDataset,
add_templated_examples,
create_fewshot_splits,
create_fewshot_splits_multilabel,
Expand Down Expand Up @@ -197,3 +200,29 @@ def test_get_augmented_samples(dataset: str):
def test_get_augmented_samples_negative():
with pytest.raises(ValueError):
get_augmented_samples(None)


@pytest.mark.parametrize(
"tokenizer_name",
["sentence-transformers/paraphrase-albert-small-v2", "sentence-transformers/distiluse-base-multilingual-cased-v1"],
)
def test_correct_model_inputs(tokenizer_name):
# Arbitrary testing data
x = list(string.ascii_lowercase)
y = list(range(len(x)))

# Relatively Standard DataLoader setup using a SetFitDataset
# for training a differentiable classification head
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
dataset = SetFitDataset(x, y, tokenizer)
dataloader = DataLoader(
dataset,
batch_size=2,
collate_fn=dataset.collate_fn,
shuffle=True,
pin_memory=True,
)

# Verify that the x_batch contains exactly those keys that the model requires
x_batch, _ = next(iter(dataloader))
assert set(x_batch.keys()) == set(tokenizer.model_input_names)

0 comments on commit 9b7f74e

Please sign in to comment.