Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement SetFit for ABSA from Intel Labs #6

Merged
merged 17 commits into from
Nov 9, 2023
Merged

Implement SetFit for ABSA from Intel Labs #6

merged 17 commits into from
Nov 9, 2023

Conversation

tomaarsen
Copy link
Owner

@tomaarsen tomaarsen commented Nov 8, 2023

Hello!

Pull Request overview

  • Implement SetFit ABSA from Intel Labs into SetFit.
  • Primary new classes:
    • AbsaModel:
      • predict
      • from_pretrained
      • save_pretrained
      • push_to_hub
      • to
      • device
    • AbsaTrainer:
      • train
      • evaluate
      • add_callback
      • pop_callback
      • remove_callback
      • push_to_hub
  • Add device property to SetFitModel.
  • Modernize SetFitModel.from_pretrained with token=... instead of use_auth_token=...
  • Throw ValueError if args on Trainer is the wrong type, e.g. if it's transformers TrainingArguments.
  • Allow partial column_mapping, move column mapping behaviour into a Mixin.
  • Add test suite for AbsaModel: ~95% test coverage on new behaviour, only push_to_hub is untested.

Usage

Training (Basic)

from setfit import AbsaModel, AbsaTrainer
from datasets import load_dataset

# You can initialize a AbsaModel using one or two SentenceTransformer models, or two ABSA models
# model = AbsaModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AbsaModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2", "sentence-transformers/all-mpnet-base-v2")

# The training/eval dataset must have `text`, `span`, `polarity`, and `ordinal` columns
raw_dataset = load_dataset("data", data_files="example_training_file.csv")
train_dataset = raw_dataset["train"].rename_columns({"sentence": "text", "aspect": "span", "polarity": "label"})

# The minimal Trainer instantiation
trainer = AbsaTrainer(model, train_dataset=train_dataset)
trainer.train()

Training (Advanced)

from setfit import AbsaModel, AbsaTrainer, TrainingArguments
from transformers import EarlyStoppingCallback
from datasets import load_dataset

# You can initialize a AbsaModel using one or two SentenceTransformer models, or two ABSA models
# model = AbsaModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AbsaModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2", "sentence-transformers/all-mpnet-base-v2")

# The training/eval dataset must have `text`, `span`, `polarity`, and `ordinal` columns
raw_dataset = load_dataset("data", data_files="example_training_file.csv")["train"]
raw_dataset = raw_dataset.rename_columns({"sentence": "text", "aspect": "span", "polarity": "label"})
raw_dataset = raw_dataset.train_test_split(test_size=10)
train_dataset, eval_dataset = raw_dataset["train"], raw_dataset["test"]

# Training arguments for aspect and polarity training
aspect_args = TrainingArguments(
    output_dir="aspect",
    num_epochs=2,
    body_learning_rate=5e-5,
    head_learning_rate=1e-2,
    use_amp=True,
    warmup_proportion=0.2,
    evaluation_strategy="steps",
    eval_steps=20,
    save_steps=20,
    load_best_model_at_end=True,
)
polarity_args = TrainingArguments(
    output_dir="polarity",
    num_epochs=3,
    max_steps=1000,
    body_learning_rate=2e-5,
    head_learning_rate=3e-2,
    evaluation_strategy="steps",
    eval_steps=20,
    save_steps=20,
    load_best_model_at_end=True,
)

trainer = AbsaTrainer(
    model,
    args=aspect_args,
    polarity_args=polarity_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)
trainer.train()

metrics = trainer.evaluate()
print(metrics)

Inference

# Predicting is as easy as with SetFit
predictions = model.predict([
    "Best pizza outside of Italy and really tasty.",
    "The variations are great and the prices are absolutely fair.",
    "Unfortunately, you have to expect some waiting time and get a note with a waiting number if it should be very full."
])
print(predictions)
"""
[
    [{'span': 'pizza', 'polarity': 'positive'}],
    [{'span': 'variations', 'polarity': 'positive'}, {'span': 'prices', 'polarity': 'positive'}],
    [{'span': 'waiting time', 'polarity': 'negative'}, {'span': 'note', 'polarity': 'positive'}, {'span': 'number', 'polarity': 'negative'}]
]
"""

Note: The model on display here was trained with a whopping 43 aspects. Not 43 aspects per class mind you, just 43 aspects between only 24 sentences (!).

Saving/Pushing to the Hub

# You can push to the Hub/save models using one or two repo_ids:
trainer.push_to_hub("tomaarsen/setfit-absa-restaurant-review", private=True)
# trainer.push_to_hub("tomaarsen/setfit-absa-restaurant-review-aspect", "tomaarsen/setfit-absa-restaurant-review-polarity", private=True)
# Or directly on the model:
# model.push_to_hub("tomaarsen/setfit-absa-restaurant-review-aspect", "tomaarsen/setfit-absa-restaurant-review-polarity", private=True)
# model.save_pretrained("absa-model")
# model.save_pretrained("absa-model-aspect", "absa-model-polarity", private=True)

TODO

  • Better model cards of saved models.

cc: @rlaperdo


  • Tom Aarsen

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant