Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recomputed tensor size does not match when using activation checkpointing in FSDP strategy #19267

Closed
hrushikesh198 opened this issue Jan 10, 2024 · 3 comments
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.1.x

Comments

@hrushikesh198
Copy link

Bug description

Hi ,
I am trying to do full fine tuning of Mixtral-8x7B model on 8xA100-40GB gpus, using

  • FSDP full sharding
  • activation checkpointing
  • bf16-true precision

During the activation recomputation time, the tensor metadata is not matching, causing the training to fail.
I have attached the small repro code and logs, please advise how can this be fixed.

Since this is a large model, sharding, checkpointing, bf16-true are all needed to fit a decent batch_size and sequence_length during training.

What version are you seeing the problem on?

v2.1

How to reproduce the bug

import functools
import time

import pytorch_lightning as pl
import pytorch_lightning.strategies as strategies
import torch
import torch.utils.data
from transformers import AutoModelForCausalLM, PreTrainedTokenizer, AutoTokenizer
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer

class TextDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.data = [
            "Transformers is a media franchise produced by American toy company Hasbro",
            "and Japanese toy company Takara Tomy.",
            "It primarily follows the heroic Autobots",
            "and the villainous Decepticons,",
            "two alien robot factions at war that can transform into other forms,",
            "such as vehicles and animals.",
            "The franchise encompasses toys, animation,",
            "comic books, video games and films.",
            "As of 2011, it generated more than ¥2 trillion ($25 billion) in revenue"
        ]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]


class LitModel(pl.LightningModule):
    def __init__(self, pretrain_model_name, pad_token_id):
        super().__init__()
        self.pretrain_model_name = pretrain_model_name
        self.pad_token_id = pad_token_id
        self.model = None

    def configure_model(self):
        if self.model is not None:
            return

        if self.trainer is not None:
            sleep = self.trainer.global_rank * 4
            print(f'[rank: {self.trainer.global_rank}] sleeping for {sleep}s to avoid CPU OOM')
            time.sleep(sleep)

        self.model = AutoModelForCausalLM.from_pretrained(
            self.pretrain_model_name,
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
            pad_token_id=self.pad_token_id,
        )

        if self.trainer is not None and self.trainer.global_rank == 0:
            print("module", self)

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters())

    def _step(self, batch, batch_idx):
        res = self.model.forward(**batch, labels=batch['input_ids'])
        return res[0]

    def training_step(self, batch, batch_idx):
        return self._step(batch, batch_idx)

    def validation_step(self, batch, batch_idx):
        return self._step(batch, batch_idx)


def train():
    pl.seed_everything(13)
    torch.set_float32_matmul_precision('medium')

    policy = {MixtralDecoderLayer}
    strategy = strategies.FSDPStrategy(auto_wrap_policy=policy, activation_checkpointing_policy=policy)
    trainer = pl.Trainer(accelerator='gpu', devices=8, precision='bf16-true', strategy=strategy, max_epochs=2)

    pretrain_model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1"
    tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(pretrain_model_name)
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = 'left'
    model = LitModel(pretrain_model_name, tokenizer.pad_token_id)

    dataset = TextDataset()
    train_ds, val_ds = torch.utils.data.random_split(dataset, [.8, .2])
    collate = functools.partial(tokenizer, return_tensors='pt', padding='longest')
    train_dl = torch.utils.data.DataLoader(train_ds, batch_size=2, collate_fn=collate)
    val_dl = torch.utils.data.DataLoader(val_ds, batch_size=2, collate_fn=collate)

    trainer.fit(model, train_dl, val_dl)


if __name__ == '__main__':
    train()

Error messages and logs

Seed set to 13
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[rank: 0] Seed set to 13
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/8
[rank: 5] Seed set to 13
[rank: 4] Seed set to 13
[rank: 7] Seed set to 13
[rank: 6] Seed set to 13
[rank: 3] Seed set to 13
[rank: 2] Seed set to 13
[rank: 1] Seed set to 13
[rank: 5] Seed set to 13
Initializing distributed: GLOBAL_RANK: 5, MEMBER: 6/8
[rank: 7] Seed set to 13
Initializing distributed: GLOBAL_RANK: 7, MEMBER: 8/8
[rank: 6] Seed set to 13
Initializing distributed: GLOBAL_RANK: 6, MEMBER: 7/8
[rank: 3] Seed set to 13
Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/8
[rank: 2] Seed set to 13
Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/8
[rank: 1] Seed set to 13
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/8
[rank: 4] Seed set to 13
Initializing distributed: GLOBAL_RANK: 4, MEMBER: 5/8
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 8 processes
----------------------------------------------------------------------------------------------------

[rank: 5] sleeping for 20s to avoid CPU OOM
[rank: 0] sleeping for 0s to avoid CPU OOM
[rank: 7] sleeping for 28s to avoid CPU OOM
[rank: 4] sleeping for 16s to avoid CPU OOM
[rank: 2] sleeping for 8s to avoid CPU OOM
[rank: 1] sleeping for 4s to avoid CPU OOM
[rank: 6] sleeping for 24s to avoid CPU OOM
[rank: 3] sleeping for 12s to avoid CPU OOM
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19/19 [00:12<00:00,  1.51it/s]
module LitModel(
  (model): MixtralForCausalLM(
    (model): MixtralModel(
      (embed_tokens): Embedding(32000, 4096, padding_idx=2)
      (layers): ModuleList(
        (0-31): 32 x MixtralDecoderLayer(
          (self_attn): MixtralFlashAttention2(
            (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (rotary_emb): MixtralRotaryEmbedding()
          )
          (block_sparse_moe): MixtralSparseMoeBlock(
            (gate): Linear(in_features=4096, out_features=8, bias=False)
            (experts): ModuleList(
              (0-7): 8 x MixtralBLockSparseTop2MLP(
                (w1): Linear(in_features=4096, out_features=14336, bias=False)
                (w2): Linear(in_features=14336, out_features=4096, bias=False)
                (w3): Linear(in_features=4096, out_features=14336, bias=False)
                (act_fn): SiLU()
              )
            )
          )
          (input_layernorm): MixtralRMSNorm()
          (post_attention_layernorm): MixtralRMSNorm()
        )
      )
      (norm): MixtralRMSNorm()
    )
    (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
  )
)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19/19 [00:10<00:00,  1.73it/s]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19/19 [00:12<00:00,  1.46it/s]
LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19/19 [00:14<00:00,  1.29it/s]
LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19/19 [00:18<00:00,  1.00it/s]
LOCAL_RANK: 4 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19/19 [00:18<00:00,  1.00it/s]
LOCAL_RANK: 5 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19/19 [00:18<00:00,  1.04it/s]
Loading checkpoint shards:  89%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋             | 17/19 [00:13<00:01,  1.73it/s]LOCAL_RANK: 6 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19/19 [00:14<00:00,  1.28it/s]
LOCAL_RANK: 7 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name  | Type               | Params
---------------------------------------------
0 | model | MixtralForCausalLM | 5.8 B
---------------------------------------------
5.8 B     Trainable params
0         Non-trainable params
5.8 B     Total params
23,351.396Total estimated model params size (MB)
Sanity Checking DataLoader 0:   0%|                                                                                                                                     | 0/1 [00:00<?, ?it/s]
/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
Epoch 0:   0%|                                                                                                                                                          | 0/1 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/root/workspace/lm-pretrain/training/dummy.py", line 105, in <module>
    train()
  File "/root/workspace/lm-pretrain/training/dummy.py", line 101, in train
Traceback (most recent call last):
    trainer.fit(model, train_dl, val_dl)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 544, in fit
  File "/root/workspace/lm-pretrain/training/dummy.py", line 105, in <module>
    call._call_and_handle_interrupt(
    train()
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
  File "/root/workspace/lm-pretrain/training/dummy.py", line 101, in train
    return trainer_fn(*args, **kwargs)
    trainer.fit(model, train_dl, val_dl)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 580, in _fit_impl
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 544, in fit
    self._run(model, ckpt_path=ckpt_path)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 989, in _run
    call._call_and_handle_interrupt(
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 580, in _fit_impl
    results = self._run_stage()
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1035, in _run_stage
    self._run(model, ckpt_path=ckpt_path)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 989, in _run
    self.fit_loop.run()
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 202, in run
    results = self._run_stage()
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1035, in _run_stage
    self.advance()
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 359, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 136, in run
    self.fit_loop.run()
    self.advance(data_fetcher)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 202, in run
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 240, in advance
    self.advance()
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 359, in advance
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 187, in run
    self.epoch_loop.run(self._data_fetcher)
    self._optimizer_step(batch_idx, closure)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 136, in run
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 265, in _optimizer_step
    self.advance(data_fetcher)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 240, in advance
    call._call_lightning_module_hook(
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 157, in _call_lightning_module_hook
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 187, in run
    output = fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1291, in optimizer_step
    self._optimizer_step(batch_idx, closure)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 265, in _optimizer_step
    call._call_lightning_module_hook(
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 157, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1291, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py", line 151, in step
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 230, in optimizer_step
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/fsdp.py", line 145, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py", line 151, in step
    return super().optimizer_step(optimizer, model=model, closure=closure, **kwargs)
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py", line 117, in optimizer_step
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 230, in optimizer_step
    return optimizer.step(closure=closure, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/optim/optimizer.py", line 373, in wrapper
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/fsdp.py", line 145, in optimizer_step
    return super().optimizer_step(optimizer, model=model, closure=closure, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py", line 117, in optimizer_step
    out = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/optim/optimizer.py", line 76, in _use_grad
    return optimizer.step(closure=closure, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/optim/optimizer.py", line 373, in wrapper
    ret = func(self, *args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/optim/adam.py", line 143, in step
    loss = closure()
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py", line 104, in _wrap_closure
    out = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/optim/optimizer.py", line 76, in _use_grad
    closure_result = closure()
    ret = func(self, *args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 140, in __call__
  File "/opt/conda/lib/python3.10/site-packages/torch/optim/adam.py", line 143, in step
    self._result = self.closure(*args, **kwargs)
    loss = closure()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py", line 104, in _wrap_closure
    return func(*args, **kwargs)
    closure_result = closure()
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 135, in closure
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 140, in __call__
    self._backward_fn(step_output.closure_loss)
    self._result = self.closure(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 236, in backward_fn
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 135, in closure
    call._call_strategy_hook(self.trainer, "backward", loss, optimizer)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 309, in _call_strategy_hook
    self._backward_fn(step_output.closure_loss)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 236, in backward_fn
    output = fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 204, in backward
    call._call_strategy_hook(self.trainer, "backward", loss, optimizer)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 309, in _call_strategy_hook
    self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py", line 69, in backward
Traceback (most recent call last):
  File "/root/workspace/lm-pretrain/training/dummy.py", line 105, in <module>
    output = fn(*args, **kwargs)
    model.backward(tensor, *args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 204, in backward
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1078, in backward
    self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py", line 69, in backward
    model.backward(tensor, *args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1078, in backward
    train()
  File "/root/workspace/lm-pretrain/training/dummy.py", line 101, in train
    loss.backward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
    trainer.fit(model, train_dl, val_dl)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 544, in fit
    torch.autograd.backward(
    loss.backward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
  File "/opt/conda/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    call._call_and_handle_interrupt(
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1075, in unpack_hook
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
    torch.autograd.backward(
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    return trainer_fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 580, in _fit_impl
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1075, in unpack_hook
    frame.check_recomputed_tensors_match(gid)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 850, in check_recomputed_tensors_match
    self._run(model, ckpt_path=ckpt_path)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 989, in _run
    frame.check_recomputed_tensors_match(gid)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 850, in check_recomputed_tensors_match
    raise CheckpointError(
torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.
tensor at position 18:
saved metadata: {'shape': torch.Size([1, 25, 32, 128]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=5)}
recomputed metadata: {'shape': torch.Size([1, 50, 32, 128]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=5)}
tensor at position 19:
saved metadata: {'shape': torch.Size([1, 25, 32, 128]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=5)}
recomputed metadata: {'shape': torch.Size([1, 50, 32, 128]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=5)}

Environment

Current environment
  • CUDA:
    • GPU:
      • NVIDIA A100-SXM4-40GB
      • NVIDIA A100-SXM4-40GB
      • NVIDIA A100-SXM4-40GB
      • NVIDIA A100-SXM4-40GB
      • NVIDIA A100-SXM4-40GB
      • NVIDIA A100-SXM4-40GB
      • NVIDIA A100-SXM4-40GB
      • NVIDIA A100-SXM4-40GB
    • available: True
    • version: 12.1
  • Lightning:
    • lightning-utilities: 0.10.0
    • pytorch-lightning: 2.1.3
    • torch: 2.1.2
    • torchaudio: 2.1.2
    • torchelastic: 0.2.2
    • torchmetrics: 1.2.1
    • torchvision: 0.16.2
  • Packages:
    • absl-py: 2.0.0
    • aiohttp: 3.9.1
    • aiosignal: 1.3.1
    • alembic: 1.13.1
    • annotated-types: 0.6.0
    • archspec: 0.2.1
    • asttokens: 2.0.5
    • astunparse: 1.6.3
    • async-timeout: 4.0.3
    • attrs: 23.1.0
    • backcall: 0.2.0
    • beautifulsoup4: 4.12.2
    • bitsandbytes: 0.42.0
    • boltons: 23.0.0
    • brotli: 1.0.9
    • cachetools: 5.3.2
    • certifi: 2023.11.17
    • cffi: 1.16.0
    • chardet: 4.0.0
    • charset-normalizer: 2.0.4
    • click: 8.1.7
    • colorlog: 6.8.0
    • conda: 23.9.0
    • conda-build: 3.28.1
    • conda-content-trust: 0.2.0
    • conda-index: 0.3.0
    • conda-libmamba-solver: 23.7.0
    • conda-package-handling: 2.2.0
    • conda-package-streaming: 0.9.0
    • cryptography: 41.0.7
    • datasets: 2.16.1
    • decorator: 5.1.1
    • dill: 0.3.7
    • distro: 1.8.0
    • dnspython: 2.4.2
    • einops: 0.7.0
    • exceptiongroup: 1.0.4
    • executing: 0.8.3
    • expecttest: 0.1.6
    • filelock: 3.13.1
    • flash-attn: 2.4.2
    • frozenlist: 1.4.1
    • fsspec: 2023.10.0
    • gmpy2: 2.1.2
    • google-auth: 2.26.1
    • google-auth-oauthlib: 1.2.0
    • greenlet: 3.0.3
    • grpcio: 1.60.0
    • huggingface-hub: 0.20.2
    • hypothesis: 6.92.0
    • idna: 3.4
    • ipython: 8.15.0
    • jedi: 0.18.1
    • jinja2: 3.1.2
    • jsonpatch: 1.32
    • jsonpointer: 2.1
    • jsonschema: 4.19.2
    • jsonschema-specifications: 2023.7.1
    • libarchive-c: 2.9
    • libmambapy: 1.5.3
    • lightning-utilities: 0.10.0
    • mako: 1.3.0
    • markdown: 3.5.2
    • markupsafe: 2.1.1
    • matplotlib-inline: 0.1.6
    • menuinst: 2.0.1
    • mkl-fft: 1.3.8
    • mkl-random: 1.2.4
    • mkl-service: 2.4.0
    • more-itertools: 10.1.0
    • mpmath: 1.3.0
    • multidict: 6.0.4
    • multiprocess: 0.70.15
    • networkx: 3.1
    • ninja: 1.11.1.1
    • numpy: 1.26.2
    • oauthlib: 3.2.2
    • optuna: 3.5.0
    • packaging: 23.1
    • pandas: 2.1.4
    • parso: 0.8.3
    • pexpect: 4.8.0
    • pickleshare: 0.7.5
    • pillow: 10.0.1
    • pip: 23.3.1
    • pkginfo: 1.9.6
    • platformdirs: 3.10.0
    • pluggy: 1.0.0
    • prompt-toolkit: 3.0.36
    • protobuf: 4.23.4
    • psutil: 5.9.0
    • ptyprocess: 0.7.0
    • pure-eval: 0.2.2
    • pyarrow: 14.0.2
    • pyarrow-hotfix: 0.6
    • pyasn1: 0.5.1
    • pyasn1-modules: 0.3.0
    • pycosat: 0.6.6
    • pycparser: 2.21
    • pydantic: 2.5.3
    • pydantic-core: 2.14.6
    • pygments: 2.15.1
    • pyopenssl: 23.2.0
    • pysocks: 1.7.1
    • python-dateutil: 2.8.2
    • python-etcd: 0.4.5
    • pytorch-lightning: 2.1.3
    • pytz: 2023.3.post1
    • pyyaml: 6.0.1
    • referencing: 0.30.2
    • regex: 2023.12.25
    • requests: 2.31.0
    • requests-oauthlib: 1.3.1
    • rpds-py: 0.10.6
    • rsa: 4.9
    • ruamel.yaml: 0.17.21
    • ruamel.yaml.clib: 0.2.6
    • safetensors: 0.4.1
    • scipy: 1.11.4
    • setuptools: 68.2.2
    • six: 1.16.0
    • sortedcontainers: 2.4.0
    • soupsieve: 2.5
    • sqlalchemy: 2.0.25
    • stack-data: 0.2.0
    • sympy: 1.12
    • tensorboard: 2.15.1
    • tensorboard-data-server: 0.7.2
    • tokenizers: 0.15.0
    • tomli: 2.0.1
    • toolz: 0.12.0
    • torch: 2.1.2
    • torchaudio: 2.1.2
    • torchelastic: 0.2.2
    • torchmetrics: 1.2.1
    • torchvision: 0.16.2
    • tqdm: 4.65.0
    • traitlets: 5.7.1
    • transformers: 4.36.2
    • triton: 2.1.0
    • truststore: 0.8.0
    • types-dataclasses: 0.6.6
    • typing-extensions: 4.7.1
    • tzdata: 2023.4
    • urllib3: 1.26.18
    • wcwidth: 0.2.5
    • werkzeug: 3.0.1
    • wheel: 0.41.2
    • xxhash: 3.4.1
    • yarl: 1.9.4
    • zstandard: 0.19.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
    • processor: x86_64
    • python: 3.10.13
    • release: 4.19.0-24-cloud-amd64
    • version: Proposal for help #1 SMP Debian 4.19.282-1 (2023-04-29)

More info

No response

@hrushikesh198 hrushikesh198 added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Jan 10, 2024
@hrushikesh198
Copy link
Author

I tried using pretrained_model_name='mistralai/Mistral-7B-Instruct-v0.2'.
It still gives tensor size mismatch error, irrespective which precision I use (bf16-mixed or true).
Seems like activation_checkpointing and fsdp are not working well together.

@hrushikesh198 hrushikesh198 changed the title bf16-true precision does not work with gradient checkpointing Recomputed tensor size does not match when using activation checkpointing in FSDP strategy Jan 11, 2024
@hrushikesh198
Copy link
Author

hrushikesh198 commented Jan 12, 2024

I was able to fix the issue by explicitly calling
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': True}) post model load inside configure_model() method.
Issue re-appears if I set use_reentrant:False in the above call.

@Smu-Tan
Copy link

Smu-Tan commented Jan 14, 2025

I was able to fix the issue by explicitly calling self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': True}) post model load inside configure_model() method. Issue re-appears if I set use_reentrant:False in the above call.

Did you experience significant performance degradation when setting gradient_checkpointing_kwargs={'use_reentrant': True} for deepspeed 3? A similar observation is this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.1.x
Projects
None yet
Development

No branches or pull requests

2 participants