Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor to introduce Trainer & TrainingArguments, add SetFit ABSA #265

Merged
merged 97 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
97 commits
Select commit Hold shift + click to select a range
1acdd5c
Implement Trainer & TrainingArguments w. tests
tomaarsen Jan 11, 2023
89f4435
Readded support for hyperparameter tuning
tomaarsen Jan 11, 2023
5f2a6b3
Remove unused imports and reformat
tomaarsen Jan 11, 2023
622f33b
Preserve desired behaviour despite deprecation of keep_body_frozen pa…
tomaarsen Jan 11, 2023
ff59154
Ensure that DeprecationWarnings are displayed
tomaarsen Jan 11, 2023
3b4ef58
Set Trainer.freeze and Trainer.unfreeze methods normally
tomaarsen Jan 11, 2023
fd68274
Add TrainingArgument tests for num_epochs, batch_sizes, lr
tomaarsen Jan 11, 2023
14602ea
Convert trainer.train arguments into a softer deprecation
tomaarsen Jan 11, 2023
94106cc
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen Jan 22, 2023
a39e772
Merge branch 'refactor_v2' of https://github.com/tomaarsen/setfit; br…
tomaarsen Jan 23, 2023
9fc55a6
Use body/head_learning_rate instead of classifier/embedding_learning_…
tomaarsen Jan 23, 2023
7d4ad00
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen Jan 23, 2023
aab2377
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen Feb 6, 2023
dee70b1
Reformat according to the newest black version
tomaarsen Feb 6, 2023
fb6547d
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen Feb 6, 2023
abbbb03
Remove "classifier" from var names in SetFitHead
tomaarsen Feb 6, 2023
12d326e
Update DeprecationWarnings to include timeline
tomaarsen Feb 6, 2023
70c0295
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen Feb 6, 2023
fc246cc
Convert training_argument imports to relative imports
tomaarsen Feb 6, 2023
57aa54f
Make conditional explicit
tomaarsen Feb 6, 2023
7ebdf93
Make conditional explicit
tomaarsen Feb 6, 2023
4695293
Use assertEqual rather than assert
tomaarsen Feb 6, 2023
4c6d0fd
Remove training_arguments from test func names
tomaarsen Feb 6, 2023
5937ec2
Replace loss_class on Trainer with loss on TrainArgs
tomaarsen Feb 6, 2023
f1e3de9
Removed dead class argument
tomaarsen Feb 6, 2023
6051095
Move SupConLoss to losses.py
tomaarsen Feb 6, 2023
bddd46a
Add deprecation to Trainer.(un)freeze
tomaarsen Feb 7, 2023
fa8a077
Prevent warning from always triggering
tomaarsen Feb 7, 2023
85a3684
Export TrainingArguments in __init__
tomaarsen Feb 7, 2023
ca625a2
Update & add important missing docstrings
tomaarsen Feb 7, 2023
868d7b7
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen Feb 7, 2023
68e9094
Use standard dataclass initialization for SetFitModel
tomaarsen Feb 8, 2023
19a6fc8
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen Feb 15, 2023
0b2efa1
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen Feb 15, 2023
ca87c42
Remove duplicate space in DeprecationWarning
tomaarsen Feb 16, 2023
cc5282f
No longer require labeled data for DistillationTrainer
tomaarsen Mar 3, 2023
c6f5782
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen Mar 3, 2023
36cbbfe
Update docs for v1.0.0
tomaarsen Mar 6, 2023
deb57ff
Remove references of SetFitTrainer
tomaarsen Mar 6, 2023
46922d5
Update expected test output
tomaarsen Mar 6, 2023
f43d5b2
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen Apr 19, 2023
b0f9f58
Remove unused pipeline
tomaarsen Apr 19, 2023
339f332
Execute deprecations
tomaarsen Apr 19, 2023
9e0bf78
Stop importing now-removed function
tomaarsen Apr 19, 2023
ecabbcf
Initial setup for logging & callbacks
tomaarsen Jul 6, 2023
6e6720b
Move sentence-transformer training into trainer.py
tomaarsen Jul 6, 2023
826eb53
Add checkpointing, support EarlyStoppingCallback
tomaarsen Jul 28, 2023
019a971
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen Jul 29, 2023
1930973
Run formatting
tomaarsen Jul 29, 2023
e4f3f76
Merge branch 'refactor_v2' of https://github.com/tomaarsen/setfit int…
tomaarsen Jul 29, 2023
0f66109
Merge pull request #4 from tomaarsen/feat/logging_callbacks
tomaarsen Jul 29, 2023
a87cdc0
Add additional trainer tests
tomaarsen Jul 29, 2023
d418759
Use isinstance, required by flake8 release from 1hr ago
tomaarsen Jul 29, 2023
08892f6
sampler for refactor WIP
danstan5 Sep 14, 2023
0a2b664
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen Oct 17, 2023
429de0f
Merge branch 'refactor_v2' of https://github.com/tomaarsen/setfit int…
tomaarsen Oct 17, 2023
173f084
Run formatters
tomaarsen Oct 17, 2023
c23959a
Remove tests from modeling.py
tomaarsen Oct 17, 2023
0fa3870
Add missing type hint
tomaarsen Oct 17, 2023
3969f38
Adjust test to still pass if W&B/Tensorboard are installed
tomaarsen Oct 17, 2023
567f1c9
Merge branch 'refactor_v2' of https://github.com/tomaarsen/setfit int…
tomaarsen Oct 17, 2023
851f0bb
The log/eval/save steps should be saved on the state instead
tomaarsen Oct 17, 2023
67ddedc
Merge branch 'refactor_v2' of https://github.com/tomaarsen/setfit int…
tomaarsen Oct 17, 2023
d37ee09
sampler logic fix "unique" strategy
danstan5 Oct 19, 2023
0ef8837
add sampler tests (not complete)
danstan5 Oct 19, 2023
131aa26
add sampling_strategy into TrainingArguments
danstan5 Oct 19, 2023
c6c6228
Merge branch 'refactor-sampling' of https://github.com/danstan5/setfi…
danstan5 Oct 19, 2023
7431005
num_iterations removed from TrainingArguments
danstan5 Oct 19, 2023
3bd2acc
run_fewshot compatible with <v.1.0.0
danstan5 Oct 20, 2023
3d07e6c
Run make style
tomaarsen Oct 25, 2023
978daee
Use "no" as the default evaluation_strategy
tomaarsen Oct 25, 2023
2802a3f
Move num_iterations back to TrainingArguments
tomaarsen Oct 25, 2023
391f991
Fix broken trainer tests due to new default sampling
tomaarsen Oct 25, 2023
f8b7253
Use the Contrastive Dataset for Distillation
tomaarsen Oct 25, 2023
38e9607
Set the default logging steps at 50
tomaarsen Oct 25, 2023
4ead15d
Add max_steps argument to TrainingArguments
tomaarsen Oct 25, 2023
eb70336
Change max_steps conditional
tomaarsen Oct 25, 2023
3478799
Merge pull request #5 from danstan5/refactor-sampling
tomaarsen Oct 27, 2023
d9c4a05
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen Nov 9, 2023
5b39f06
Seeds are now correctly applied for reproducibility
tomaarsen Nov 9, 2023
7c3feed
Don't scale gradients during evaluation
tomaarsen Nov 9, 2023
cdc8979
Use evaluation_strategy="steps" if eval_steps is set
tomaarsen Nov 9, 2023
e040167
Run formatting
tomaarsen Nov 9, 2023
d2f2489
Implement SetFit for ABSA from Intel Labs (#6)
tomaarsen Nov 9, 2023
5c4569d
Import optuna under TYPE_CHECKING
tomaarsen Nov 9, 2023
ceeb725
Remove unused import, reformat
tomaarsen Nov 9, 2023
5c669b5
Add MANIFEST.in with model_card_template
tomaarsen Nov 9, 2023
8e201e5
Don't require transformers TrainingArgs in tests
tomaarsen Nov 9, 2023
6ae5045
Update URLs in setup.py
tomaarsen Nov 9, 2023
ecaabb4
Increase min hf_hub version to 0.12.0 for SoftTemporaryDirectory
tomaarsen Nov 9, 2023
4e79397
Include MANIFEST.in data via `include_package_data=True`
tomaarsen Nov 9, 2023
65aff32
Use kwargs instead of args in super call
tomaarsen Nov 9, 2023
eeeac55
Use v0.13.0 as min. version as huggingface/huggingface_hub#1315
tomaarsen Nov 9, 2023
3214f1b
Use en_core_web_sm for tests
tomaarsen Nov 10, 2023
2b78bb0
Remove incorrect spacy_model from AspectModel/PolarityModel
tomaarsen Nov 10, 2023
b68f655
Rerun formatting
tomaarsen Nov 10, 2023
d85f0d9
Run CI on pre branch & workflow dispatch
tomaarsen Nov 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/setfit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
__version__ = "0.6.0.dev0"

import warnings

from .data import add_templated_examples, sample_dataset
from .modeling import SetFitHead, SetFitModel
from .trainer import SetFitTrainer
from .trainer_distillation import DistillationSetFitTrainer
from .trainer import SetFitTrainer, Trainer
from .trainer_distillation import DistillationSetFitTrainer, DistillationTrainer


# Ensure that DeprecationWarnings are always shown
warnings.filterwarnings("default", category=DeprecationWarning)
93 changes: 60 additions & 33 deletions src/setfit/modeling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
Expand All @@ -14,14 +15,14 @@
import numpy as np
import requests
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
from sentence_transformers import InputExample, SentenceTransformer, models
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier
from sklearn.multioutput import ClassifierChain, MultiOutputClassifier
from torch import nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from tqdm.auto import tqdm, trange

from . import logging
from .data import SetFitDataset
Expand Down Expand Up @@ -216,7 +217,7 @@ def predict(self, x_test: torch.Tensor) -> torch.Tensor:
return torch.where(probs >= 0.5, 1, 0)
return torch.argmax(probs, dim=-1)

def get_loss_fn(self):
def get_loss_fn(self) -> nn.Module:
if self.multitarget: # if sigmoid output
return torch.nn.BCEWithLogitsLoss()
return torch.nn.CrossEntropyLoss()
Expand All @@ -242,9 +243,9 @@ def get_config_dict(self) -> Dict[str, Optional[Union[int, float, bool]]]:
@staticmethod
def _init_weight(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
torch.nn.init.constant_(module.bias, 1e-2)
nn.init.constant_(module.bias, 1e-2)

def __repr__(self):
return "SetFitHead({})".format(self.get_config_dict())
Expand Down Expand Up @@ -280,25 +281,29 @@ def fit(
self,
x_train: List[str],
y_train: Union[List[int], List[List[int]]],
num_epochs: int,
batch_size: Optional[int] = None,
learning_rate: Optional[float] = None,
body_learning_rate: Optional[float] = None,
classifier_num_epochs: int,
classifier_batch_size: Optional[int] = None,
body_classifier_learning_rate: Optional[float] = None,
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
head_learning_rate: Optional[float] = None,
l2_weight: Optional[float] = None,
max_length: Optional[int] = None,
show_progress_bar: Optional[bool] = None,
show_progress_bar: bool = True,
end_to_end: bool = False,
**kwargs,
) -> None:
if self.has_differentiable_head: # train with pyTorch
device = self.model_body.device
self.model_body.train()
self.model_head.train()
if not end_to_end:
self.freeze("body")

dataloader = self._prepare_dataloader(x_train, y_train, batch_size, max_length)
dataloader = self._prepare_dataloader(x_train, y_train, classifier_batch_size, max_length)
criterion = self.model_head.get_loss_fn()
optimizer = self._prepare_optimizer(learning_rate, body_learning_rate, l2_weight)
optimizer = self._prepare_optimizer(head_learning_rate, body_classifier_learning_rate, l2_weight)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
for epoch_idx in tqdm(range(num_epochs), desc="Epoch", disable=not show_progress_bar):
for batch in dataloader:
for epoch_idx in trange(classifier_num_epochs, desc="Epoch", disable=not show_progress_bar):
for batch in tqdm(dataloader, desc="Iteration", disable=not show_progress_bar, leave=False):
features, labels = batch
optimizer.zero_grad()

Expand All @@ -308,15 +313,18 @@ def fit(

outputs = self.model_body(features)
if self.normalize_embeddings:
outputs = torch.nn.functional.normalize(outputs, p=2, dim=1)
outputs = nn.functional.normalize(outputs, p=2, dim=1)
outputs = self.model_head(outputs)
logits = outputs["logits"]

loss = criterion(logits, labels)
loss: torch.Tensor = criterion(logits, labels)
loss.backward()
optimizer.step()

scheduler.step()

if not end_to_end:
self.unfreeze("body")
else: # train with sklearn
embeddings = self.model_body.encode(x_train, normalize_embeddings=self.normalize_embeddings)
self.model_head.fit(embeddings, y_train)
Expand Down Expand Up @@ -359,16 +367,20 @@ def _prepare_dataloader(

def _prepare_optimizer(
self,
learning_rate: float,
body_learning_rate: Optional[float],
head_learning_rate: float,
body_classifier_learning_rate: Optional[float],
l2_weight: float,
) -> torch.optim.Optimizer:
body_learning_rate = body_learning_rate or learning_rate
body_classifier_learning_rate = body_classifier_learning_rate or head_learning_rate
l2_weight = l2_weight or self.l2_weight
optimizer = torch.optim.AdamW(
[
{"params": self.model_body.parameters(), "lr": body_learning_rate, "weight_decay": l2_weight},
{"params": self.model_head.parameters(), "lr": learning_rate, "weight_decay": l2_weight},
{
"params": self.model_body.parameters(),
"lr": body_classifier_learning_rate,
"weight_decay": l2_weight,
},
{"params": self.model_head.parameters(), "lr": head_learning_rate, "weight_decay": l2_weight},
],
)

Expand All @@ -378,25 +390,40 @@ def freeze(self, component: Optional[Literal["body", "head"]] = None) -> None:
if component is None or component == "body":
self._freeze_or_not(self.model_body, to_freeze=True)

if component is None or component == "head":
if (component is None or component == "head") and self.has_differentiable_head:
self._freeze_or_not(self.model_head, to_freeze=True)

def unfreeze(self, component: Optional[Literal["body", "head"]] = None) -> None:
def unfreeze(
self, component: Optional[Literal["body", "head"]] = None, keep_body_frozen: Optional[bool] = None
) -> None:
if keep_body_frozen is not None:
warnings.warn(
'`keep_body_frozen` is deprecated. Please either pass "head", "body" or no arguments to unfreeze both.',
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
DeprecationWarning,
stacklevel=2,
)
# If the body must stay frozen, only unfreeze the head. Eventually, this entire if-branch
# can be removed.
if keep_body_frozen and not component:
component = "head"

if component is None or component == "body":
self._freeze_or_not(self.model_body, to_freeze=False)

if component is None or component == "head":
if (component is None or component == "head") and self.has_differentiable_head:
self._freeze_or_not(self.model_head, to_freeze=False)

def _freeze_or_not(self, model: torch.nn.Module, to_freeze: bool) -> None:
def _freeze_or_not(self, model: nn.Module, to_freeze: bool) -> None:
for param in model.parameters():
param.requires_grad = not to_freeze

def predict(self, x_test: List[str], as_numpy: bool = False) -> Union[torch.Tensor, "ndarray"]:
embeddings = self.model_body.encode(
x_test, normalize_embeddings=self.normalize_embeddings, convert_to_tensor=self.has_differentiable_head
def encode(self, inputs: List[str]) -> Union[torch.Tensor, "ndarray"]:
return self.model_body.encode(
inputs, normalize_embeddings=self.normalize_embeddings, convert_to_tensor=self.has_differentiable_head
)

def predict(self, inputs: List[str], as_numpy: bool = False) -> Union[torch.Tensor, "ndarray"]:
embeddings = self.encode(inputs)
outputs = self.model_head.predict(embeddings)

if as_numpy and self.has_differentiable_head:
Expand All @@ -406,11 +433,8 @@ def predict(self, x_test: List[str], as_numpy: bool = False) -> Union[torch.Tens

return outputs

def predict_proba(self, x_test: List[str], as_numpy: bool = False) -> Union[torch.Tensor, "ndarray"]:
embeddings = self.model_body.encode(
x_test, normalize_embeddings=self.normalize_embeddings, convert_to_tensor=self.has_differentiable_head
)

def predict_proba(self, inputs: List[str], as_numpy: bool = False) -> Union[torch.Tensor, "ndarray"]:
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
embeddings = self.encode(inputs)
outputs = self.model_head.predict_proba(embeddings)

if as_numpy and self.has_differentiable_head:
Expand All @@ -429,6 +453,9 @@ def to(self, device: Union[str, torch.device]) -> "SetFitModel":
Returns:
SetFitModel: Returns the original model, but now on the desired device.
"""
# Note that we must also set _target_device, or any SentenceTransformer.fit() call will reset
# the body location
self.model_body._target_device = device if isinstance(device, torch.device) else torch.device(device)
self.model_body = self.model_body.to(device)

if self.has_differentiable_head:
Expand Down
Loading