diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6ddc5efabcff..2b0a6a6c69c8 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2048,7 +2048,10 @@ def _get_no_split_modules(self, device_map: str): return list(_no_split_modules) def resize_token_embeddings( - self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None + self, + new_num_tokens: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + mean_resizing: bool = True, ) -> nn.Embedding: """ Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`. @@ -2068,11 +2071,19 @@ def resize_token_embeddings( `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more details about this, or help on choosing the correct value for resizing, refer to this guide: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc + mean_resizing (`bool`): + Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and + covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`. + + Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models, + where the generated tokens' probabilities won't be affected by the added embeddings because initializing the new embeddings with the + old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings. + Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html Return: `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. """ - model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) if new_num_tokens is None and pad_to_multiple_of is None: return model_embeds @@ -2095,9 +2106,11 @@ def resize_token_embeddings( return model_embeds - def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None): + def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean_resizing=True): old_embeddings = self.get_input_embeddings() - new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of) + new_embeddings = self._get_resized_embeddings( + old_embeddings, new_num_tokens, pad_to_multiple_of, mean_resizing + ) if hasattr(old_embeddings, "_hf_hook"): hook = old_embeddings._hf_hook add_hook_to_module(new_embeddings, hook) @@ -2120,9 +2133,9 @@ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None): if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings: old_lm_head = self.get_output_embeddings() if isinstance(old_lm_head, torch.nn.Embedding): - new_lm_head = self._get_resized_embeddings(old_lm_head, new_num_tokens) + new_lm_head = self._get_resized_embeddings(old_lm_head, new_num_tokens, mean_resizing=mean_resizing) else: - new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens) + new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens, mean_resizing=mean_resizing) if hasattr(old_lm_head, "_hf_hook"): hook = old_lm_head._hf_hook add_hook_to_module(new_lm_head, hook) @@ -2137,6 +2150,7 @@ def _get_resized_embeddings( old_embeddings: nn.Embedding, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None, + mean_resizing: bool = True, ) -> nn.Embedding: """ Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly @@ -2159,6 +2173,14 @@ def _get_resized_embeddings( `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more details about this, or help on choosing the correct value for resizing, refer to this guide: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc + mean_resizing (`bool`): + Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and + covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`. + + Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models, + where the generated tokens' probabilities will not be affected by the added embeddings because initializing the new embeddings with the + old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings. + Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html Return: @@ -2217,8 +2239,32 @@ def _get_resized_embeddings( dtype=old_embeddings.weight.dtype, ) - # initialize all new embeddings (in particular added tokens) - self._init_weights(new_embeddings) + if new_num_tokens > old_num_tokens and not mean_resizing: + # initialize new embeddings (in particular added tokens) with a mean of 0 and std equals `config.initializer_range`. + self._init_weights(new_embeddings) + + elif new_num_tokens > old_num_tokens and mean_resizing: + # initialize new embeddings (in particular added tokens). The new embeddings will be initialized + # from a multivariate normal distribution that has old embeddings' mean and covariance. + # as described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html + logger.warning_once( + "The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. " + "As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. " + "To disable this, use `mean_resizing=False`" + ) + + added_num_tokens = new_num_tokens - old_num_tokens + if is_deepspeed_zero3_enabled() and not is_quantized: + import deepspeed + + with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None): + self._init_added_embeddings_weights_with_mean( + old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens + ) + else: + self._init_added_embeddings_weights_with_mean( + old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens + ) # Copy token embeddings from the previous weights @@ -2258,7 +2304,11 @@ def _get_resized_embeddings( return old_embeddings def _get_resized_lm_head( - self, old_lm_head: nn.Linear, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False + self, + old_lm_head: nn.Linear, + new_num_tokens: Optional[int] = None, + transposed: Optional[bool] = False, + mean_resizing: bool = True, ) -> nn.Linear: """ Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized @@ -2275,6 +2325,14 @@ def _get_resized_lm_head( `torch.nn.Linear` module of the model without doing anything. transposed (`bool`, *optional*, defaults to `False`): Whether `old_lm_head` is transposed or not. If True `old_lm_head.size()` is `lm_head_dim, vocab_size` else `vocab_size, lm_head_dim`. + mean_resizing (`bool`): + Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and + covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`. + + Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models, + where the generated tokens' probabilities will not be affected by the added embeddings because initializing the new embeddings with the + old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings. + Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html Return: `torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is @@ -2321,8 +2379,40 @@ def _get_resized_lm_head( dtype=old_lm_head.weight.dtype, ) - # initialize new lm head (in particular added tokens) - self._init_weights(new_lm_head) + if new_num_tokens > old_num_tokens and not mean_resizing: + # initialize new embeddings (in particular added tokens) with a mean of 0 and std equals `config.initializer_range`. + self._init_weights(new_lm_head) + + elif new_num_tokens > old_num_tokens and mean_resizing: + # initialize new lm_head weights (in particular added tokens). The new lm_head weights + # will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. + # as described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html + logger.warning_once( + "The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. " + "As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. " + "To disable this, use `mean_resizing=False`" + ) + + added_num_tokens = new_num_tokens - old_num_tokens + if is_deepspeed_zero3_enabled() and not is_quantized: + import deepspeed + + params = [old_lm_head.weight] + if has_new_lm_head_bias: + params += [old_lm_head.bias] + with deepspeed.zero.GatheredParameters(params, modifier_rank=None): + self._init_added_lm_head_weights_with_mean( + old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens, transposed + ) + if has_new_lm_head_bias: + self._init_added_lm_head_bias_with_mean(old_lm_head, new_lm_head, added_num_tokens) + + else: + self._init_added_lm_head_weights_with_mean( + old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens, transposed + ) + if has_new_lm_head_bias: + self._init_added_lm_head_bias_with_mean(old_lm_head, new_lm_head, added_num_tokens) num_tokens_to_copy = min(old_num_tokens, new_num_tokens) @@ -2341,6 +2431,52 @@ def _get_resized_lm_head( return new_lm_head + def _init_added_embeddings_weights_with_mean( + self, old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens + ): + old_embeddings_weight = old_embeddings.weight.data.to(torch.float32) + mean_embeddings = torch.mean(old_embeddings_weight, axis=0) + old_centered_embeddings = old_embeddings_weight - mean_embeddings + covariance = old_centered_embeddings.T @ old_centered_embeddings / old_num_tokens + if old_embedding_dim >= old_num_tokens: + # Covarince matrix must be positive definite. For edge cases, when `vocab_size` is + # smaller than `hidden_size`, covarince matrix won't be positive definite so we + # must add the eye matrix to the covarince matrix to convert it to be positive definite. + covariance = covariance + torch.eye(old_embedding_dim, device=old_embeddings.weight.device) * 1e-3 + distribution = torch.distributions.multivariate_normal.MultivariateNormal( + mean_embeddings, covariance_matrix=1e-5 * covariance + ) + new_embeddings.weight.data[-1 * added_num_tokens :, :] = distribution.sample( + sample_shape=(added_num_tokens,) + ).to(old_embeddings.weight.dtype) + + def _init_added_lm_head_weights_with_mean( + self, + old_lm_head, + new_lm_head, + old_lm_head_dim, + old_num_tokens, + added_num_tokens, + transposed=False, + ): + if transposed: + # Transpose to the desired shape for the function. + new_lm_head.weight.data = new_lm_head.weight.data.T + + # The same initilization logic as Embeddings. + self._init_added_embeddings_weights_with_mean( + old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens + ) + + if transposed: + # Transpose again to the correct shape. + new_lm_head.weight.data = new_lm_head.weight.data.T + + def _init_added_lm_head_bias_with_mean(self, old_lm_head, new_lm_head, added_num_tokens): + bias_mean = torch.mean(old_lm_head.bias.data, axis=0, dtype=torch.float32) + bias_std = torch.std(old_lm_head.bias.data, axis=0).to(torch.float32) + new_lm_head.bias.data[-1 * added_num_tokens :].normal_(mean=bias_mean, std=bias_std * 1e-5) + def _copy_lm_head_original_to_resized( self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias ): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 705631a460d8..66b4e25a4526 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -25,6 +25,7 @@ import time import warnings from collections import defaultdict +from contextlib import contextmanager from typing import Dict, List, Tuple import numpy as np @@ -45,6 +46,12 @@ logging, set_seed, ) +from transformers.integrations import HfDeepSpeedConfig +from transformers.integrations.deepspeed import ( + is_deepspeed_available, + is_deepspeed_zero3_enabled, + unset_hf_deepspeed_config, +) from transformers.models.auto import get_values from transformers.models.auto.modeling_auto import ( MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, @@ -75,6 +82,7 @@ is_pt_tf_cross_test, require_accelerate, require_bitsandbytes, + require_deepspeed, require_flash_attn, require_non_xpu, require_read_token, @@ -134,6 +142,9 @@ if is_torch_fx_available(): from transformers.utils.fx import _FX_SUPPORTED_MODELS_WITH_KV_CACHE, symbolic_trace +if is_deepspeed_available(): + import deepspeed + def _config_zero_init(config): configs_no_init = copy.deepcopy(config) @@ -171,6 +182,15 @@ def _mock_all_init_weights(self): self.tie_weights() +@contextmanager +def _deepspeed_zero3(ds_config): + dschf = HfDeepSpeedConfig(ds_config) + try: + yield dschf + finally: + unset_hf_deepspeed_config() + + @require_torch class ModelTesterMixin: model_tester = None @@ -1797,8 +1817,13 @@ def test_resize_tokens_embeddings(self): for model_class in self.all_model_classes: config = copy.deepcopy(original_config) - model = model_class(config) - model.to(torch_device) + if is_deepspeed_zero3_enabled(): + with deepspeed.zero.Init(): + model = model_class(config) + else: + model = model_class(config) + model.to(torch_device) + model_embed_pre_resize = model.get_input_embeddings() type_model_embed_pre_resize = type(model_embed_pre_resize) @@ -1813,15 +1838,26 @@ def test_resize_tokens_embeddings(self): # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size model_embed = model.resize_token_embeddings(model_vocab_size + 10) new_model_vocab_size = model.config.get_text_config().vocab_size - self.assertEqual(new_model_vocab_size, model_vocab_size + 10) # Check that it actually resizes the embeddings matrix self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10) # Check to make sure the type of embeddings returned post resizing is same as type of input type_model_embed_post_resize = type(model_embed) self.assertEqual(type_model_embed_pre_resize, type_model_embed_post_resize) + # Check that added embeddings mean is close to the old embeddings mean + if is_deepspeed_zero3_enabled(): + with deepspeed.zero.GatheredParameters(model_embed.weight, modifier_rank=None): + old_embeddings_mean = torch.mean(model_embed.weight.data[:-10, :], axis=0) + new_embeddings_mean = torch.mean(model_embed.weight.data[-10:, :], axis=0) + else: + old_embeddings_mean = torch.mean(model_embed.weight.data[:-10, :], axis=0) + new_embeddings_mean = torch.mean(model_embed.weight.data[-10:, :], axis=0) + torch.testing.assert_close(old_embeddings_mean, new_embeddings_mean, atol=1e-3, rtol=1e-1) + # Check that the model can still do a forward pass successfully (every parameter should be resized) - model(**self._prepare_for_class(inputs_dict, model_class)) + if not is_deepspeed_zero3_enabled(): + # A distriputed launcher is needed for the forward pass when deepspeed is enabled + model(**self._prepare_for_class(inputs_dict, model_class)) # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size model_embed = model.resize_token_embeddings(model_vocab_size - 15) @@ -1835,9 +1871,11 @@ def test_resize_tokens_embeddings(self): inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1) # make sure that decoder_input_ids are resized as well - if "decoder_input_ids" in inputs_dict: - inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1) - model(**self._prepare_for_class(inputs_dict, model_class)) + if not is_deepspeed_zero3_enabled(): + # A distriputed launcher is needed for the forward pass when deepspeed is enabled + if "decoder_input_ids" in inputs_dict: + inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1) + model(**self._prepare_for_class(inputs_dict, model_class)) # Check that adding and removing tokens has not modified the first part of the embedding matrix. models_equal = True @@ -1847,9 +1885,13 @@ def test_resize_tokens_embeddings(self): self.assertTrue(models_equal) - config = copy.deepcopy(original_config) - model = model_class(config) - model.to(torch_device) + del model + if is_deepspeed_zero3_enabled(): + with deepspeed.zero.Init(): + model = model_class(config) + else: + model = model_class(config) + model.to(torch_device) model_vocab_size = config.get_text_config().vocab_size model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1) @@ -1877,6 +1919,63 @@ def test_resize_tokens_embeddings(self): ): model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3) + # Test when `vocab_size` is smaller than `hidden_size`. + del model + config.vocab_size = 4 + if is_deepspeed_zero3_enabled(): + with deepspeed.zero.Init(): + model = model_class(config) + else: + model = model_class(config) + model.to(torch_device) + + model_vocab_size = config.get_text_config().vocab_size + # Retrieve the embeddings and clone theme + model_embed = model.resize_token_embeddings(model_vocab_size) + cloned_embeddings = model_embed.weight.clone() + + # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size + model_embed = model.resize_token_embeddings(model_vocab_size + 10) + new_model_vocab_size = model.config.get_text_config().vocab_size + self.assertEqual(new_model_vocab_size, model_vocab_size + 10) + # Check that it actually resizes the embeddings matrix + self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10) + # Check to make sure the type of embeddings returned post resizing is same as type of input + type_model_embed_post_resize = type(model_embed) + self.assertEqual(type_model_embed_pre_resize, type_model_embed_post_resize) + # Check that added embeddings mean is close to the old embeddings mean + if is_deepspeed_zero3_enabled(): + with deepspeed.zero.GatheredParameters(model_embed.weight, modifier_rank=None): + old_embeddings_mean = torch.mean(model_embed.weight.data[:-10, :], axis=0) + new_embeddings_mean = torch.mean(model_embed.weight.data[-10:, :], axis=0) + else: + old_embeddings_mean = torch.mean(model_embed.weight.data[:-10, :], axis=0) + new_embeddings_mean = torch.mean(model_embed.weight.data[-10:, :], axis=0) + torch.testing.assert_close(old_embeddings_mean, new_embeddings_mean, atol=1e-3, rtol=1e-1) + + @require_deepspeed + @require_torch_gpu + def test_resize_tokens_embeddings_with_deepspeed(self): + ds_config = { + "zero_optimization": { + "stage": 3, + "offload_param": {"device": "cpu", "pin_memory": True}, + }, + } + with _deepspeed_zero3(ds_config): + self.test_resize_tokens_embeddings() + + @require_deepspeed + @require_torch_multi_gpu + def test_resize_tokens_embeddings_with_deepspeed_multi_gpu(self): + ds_config = { + "zero_optimization": { + "stage": 3, + }, + } + with _deepspeed_zero3(ds_config): + self.test_resize_tokens_embeddings() + def test_resize_embeddings_untied(self): if not self.test_resize_embeddings: self.skipTest(reason="test_resize_embeddings is set to `False`") @@ -1890,7 +1989,11 @@ def test_resize_embeddings_untied(self): for model_class in self.all_model_classes: config = copy.deepcopy(original_config) - model = model_class(config).to(torch_device) + if is_deepspeed_zero3_enabled(): + with deepspeed.zero.Init(): + model = model_class(config) + else: + model = model_class(config).to(torch_device) # if no output embeddings -> leave test if model.get_output_embeddings() is None: @@ -1907,7 +2010,33 @@ def test_resize_embeddings_untied(self): if output_embeds.bias is not None: self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10) # Check that the model can still do a forward pass successfully (every parameter should be resized) - model(**self._prepare_for_class(inputs_dict, model_class)) + if not is_deepspeed_zero3_enabled(): + # A distriputed launcher is needed for the forward pass when deepspeed is enabled + model(**self._prepare_for_class(inputs_dict, model_class)) + + # Test multivariate resizing. + model.resize_token_embeddings(model_vocab_size + 10) + output_embeds = model.get_output_embeddings() + # Check that added embeddings mean is close to the old embeddings mean + if is_deepspeed_zero3_enabled(): + with deepspeed.zero.GatheredParameters(output_embeds.weight, modifier_rank=None): + old_embeddings_mean = torch.mean(output_embeds.weight.data[:-10, :], axis=0) + new_embeddings_mean = torch.mean(output_embeds.weight.data[-10:, :], axis=0) + else: + old_embeddings_mean = torch.mean(output_embeds.weight.data[:-10, :], axis=0) + new_embeddings_mean = torch.mean(output_embeds.weight.data[-10:, :], axis=0) + torch.testing.assert_close(old_embeddings_mean, new_embeddings_mean, atol=1e-3, rtol=1e-1) + # check if the bias is always initialized with zero. + if output_embeds.bias is not None: + if is_deepspeed_zero3_enabled(): + with deepspeed.zero.GatheredParameters(output_embeds.bias, modifier_rank=None): + old_bias_mean = torch.mean(output_embeds.bias.data[:-10], axis=0) + new_bias_mean = torch.mean(output_embeds.bias.data[-10:], axis=0) + else: + old_bias_mean = torch.mean(output_embeds.bias.data[:-10], axis=0) + new_bias_mean = torch.mean(output_embeds.bias.data[-10:], axis=0) + + torch.testing.assert_close(old_bias_mean, new_bias_mean, atol=1e-5, rtol=1e-2) # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size model.resize_token_embeddings(model_vocab_size - 15) @@ -1925,7 +2054,32 @@ def test_resize_embeddings_untied(self): if "decoder_input_ids" in inputs_dict: inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1) # Check that the model can still do a forward pass successfully (every parameter should be resized) - model(**self._prepare_for_class(inputs_dict, model_class)) + if not is_deepspeed_zero3_enabled(): + # A distriputed launcher is needed for the forward pass when deepspeed is enabled + model(**self._prepare_for_class(inputs_dict, model_class)) + + @require_deepspeed + @require_torch_gpu + def test_resize_embeddings_untied_with_deepspeed(self): + ds_config = { + "zero_optimization": { + "stage": 3, + "offload_param": {"device": "cpu", "pin_memory": True}, + }, + } + with _deepspeed_zero3(ds_config): + self.test_resize_embeddings_untied() + + @require_deepspeed + @require_torch_multi_gpu + def test_resize_embeddings_untied_with_deepspeed_multi_gpu(self): + ds_config = { + "zero_optimization": { + "stage": 3, + }, + } + with _deepspeed_zero3(ds_config): + self.test_resize_embeddings_untied() def test_model_get_set_embeddings(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()