-
Notifications
You must be signed in to change notification settings - Fork 101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify boilerplate for monoT5 and monoBERT #83
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thanks for implementing this!
Can you also change the README accordingly? |
hey @rodrigonogueira4 - do you prefer this impl or the alternative of folding the boilerplate code into constructors for existing models? e.g., https://github.com/castorini/pygaggle/blob/master/pygaggle/rerank/transformer.py#L25 |
Yes, it is actually better to rename |
@rodrigonogueira4 @lintool which do you prefer? folding into constructors: class T5Reranker(Reranker):
def __init__(self,
model: T5ForConditionalGeneration = None,
tokenizer: QueryDocumentBatchTokenizer = None):
if not model:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = T5ForConditionalGeneration.from_pretrained('castorini/monot5-base-msmarco').to(device).eval()
self.model = model
if not tokenizer:
tokenizer = T5BatchTokenizer(AutoTokenizer.from_pretrained('t5-base'), batch_size=8)
self.tokenizer = tokenizer
self.device = next(self.model.parameters(), None).device
class SequenceClassificationTransformerReranker(Reranker):
def __init__(self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer):
if not model:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AutoModelForSequenceClassification.from_pretrained('castorini/monobert-large-msmarco').to(device).eval()
self.model = model
if not tokenizer:
tokenizer = AutoTokenizer.from_pretrained('bert-large-uncased')
self.tokenizer = tokenizer
self.device = next(model.parameters()).device making subclasses: class MonoT5(T5Reranker):
def __init__(self,
model: T5ForConditionalGeneration = None,
tokenizer: QueryDocumentBatchTokenizer = None):
if not model:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = T5ForConditionalGeneration.from_pretrained('castorini/monot5-base-msmarco').to(device).eval()
if not tokenizer:
tokenizer = T5BatchTokenizer(AutoTokenizer.from_pretrained('t5-base'), batch_size=8)
super().__init__(model, tokenizer)
class MonoBERT(SequenceClassificationTransformerReranker):
def __init__(self,
model: T5ForConditionalGeneration = None,
tokenizer: QueryDocumentBatchTokenizer = None):
if not model:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AutoModelForSequenceClassification.from_pretrained('castorini/monobert-large-msmarco').to(device).eval()
if not tokenizer:
tokenizer = AutoTokenizer.from_pretrained('bert-large-uncased')
super().__init__(model, tokenizer) Option 2 is in case |
My vote is for option 1, but renaming the classes @ronakice should chime in also... Also, once we build these abstractions we should propagate to the replications also, e.g.,: |
I agree with @lintool, lowercase "m" as it is consistent with our previous work! |
But re: Option 1 vs. Option 2? I.e., is there something special about our current abstractions that we should keep? |
the main things I can think of with using lowercase are
I'd avoid it unless the consistency w/ paper is particularly important and worth ignoring Python's conventions As for the current abstractions, I was mainly concerned about |
Hey @yuxuan-ji sorry I got a bit carried away with some other work. Firstly, yes I think option 1 is better. You make a fair point about linters complaining about monoBERT/monoT5 as well as the lack of clarity to the general dev. So I concede, I think it is better to go with As to the usage of |
Okay, we've converged. Option 1, FWIW - huggingface made the model names fugly, favoring conformance to conventions - see Bert and MBart. @yuxuan-ji please execute. |
8a636c7
to
5a6a0fb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Merging! Thanks @yuxuan-ji for swiftly finishing this :)
closes #80
usage is like so:
tested outputs are the same as https://github.com/castorini/pygaggle#a-simple-reranking-example
I was a bit confused on the goal of these two functions, see my comment here: #80 (comment)