You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi ,
I am trying to do full fine tuning of Mixtral-8x7B model on 8xA100-40GB gpus, using
FSDP full sharding
activation checkpointing
bf16-true precision
During the activation recomputation time, the tensor metadata is not matching, causing the training to fail.
I have attached the small repro code and logs, please advise how can this be fixed.
Since this is a large model, sharding, checkpointing, bf16-true are all needed to fit a decent batch_size and sequence_length during training.
What version are you seeing the problem on?
v2.1
How to reproduce the bug
importfunctoolsimporttimeimportpytorch_lightningasplimportpytorch_lightning.strategiesasstrategiesimporttorchimporttorch.utils.datafromtransformersimportAutoModelForCausalLM, PreTrainedTokenizer, AutoTokenizerfromtransformers.models.mixtral.modeling_mixtralimportMixtralDecoderLayerclassTextDataset(torch.utils.data.Dataset):
def__init__(self):
self.data= [
"Transformers is a media franchise produced by American toy company Hasbro",
"and Japanese toy company Takara Tomy.",
"It primarily follows the heroic Autobots",
"and the villainous Decepticons,",
"two alien robot factions at war that can transform into other forms,",
"such as vehicles and animals.",
"The franchise encompasses toys, animation,",
"comic books, video games and films.",
"As of 2011, it generated more than ¥2 trillion ($25 billion) in revenue"
]
def__len__(self):
returnlen(self.data)
def__getitem__(self, index):
returnself.data[index]
classLitModel(pl.LightningModule):
def__init__(self, pretrain_model_name, pad_token_id):
super().__init__()
self.pretrain_model_name=pretrain_model_nameself.pad_token_id=pad_token_idself.model=Nonedefconfigure_model(self):
ifself.modelisnotNone:
returnifself.trainerisnotNone:
sleep=self.trainer.global_rank*4print(f'[rank: {self.trainer.global_rank}] sleeping for {sleep}s to avoid CPU OOM')
time.sleep(sleep)
self.model=AutoModelForCausalLM.from_pretrained(
self.pretrain_model_name,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
pad_token_id=self.pad_token_id,
)
ifself.trainerisnotNoneandself.trainer.global_rank==0:
print("module", self)
defconfigure_optimizers(self):
returntorch.optim.Adam(self.model.parameters())
def_step(self, batch, batch_idx):
res=self.model.forward(**batch, labels=batch['input_ids'])
returnres[0]
deftraining_step(self, batch, batch_idx):
returnself._step(batch, batch_idx)
defvalidation_step(self, batch, batch_idx):
returnself._step(batch, batch_idx)
deftrain():
pl.seed_everything(13)
torch.set_float32_matmul_precision('medium')
policy= {MixtralDecoderLayer}
strategy=strategies.FSDPStrategy(auto_wrap_policy=policy, activation_checkpointing_policy=policy)
trainer=pl.Trainer(accelerator='gpu', devices=8, precision='bf16-true', strategy=strategy, max_epochs=2)
pretrain_model_name="mistralai/Mixtral-8x7B-Instruct-v0.1"tokenizer: PreTrainedTokenizer=AutoTokenizer.from_pretrained(pretrain_model_name)
tokenizer.pad_token_id=tokenizer.eos_token_idtokenizer.padding_side='left'model=LitModel(pretrain_model_name, tokenizer.pad_token_id)
dataset=TextDataset()
train_ds, val_ds=torch.utils.data.random_split(dataset, [.8, .2])
collate=functools.partial(tokenizer, return_tensors='pt', padding='longest')
train_dl=torch.utils.data.DataLoader(train_ds, batch_size=2, collate_fn=collate)
val_dl=torch.utils.data.DataLoader(val_ds, batch_size=2, collate_fn=collate)
trainer.fit(model, train_dl, val_dl)
if__name__=='__main__':
train()
Error messages and logs
Seed set to 13
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[rank: 0] Seed set to 13
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/8
[rank: 5] Seed set to 13
[rank: 4] Seed set to 13
[rank: 7] Seed set to 13
[rank: 6] Seed set to 13
[rank: 3] Seed set to 13
[rank: 2] Seed set to 13
[rank: 1] Seed set to 13
[rank: 5] Seed set to 13
Initializing distributed: GLOBAL_RANK: 5, MEMBER: 6/8
[rank: 7] Seed set to 13
Initializing distributed: GLOBAL_RANK: 7, MEMBER: 8/8
[rank: 6] Seed set to 13
Initializing distributed: GLOBAL_RANK: 6, MEMBER: 7/8
[rank: 3] Seed set to 13
Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/8
[rank: 2] Seed set to 13
Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/8
[rank: 1] Seed set to 13
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/8
[rank: 4] Seed set to 13
Initializing distributed: GLOBAL_RANK: 4, MEMBER: 5/8
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 8 processes
----------------------------------------------------------------------------------------------------
[rank: 5] sleeping for 20s to avoid CPU OOM
[rank: 0] sleeping for 0s to avoid CPU OOM
[rank: 7] sleeping for 28s to avoid CPU OOM
[rank: 4] sleeping for 16s to avoid CPU OOM
[rank: 2] sleeping for 8s to avoid CPU OOM
[rank: 1] sleeping for 4s to avoid CPU OOM
[rank: 6] sleeping for 24s to avoid CPU OOM
[rank: 3] sleeping for 12s to avoid CPU OOM
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19/19 [00:12<00:00, 1.51it/s]
module LitModel(
(model): MixtralForCausalLM(
(model): MixtralModel(
(embed_tokens): Embedding(32000, 4096, padding_idx=2)
(layers): ModuleList(
(0-31): 32 x MixtralDecoderLayer(
(self_attn): MixtralFlashAttention2(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=1024, bias=False)
(v_proj): Linear(in_features=4096, out_features=1024, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): MixtralRotaryEmbedding()
)
(block_sparse_moe): MixtralSparseMoeBlock(
(gate): Linear(in_features=4096, out_features=8, bias=False)
(experts): ModuleList(
(0-7): 8 x MixtralBLockSparseTop2MLP(
(w1): Linear(in_features=4096, out_features=14336, bias=False)
(w2): Linear(in_features=14336, out_features=4096, bias=False)
(w3): Linear(in_features=4096, out_features=14336, bias=False)
(act_fn): SiLU()
)
)
)
(input_layernorm): MixtralRMSNorm()
(post_attention_layernorm): MixtralRMSNorm()
)
)
(norm): MixtralRMSNorm()
)
(lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19/19 [00:10<00:00, 1.73it/s]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19/19 [00:12<00:00, 1.46it/s]
LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19/19 [00:14<00:00, 1.29it/s]
LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19/19 [00:18<00:00, 1.00it/s]
LOCAL_RANK: 4 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19/19 [00:18<00:00, 1.00it/s]
LOCAL_RANK: 5 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19/19 [00:18<00:00, 1.04it/s]
Loading checkpoint shards: 89%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 17/19 [00:13<00:01, 1.73it/s]LOCAL_RANK: 6 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19/19 [00:14<00:00, 1.28it/s]
LOCAL_RANK: 7 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
| Name | Type | Params
---------------------------------------------
0 | model | MixtralForCausalLM | 5.8 B
---------------------------------------------
5.8 B Trainable params
0 Non-trainable params
5.8 B Total params
23,351.396Total estimated model params size (MB)
Sanity Checking DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
Epoch 0: 0%| | 0/1 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/root/workspace/lm-pretrain/training/dummy.py", line 105, in <module>
train()
File "/root/workspace/lm-pretrain/training/dummy.py", line 101, in train
Traceback (most recent call last):
trainer.fit(model, train_dl, val_dl)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 544, in fit
File "/root/workspace/lm-pretrain/training/dummy.py", line 105, in <module>
call._call_and_handle_interrupt(
train()
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
File "/root/workspace/lm-pretrain/training/dummy.py", line 101, in train
return trainer_fn(*args, **kwargs)
trainer.fit(model, train_dl, val_dl)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 580, in _fit_impl
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 544, in fit
self._run(model, ckpt_path=ckpt_path)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 989, in _run
call._call_and_handle_interrupt(
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 580, in _fit_impl
results = self._run_stage()
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1035, in _run_stage
self._run(model, ckpt_path=ckpt_path)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 989, in _run
self.fit_loop.run()
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 202, in run
results = self._run_stage()
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1035, in _run_stage
self.advance()
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 359, in advance
self.epoch_loop.run(self._data_fetcher)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 136, in run
self.fit_loop.run()
self.advance(data_fetcher)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 202, in run
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 240, in advance
self.advance()
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 359, in advance
batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 187, in run
self.epoch_loop.run(self._data_fetcher)
self._optimizer_step(batch_idx, closure)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 136, in run
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 265, in _optimizer_step
self.advance(data_fetcher)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 240, in advance
call._call_lightning_module_hook(
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 157, in _call_lightning_module_hook
batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 187, in run
output = fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1291, in optimizer_step
self._optimizer_step(batch_idx, closure)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 265, in _optimizer_step
call._call_lightning_module_hook(
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 157, in _call_lightning_module_hook
output = fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1291, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py", line 151, in step
step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 230, in optimizer_step
return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/fsdp.py", line 145, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py", line 151, in step
return super().optimizer_step(optimizer, model=model, closure=closure, **kwargs)
step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py", line 117, in optimizer_step
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 230, in optimizer_step
return optimizer.step(closure=closure, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/optim/optimizer.py", line 373, in wrapper
return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/fsdp.py", line 145, in optimizer_step
return super().optimizer_step(optimizer, model=model, closure=closure, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py", line 117, in optimizer_step
out = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/optim/optimizer.py", line 76, in _use_grad
return optimizer.step(closure=closure, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/optim/optimizer.py", line 373, in wrapper
ret = func(self, *args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/optim/adam.py", line 143, in step
loss = closure()
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py", line 104, in _wrap_closure
out = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/optim/optimizer.py", line 76, in _use_grad
closure_result = closure()
ret = func(self, *args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 140, in __call__
File "/opt/conda/lib/python3.10/site-packages/torch/optim/adam.py", line 143, in step
self._result = self.closure(*args, **kwargs)
loss = closure()
File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py", line 104, in _wrap_closure
return func(*args, **kwargs)
closure_result = closure()
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 135, in closure
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 140, in __call__
self._backward_fn(step_output.closure_loss)
self._result = self.closure(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 236, in backward_fn
File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 135, in closure
call._call_strategy_hook(self.trainer, "backward", loss, optimizer)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 309, in _call_strategy_hook
self._backward_fn(step_output.closure_loss)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 236, in backward_fn
output = fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 204, in backward
call._call_strategy_hook(self.trainer, "backward", loss, optimizer)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 309, in _call_strategy_hook
self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py", line 69, in backward
Traceback (most recent call last):
File "/root/workspace/lm-pretrain/training/dummy.py", line 105, in <module>
output = fn(*args, **kwargs)
model.backward(tensor, *args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 204, in backward
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1078, in backward
self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py", line 69, in backward
model.backward(tensor, *args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1078, in backward
train()
File "/root/workspace/lm-pretrain/training/dummy.py", line 101, in train
loss.backward(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
trainer.fit(model, train_dl, val_dl)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 544, in fit
torch.autograd.backward(
loss.backward(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
File "/opt/conda/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
call._call_and_handle_interrupt(
File "/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1075, in unpack_hook
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
torch.autograd.backward(
File "/opt/conda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
return trainer_fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 580, in _fit_impl
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1075, in unpack_hook
frame.check_recomputed_tensors_match(gid)
File "/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 850, in check_recomputed_tensors_match
self._run(model, ckpt_path=ckpt_path)
File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 989, in _run
frame.check_recomputed_tensors_match(gid)
File "/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 850, in check_recomputed_tensors_match
raise CheckpointError(
torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.
tensor at position 18:
saved metadata: {'shape': torch.Size([1, 25, 32, 128]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=5)}
recomputed metadata: {'shape': torch.Size([1, 50, 32, 128]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=5)}
tensor at position 19:
saved metadata: {'shape': torch.Size([1, 25, 32, 128]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=5)}
recomputed metadata: {'shape': torch.Size([1, 50, 32, 128]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=5)}
I tried using pretrained_model_name='mistralai/Mistral-7B-Instruct-v0.2'.
It still gives tensor size mismatch error, irrespective which precision I use (bf16-mixed or true).
Seems like activation_checkpointing and fsdp are not working well together.
hrushikesh198
changed the title
bf16-true precision does not work with gradient checkpointing
Recomputed tensor size does not match when using activation checkpointing in FSDP strategy
Jan 11, 2024
I was able to fix the issue by explicitly calling self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': True}) post model load inside configure_model() method.
Issue re-appears if I set use_reentrant:False in the above call.
I was able to fix the issue by explicitly calling self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': True}) post model load inside configure_model() method. Issue re-appears if I set use_reentrant:False in the above call.
Did you experience significant performance degradation when setting gradient_checkpointing_kwargs={'use_reentrant': True} for deepspeed 3? A similar observation is this
Bug description
Hi ,
I am trying to do full fine tuning of Mixtral-8x7B model on 8xA100-40GB gpus, using
During the activation recomputation time, the tensor metadata is not matching, causing the training to fail.
I have attached the small repro code and logs, please advise how can this be fixed.
Since this is a large model, sharding, checkpointing, bf16-true are all needed to fit a decent batch_size and sequence_length during training.
What version are you seeing the problem on?
v2.1
How to reproduce the bug
Error messages and logs
Environment
Current environment
More info
No response
The text was updated successfully, but these errors were encountered: