From d73e5e6bad39a31fe23e77349efb77f81b379682 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 29 Dec 2020 11:09:36 -0800 Subject: [PATCH 1/5] --model_parallel hasn't been implemented for most models --- src/transformers/trainer.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 6756e591656a..8914026331c8 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -241,6 +241,14 @@ def __init__( if model is None and model_init is not None: model = self.call_model_init() + if self.args.model_parallel: + # XXX: ideally this register should be maintained elsewhere so that the trainer could just do + # if model.model_parallel_is_supported() + mp_supported = ["gpt2", "t5"] + assert ( + model.config.model_type in mp_supported + ), f"{model.config.model_type} implementation currently doesn't support model parallelism, therefore --model_parallel cl arg cannot be used" + # Model parallel if model is not None and not self.args.model_parallel: model = model.to(args.device) From 1553f61dbcea773c6ef18308ea941fc6db503fbc Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 29 Dec 2020 11:21:34 -0800 Subject: [PATCH 2/5] make the help clear as well --- src/transformers/training_args.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 9d78ce41fe34..8ac8eb88a026 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -207,8 +207,8 @@ class TrainingArguments: :obj:`"eval_loss"`. - :obj:`False` if :obj:`metric_for_best_model` is not set, or set to :obj:`"loss"` or :obj:`"eval_loss"`. model_parallel (:obj:`bool`, `optional`, defaults to :obj:`False`): - If there is more than one device, whether to use model parallelism to distribute the model's modules across - devices or not. + If the model supports model parallelism and there is more than one device, whether to use model parallelism + to distribute the model's modules across devices or not. ignore_skip_data (:obj:`bool`, `optional`, defaults to :obj:`False`): When resuming training, whether or not to skip the epochs and batches to get the data loading at the same stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping From 3859e7726a716b1cf22b8c84f59bdcc28fd301ad Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 4 Jan 2021 13:03:54 -0800 Subject: [PATCH 3/5] implement is_parallelizable; use it --- src/transformers/modeling_utils.py | 7 +++++++ src/transformers/models/gpt2/modeling_gpt2.py | 1 + src/transformers/models/t5/modeling_t5.py | 1 + src/transformers/trainer.py | 11 ++++------- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index fba7aa89cbeb..33ec28e3f56c 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -404,6 +404,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model. + - **_is_parallelizable** (:obj:`bool`) -- A flag indicating whether this model supports model parallelization. """ config_class = None base_model_prefix = "" @@ -417,6 +418,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): # trained, but which are deterministic) _keys_to_ignore_on_save = None + _is_parallelizable = False + + @property + def is_parallelizable(self) -> bool: + return self._is_parallelizable + @property def dummy_inputs(self) -> Dict[str, torch.Tensor]: """ diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index bb8046c0e2f0..d75a5f1f040d 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -337,6 +337,7 @@ class GPT2PreTrainedModel(PreTrainedModel): config_class = GPT2Config load_tf_weights = load_tf_weights_in_gpt2 base_model_prefix = "transformer" + _is_parallelizable = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 0ce2be3c62ac..2ce3c80cad9e 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -683,6 +683,7 @@ class T5PreTrainedModel(PreTrainedModel): config_class = T5Config load_tf_weights = load_tf_weights_in_t5 base_model_prefix = "transformer" + _is_parallelizable = True @property def dummy_inputs(self): diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 47f1f58baa13..c15d5e10a860 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -242,13 +242,10 @@ def __init__( if model is None and model_init is not None: model = self.call_model_init() - if self.args.model_parallel: - # XXX: ideally this register should be maintained elsewhere so that the trainer could just do - # if model.model_parallel_is_supported() - mp_supported = ["gpt2", "t5"] - assert ( - model.config.model_type in mp_supported - ), f"{model.config.model_type} implementation currently doesn't support model parallelism, therefore --model_parallel cl arg cannot be used" + if not model.is_parallelizable: + raise ValueError( + f"{model.__class__.__name__} implementation currently doesn't support model parallelism, therefore --model_parallel cl arg cannot be used" + ) # Model parallel if model is not None and not self.args.model_parallel: From 6f3799c25d8cd728c3322a01d361223d293dcd49 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 4 Jan 2021 13:24:50 -0800 Subject: [PATCH 4/5] oops --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index c15d5e10a860..a5ecba584d0f 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -242,7 +242,7 @@ def __init__( if model is None and model_init is not None: model = self.call_model_init() - if not model.is_parallelizable: + if self.args.model_parallel and not model.is_parallelizable: raise ValueError( f"{model.__class__.__name__} implementation currently doesn't support model parallelism, therefore --model_parallel cl arg cannot be used" ) From d4f60ea1b075315bd5b313c56ede9878721d7564 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 4 Jan 2021 13:28:02 -0800 Subject: [PATCH 5/5] remove property --- src/transformers/modeling_utils.py | 8 ++------ src/transformers/models/gpt2/modeling_gpt2.py | 2 +- src/transformers/models/t5/modeling_t5.py | 2 +- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 33ec28e3f56c..d0fc1ad0f4b2 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -404,7 +404,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model. - - **_is_parallelizable** (:obj:`bool`) -- A flag indicating whether this model supports model parallelization. + - **is_parallelizable** (:obj:`bool`) -- A flag indicating whether this model supports model parallelization. """ config_class = None base_model_prefix = "" @@ -418,11 +418,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): # trained, but which are deterministic) _keys_to_ignore_on_save = None - _is_parallelizable = False - - @property - def is_parallelizable(self) -> bool: - return self._is_parallelizable + is_parallelizable = False @property def dummy_inputs(self) -> Dict[str, torch.Tensor]: diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index d75a5f1f040d..867a02d361fb 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -337,7 +337,7 @@ class GPT2PreTrainedModel(PreTrainedModel): config_class = GPT2Config load_tf_weights = load_tf_weights_in_gpt2 base_model_prefix = "transformer" - _is_parallelizable = True + is_parallelizable = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 2ce3c80cad9e..00d9ca30eccd 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -683,7 +683,7 @@ class T5PreTrainedModel(PreTrainedModel): config_class = T5Config load_tf_weights = load_tf_weights_in_t5 base_model_prefix = "transformer" - _is_parallelizable = True + is_parallelizable = True @property def dummy_inputs(self):