Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
Browse files Browse the repository at this point in the history
…actor_v2
  • Loading branch information
tomaarsen committed Jan 22, 2023
2 parents 14602ea + 174eb00 commit 94106cc
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 25 deletions.
27 changes: 22 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,21 +158,38 @@ Based on our experiments, `SetFitHead` can achieve similar performance as using

To train SetFit models on multilabel datasets, specify the `multi_target_strategy` argument when loading the pretrained model:

#### Example using a classification head from `scikit-learn`:

```python
from setfit import SetFitModel

model = SetFitModel.from_pretrained(model_id, multi_target_strategy="one-vs-rest")
model = SetFitModel.from_pretrained(
model_id,
multi_target_strategy="one-vs-rest",
)
```

This will initialise a multilabel classification head from `sklearn` - the following options are available for `multi_target_strategy`:

* `one-vs-rest`: use a `OneVsRestClassifier` head.
* `multi-output`: use a `MultiOutputClassifier` head.
* `classifier-chain`: use a `ClassifierChain` head.
* `one-vs-rest`: uses a `OneVsRestClassifier` head.
* `multi-output`: uses a `MultiOutputClassifier` head.
* `classifier-chain`: uses a `ClassifierChain` head.

From here, you can instantiate a `SetFitTrainer` using the same example above, and train it as usual.

**Note:** If you use the differentiable head, it will automatically use `softmax` with `argmax` when `num_classes` is greater than 1.
#### Example using the differentiable `SetFitHead`:

```python
from setfit import SetFitModel

model = SetFitModel.from_pretrained(
model_id,
multi_target_strategy="one-vs-rest"
use_differentiable_head=True,
head_params={"out_features": num_classes},
)
```
**Note:** If you use the differentiable `SetFitHead` classifier head, it will automatically use `BCEWithLogitsLoss` for training. The prediction involves a `sigmoid` after which probabilities are rounded to 1 or 0. Furthermore, the `"one-vs-rest"` and `"multi-output"` multi-target strategies are equivalent for the differentiable `SetFitHead`.

### Zero-shot text classification

Expand Down
1 change: 1 addition & 0 deletions scripts/create_summary_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def compute_tfew_medians(results_path: str) -> None:

def get_formatted_ds_metrics(path: str, dataset: str, sample_sizes: List[str]) -> Tuple[str, List[str]]:
formatted_row = []
metric_name = ""
exact_metrics, exact_stds = {}, {}

for sample_size in sample_sizes:
Expand Down
8 changes: 8 additions & 0 deletions scripts/setfit/run_fewshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
from setfit.utils import DEV_DATASET_TO_METRIC, LOSS_NAME_TO_CLASS, TEST_DATASET_TO_METRIC, load_data_splits


sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from create_summary_table import create_summary_table # noqa: E402


# ignore all future warnings
simplefilter(action="ignore", category=FutureWarning)

Expand Down Expand Up @@ -169,6 +173,10 @@ def main():
sort_keys=True,
)

# Create a summary_table.csv file that computes means and standard deviations
# for all of the results in `output_path`.
create_summary_table(str(output_path))


if __name__ == "__main__":
main()
14 changes: 7 additions & 7 deletions src/setfit/data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

import pandas as pd
import torch
Expand Down Expand Up @@ -222,8 +222,8 @@ class SetFitDataset(TorchDataset):
Args:
x (`List[str]`):
A list of input data as texts that will be fed into `SetFitModel`.
y (`List[int]`):
A list of input data's labels.
y (`Union[List[int], List[List[int]]]`):
A list of input data's labels. Can be a nested list for multi-label classification.
tokenizer (`PreTrainedTokenizerBase`):
The tokenizer from `SetFitModel`'s body.
max_length (`int`, defaults to `32`):
Expand All @@ -234,7 +234,7 @@ class SetFitDataset(TorchDataset):
def __init__(
self,
x: List[str],
y: List[int],
y: Union[List[int], List[List[int]]],
tokenizer: "PreTrainedTokenizerBase",
max_length: int = 32,
) -> None:
Expand All @@ -248,7 +248,7 @@ def __init__(
def __len__(self) -> int:
return len(self.x)

def __getitem__(self, idx: int) -> Tuple[TokenizerOutput, int]:
def __getitem__(self, idx: int) -> Tuple[TokenizerOutput, Union[int, List[int]]]:
feature = self.tokenizer(
self.x[idx],
max_length=self.max_length,
Expand Down Expand Up @@ -277,6 +277,6 @@ def collate_fn(batch):

# convert to tensors
features = {k: torch.Tensor(v).int() for k, v in features.items()}
labels = torch.Tensor(labels).long()

labels = torch.Tensor(labels)
labels = labels.long() if len(labels.size()) == 1 else labels.float()
return features, labels
46 changes: 33 additions & 13 deletions src/setfit/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,17 @@ class SetFitHead(models.Dense):
out_features (`int`, defaults to `2`):
The number of targets. If set `out_features` to 1 for binary classification, it will be changed to 2 as 2-class classification.
temperature (`float`, defaults to `1.0`):
A logits' scaling factor (i.e., number of targets more than 1).
A logits' scaling factor. Higher values makes the model less confident and higher values makes
it more confident.
eps (`float`, defaults to `1e-5`):
A value for numerical stability when scaling logits.
bias (`bool`, *optional*, defaults to `True`):
Whether to add bias to the head.
device (`torch.device`, str, *optional*):
The device the model will be sent to. If `None`, will check whether GPU is available.
multitarget (`bool`, defaults to `False`):
Enable multi-target classification by making `out_features` binary predictions instead
of a single multinomial prediction.
"""

def __init__(
Expand All @@ -130,6 +134,7 @@ def __init__(
eps: float = 1e-5,
bias: bool = True,
device: Optional[Union[torch.device, str]] = None,
multitarget: bool = False,
) -> None:
super(models.Dense, self).__init__() # init on models.Dense's parent: nn.Module

Expand All @@ -150,6 +155,7 @@ def __init__(
self.eps = eps
self.bias = bias
self._device = device or "cuda" if torch.cuda.is_available() else "cpu"
self.multitarget = multitarget

self.to(self._device)
self.apply(self._init_weight)
Expand All @@ -170,7 +176,8 @@ def forward(
make sure to store embeddings under the key: 'sentence_embedding'
and the outputs will be under the key: 'prediction'.
temperature (`float`, *optional*):
A logits' scaling factor when using multi-targets (i.e., number of targets more than 1).
A logits' scaling factor. Higher values makes the model less
confident and higher values makes it more confident.
Will override the temperature given during initialization.
Returns:
[`Dict[str, torch.Tensor]` or `Tuple[torch.Tensor]`]
Expand All @@ -180,12 +187,13 @@ def forward(
if isinstance(features, dict):
assert "sentence_embedding" in features
is_features_dict = True

x = features["sentence_embedding"] if is_features_dict else features
logits = self.linear(x)
logits = logits / (temperature + self.eps)
probs = nn.functional.softmax(logits, dim=-1)

if self.multitarget: # multiple targets per item
probs = torch.sigmoid(logits)
else: # one target per item
probs = nn.functional.softmax(logits, dim=-1)
if is_features_dict:
features.update(
{
Expand All @@ -205,11 +213,13 @@ def predict_proba(self, x_test: torch.Tensor) -> torch.Tensor:
def predict(self, x_test: torch.Tensor) -> torch.Tensor:
probs = self.predict_proba(x_test)

out = torch.argmax(probs, dim=-1)

return out
if self.multitarget:
return torch.where(probs >= 0.5, 1, 0)
return torch.argmax(probs, dim=-1)

def get_loss_fn(self) -> nn.Module:
if self.multitarget: # if sigmoid output
return torch.nn.BCEWithLogitsLoss()
return torch.nn.CrossEntropyLoss()

@property
Expand Down Expand Up @@ -270,7 +280,7 @@ def has_differentiable_head(self) -> bool:
def fit(
self,
x_train: List[str],
y_train: List[int],
y_train: Union[List[int], List[List[int]]],
classifier_num_epochs: int,
classifier_batch_size: Optional[int] = None,
classifier_learning_rate: Optional[Tuple[float, float]] = (None, None),
Expand Down Expand Up @@ -322,7 +332,7 @@ def fit(
def _prepare_dataloader(
self,
x_train: List[str],
y_train: List[int],
y_train: Union[List[int], List[List[int]]],
batch_size: Optional[int] = None,
max_length: Optional[int] = None,
shuffle: bool = True,
Expand Down Expand Up @@ -413,7 +423,7 @@ def predict(self, inputs: List[str], as_numpy: bool = False) -> Union[torch.Tens
outputs = self.model_head.predict(embeddings)

if as_numpy and self.has_differentiable_head:
outputs = outputs.cpu().numpy()
outputs = outputs.detach().cpu().numpy()
elif not as_numpy and not self.has_differentiable_head:
outputs = torch.from_numpy(outputs)

Expand All @@ -424,7 +434,7 @@ def predict_proba(self, inputs: List[str], as_numpy: bool = False) -> Union[torc
outputs = self.model_head.predict_proba(embeddings)

if as_numpy and self.has_differentiable_head:
outputs = outputs.cpu().numpy()
outputs = outputs.detach().cpu().numpy()
elif not as_numpy and not self.has_differentiable_head:
outputs = torch.from_numpy(outputs)

Expand Down Expand Up @@ -487,7 +497,7 @@ def _from_pretrained(
normalize_embeddings: bool = False,
**model_kwargs,
) -> "SetFitModel":
model_body = SentenceTransformer(model_id, cache_folder=cache_dir)
model_body = SentenceTransformer(model_id, cache_folder=cache_dir, use_auth_token=use_auth_token)
target_device = model_body._target_device
model_body.to(target_device) # put `model_body` on the target device

Expand Down Expand Up @@ -526,12 +536,22 @@ def _from_pretrained(
else:
head_params = model_kwargs.get("head_params", {})
if use_differentiable_head:
if multi_target_strategy is None:
use_multitarget = False
else:
if multi_target_strategy in ["one-vs-rest", "multi-output"]:
use_multitarget = True
else:
raise ValueError(
f"multi_target_strategy '{multi_target_strategy}' is not supported for differentiable head"
)
# Base `model_head` parameters
# - get the sentence embedding dimension from the `model_body`
# - follow the `model_body`, put `model_head` on the target device
base_head_params = {
"in_features": model_body.get_sentence_embedding_dimension(),
"device": target_device,
"multitarget": use_multitarget,
}
model_head = SetFitHead(**{**head_params, **base_head_params})
else:
Expand Down
53 changes: 53 additions & 0 deletions tests/test_deprecated_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,59 @@ def compute_metrics(y_pred, y_test):
)


@pytest.mark.skip(
reason=(
"The `trainer.freeze()` before `trainer.train()` now freezes the body as well as the head, "
"which means the backwards call from `trainer.train()` will fail."
)
)
class SetFitTrainerMultilabelDifferentiableTest(TestCase):
def setUp(self):
self.model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-albert-small-v2",
multi_target_strategy="one-vs-rest",
use_differentiable_head=True,
head_params={"out_features": 2},
)
self.num_iterations = 1

def test_trainer_multilabel_support_callable_as_metric(self):
dataset = Dataset.from_dict({"text_new": ["", "a", "b", "ab"], "label_new": [[0, 0], [1, 0], [0, 1], [1, 1]]})

multilabel_f1_metric = evaluate.load("f1", "multilabel")
multilabel_accuracy_metric = evaluate.load("accuracy", "multilabel")

def compute_metrics(y_pred, y_test):
return {
"f1": multilabel_f1_metric.compute(predictions=y_pred, references=y_test, average="micro")["f1"],
"accuracy": multilabel_accuracy_metric.compute(predictions=y_pred, references=y_test)["accuracy"],
}

trainer = SetFitTrainer(
model=self.model,
train_dataset=dataset,
eval_dataset=dataset,
metric=compute_metrics,
num_iterations=self.num_iterations,
column_mapping={"text_new": "text", "label_new": "label"},
)

trainer.freeze()
trainer.train()

trainer.unfreeze(keep_body_frozen=False)
trainer.train(5)
metrics = trainer.evaluate()

self.assertEqual(
{
"f1": 1.0,
"accuracy": 1.0,
},
metrics,
)


@require_optuna
class TrainerHyperParameterOptunaIntegrationTest(TestCase):
def setUp(self):
Expand Down
18 changes: 18 additions & 0 deletions tests/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,24 @@ def test_setfit_from_pretrained_local_model_with_head(tmp_path):
assert isinstance(model, SetFitModel)


def test_setfithead_multitarget_from_pretrained():
model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-albert-small-v2",
use_differentiable_head=True,
multi_target_strategy="one-vs-rest",
head_params={"out_features": 5},
)
assert isinstance(model.model_head, SetFitHead)
assert model.model_head.multitarget
assert isinstance(model.model_head.get_loss_fn(), torch.nn.BCEWithLogitsLoss)

y_pred = model.predict("Test text")
assert len(y_pred) == 5

y_pred_probs = model.predict_proba("Test text", as_numpy=True)
assert not np.isclose(y_pred_probs.sum(), 1) # Should not sum to one


def test_to_logistic_head():
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2")
devices = (
Expand Down
43 changes: 43 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,49 @@ def compute_metrics(y_pred, y_test):
)


class SetFitTrainerMultilabelDifferentiableTest(TestCase):
def setUp(self):
self.model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-albert-small-v2",
multi_target_strategy="one-vs-rest",
use_differentiable_head=True,
head_params={"out_features": 2},
)
self.args = TrainingArguments(num_iterations=1)

def test_trainer_multilabel_support_callable_as_metric(self):
dataset = Dataset.from_dict({"text_new": ["", "a", "b", "ab"], "label_new": [[0, 0], [1, 0], [0, 1], [1, 1]]})

multilabel_f1_metric = evaluate.load("f1", "multilabel")
multilabel_accuracy_metric = evaluate.load("accuracy", "multilabel")

def compute_metrics(y_pred, y_test):
return {
"f1": multilabel_f1_metric.compute(predictions=y_pred, references=y_test, average="micro")["f1"],
"accuracy": multilabel_accuracy_metric.compute(predictions=y_pred, references=y_test)["accuracy"],
}

trainer = Trainer(
model=self.model,
args=self.args,
train_dataset=dataset,
eval_dataset=dataset,
metric=compute_metrics,
column_mapping={"text_new": "text", "label_new": "label"},
)

trainer.train()
metrics = trainer.evaluate()

self.assertEqual(
{
"f1": 1.0,
"accuracy": 1.0,
},
metrics,
)


@require_optuna
class TrainerHyperParameterOptunaIntegrationTest(TestCase):
def setUp(self):
Expand Down

0 comments on commit 94106cc

Please sign in to comment.