Skip to content

Commit

Permalink
Fixed conversion of BertForMaskedLM to transformers (#555)
Browse files Browse the repository at this point in the history
  • Loading branch information
himanshurawlani authored Sep 24, 2020
1 parent 4fef421 commit 23d3177
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
10 changes: 7 additions & 3 deletions farm/modeling/adaptive_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,10 @@ def get_language(self):
return self.language_model.language

def convert_to_transformers(self):
if len(self.prediction_heads) != 1:
if len(self.prediction_heads) == 2 and self.prediction_heads[0].model_type == "language_modelling":
logger.warning("Currently only the Masked Language Modeling component of the prediction head is converted, "
"not the Next Sentence Prediction or Sentence Order Prediction components")
elif len(self.prediction_heads) != 1:
raise ValueError(f"Currently conversion only works for models with a SINGLE prediction head. "
f"Your model has {len(self.prediction_heads)}")
elif len(self.prediction_heads[0].layer_dims) != 2:
Expand All @@ -524,14 +527,15 @@ def convert_to_transformers(self):
transformers_model = AutoModelWithLMHead.from_config(self.language_model.model.config)
# transfer weights for language model + prediction head
setattr(transformers_model, transformers_model.base_model_prefix, self.language_model.model)
# Adding decoder bias (required for conversion to transformers)
self.prediction_heads[0].decoder.bias = self.prediction_heads[0].bias

ph_state_dict = self.prediction_heads[0].state_dict()
ph_state_dict["transform.dense.weight"] = ph_state_dict.pop("dense.weight")
ph_state_dict["transform.dense.bias"] = ph_state_dict.pop("dense.bias")
ph_state_dict["transform.LayerNorm.weight"] = ph_state_dict.pop("LayerNorm.weight")
ph_state_dict["transform.LayerNorm.bias"] = ph_state_dict.pop("LayerNorm.bias")
transformers_model.cls.predictions.load_state_dict(ph_state_dict)
logger.warning("Currently only the Masked Language Modeling component of the prediction head is converted, "
"not the Next Sentence Prediction or Sentence Order Prediction components")

elif self.prediction_heads[0].model_type == "text_classification":
if self.language_model.model.base_model_prefix == "roberta":
Expand Down
2 changes: 2 additions & 0 deletions farm/modeling/prediction_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,8 @@ def __init__(self, hidden_size, vocab_size, hidden_act="gelu", task_name="lm", *
self.vocab_size = vocab_size
self.loss_fct = CrossEntropyLoss(reduction="none", ignore_index=-1)
self.num_labels = vocab_size # vocab size
# Adding layer_dims (required for conversion to transformers)
self.layer_dims = [hidden_size, vocab_size]
# TODO Check if weight init needed!
# self.apply(self.init_bert_weights)
self.ph_output_type = "per_token"
Expand Down

0 comments on commit 23d3177

Please sign in to comment.