Skip to content

Commit

Permalink
update transformers==2.10.0 to transformers==4.0.0rc1 (#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
MXueguang authored Nov 24, 2020
1 parent 68b421e commit b235fae
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 8 deletions.
5 changes: 3 additions & 2 deletions pygaggle/model/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ def greedy_decode(model: PreTrainedModel,
decode_ids = torch.full((input_ids.size(0), 1),
model.config.decoder_start_token_id,
dtype=torch.long).to(input_ids.device)
past = model.get_encoder()(input_ids, attention_mask=attention_mask)
encoder_outputs = model.get_encoder()(input_ids, attention_mask=attention_mask)
next_token_logits = None
for _ in range(length):
model_inputs = model.prepare_inputs_for_generation(
decode_ids,
past=past,
encoder_outputs=encoder_outputs,
past=None,
attention_mask=attention_mask,
use_cache=True)
outputs = model(**model_inputs) # (batch_size, cur_len, vocab_size)
Expand Down
6 changes: 4 additions & 2 deletions pygaggle/model/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ class T5BatchTokenizer(AppendEosTokenizerMixin, QueryDocumentBatchTokenizer):
def __init__(self, *args, **kwargs):
kwargs['pattern'] = 'Query: {query} Document: {document} Relevant:'
kwargs['return_attention_mask'] = True
kwargs['pad_to_max_length'] = True
kwargs['padding'] = 'max_length'
kwargs["truncation"] = True
kwargs['return_tensors'] = 'pt'
kwargs['max_length'] = 512
super().__init__(*args, **kwargs)
Expand All @@ -119,7 +120,8 @@ def __init__(self, *args, **kwargs):
class SimpleBatchTokenizer(BatchTokenizer):
def __init__(self, *args, **kwargs):
kwargs['return_attention_mask'] = True
kwargs['pad_to_max_length'] = True
kwargs['padding'] = 'max_length'
kwargs['truncation'] = True
super().__init__(*args, **kwargs)


Expand Down
7 changes: 5 additions & 2 deletions pygaggle/rerank/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,12 @@ def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
ret = self.tokenizer.encode_plus(query.text,
text.text,
max_length=512,
truncation=True,
return_token_type_ids=True,
return_tensors='pt')
input_ids = ret['input_ids'].to(self.device)
tt_ids = ret['token_type_ids'].to(self.device)
output, = self.model(input_ids, token_type_ids=tt_ids)
output, = self.model(input_ids, token_type_ids=tt_ids, return_dict=False)
if output.size(1) > 1:
text.score = torch.nn.functional.log_softmax(
output, 1)[0, -1].item()
Expand All @@ -167,12 +168,14 @@ def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
ret = self.tokenizer.encode_plus(query.text,
text.text,
max_length=512,
truncation=True,
return_tensors='pt',
return_token_type_ids=True)
input_ids = ret['input_ids'].to(self.device)
tt_ids = ret['token_type_ids'].to(self.device)
start_scores, end_scores = self.model(input_ids,
token_type_ids=tt_ids)
token_type_ids=tt_ids,
return_dict=False)
start_scores = start_scores[0]
end_scores = end_scores[0]
start_scores[(1 - tt_ids[0]).bool()] = -5000
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ scipy>=1.4
spacy==2.2.4
tensorboard>=2.1.0
tensorflow>=2.2.0rc1
tokenizers==0.7
tokenizers==0.9.4
tqdm==4.45.0
transformers==2.10.0
transformers==4.0.0rc1

0 comments on commit b235fae

Please sign in to comment.