diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 9fb0d7475fb9..d6e74d6f89ee 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -112,6 +112,10 @@ def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path): if name[-1] == "w" and name[-2] in ["out_proj", "k_proj", "q_proj", "v_proj", "c_proj", "c_fc"]: array = array.transpose() + if name == ["wte"]: + # if vocab is padded, then trim off the padding embeddings + array = array[: config.vocab_size] + try: assert ( pointer.shape == array.shape