Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Runtime error when attempting to use data distributed parallel #19

Closed
Phirefly9 opened this issue Jan 29, 2020 · 28 comments
Closed

Runtime error when attempting to use data distributed parallel #19

Phirefly9 opened this issue Jan 29, 2020 · 28 comments

Comments

@Phirefly9
Copy link

Thank you for putting in the time to do this. I have a bunch of ideas for it.

I crudely ported your example training script to use the pytorch-lightning library and when I attempted to use data distributed ran into a crash, The problem may be down in the revtorch library, but I want to hand the script off to you so you can play with it while reporting it so you can take a look and decide where the issue is.

you can get the crash by supplying the --distributed flag to the script with any number of gpus

Epoch 1:   0%|                                                                                                                                                                         | 0/1451 [00:00<?, ?batch/s]Traceback (most recent call last):
  File "example/train_lightning.py", line 166, in <module>
    main()
  File "example/train_lightning.py", line 161, in main
    trainer.fit(model)
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 687, in fit
    mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model,))
  File "/opt/conda/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 171, in spawn
    while not spawn_context.join():
  File "/opt/conda/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 118, in join
    raise Exception(msg)
Exception: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 19, in _wrap
    fn(i, *args)
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/distrib_data_parallel.py", line 331, in ddp_train
    self.run_pretrain_routine(model)
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 829, in run_pretrain_routine
    self.train()
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 332, in train
    self.run_training_epoch()
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 386, in run_training_epoch
    output = self.run_training_batch(batch, batch_idx)
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 506, in run_training_batch
    loss = optimizer_closure()
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 489, in optimizer_closure
    model_ref.backward(self.use_amp, closure_loss, optimizer)
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/core/hooks.py", line 154, in backward
    loss.backward()
  File "/opt/conda/lib/python3.6/site-packages/torch/tensor.py", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/opt/conda/lib/python3.6/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
  File "/opt/conda/lib/python3.6/site-packages/torch/autograd/function.py", line 77, in apply
    return self._forward_cls.backward(self, *args)
  File "/opt/conda/lib/python3.6/site-packages/revtorch/revtorch.py", line 161, in backward
    y, dy = ctx.reversible_blocks[i].backward_pass(y, dy)
  File "/opt/conda/lib/python3.6/site-packages/revtorch/revtorch.py", line 89, in backward_pass
    gy1.backward(dy2)
  File "/opt/conda/lib/python3.6/site-packages/torch/tensor.py", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/opt/conda/lib/python3.6/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Expected to mark a variable ready only once. This error is caused by use of a module parameter outside the `forward` function. The return value of the `forward` function is inspected by the distributed data parallel wrapper to figure out if any of the module's parameters went unused. If this is the case, it knows they won't receive gradients in a backward pass. If any of those parameters are then used outside `forward`, this error condition is triggered. You can disable unused parameter detection by passing the keyword argument `find_unused_parameters=False` to `torch.nn.parallel.DistributedDataParallel`.

script:

from reformer_pytorch import ReformerLM

import tqdm
import gzip
import numpy as np
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import Trainer

import os

import torch
from torch import nn
from torchvision import transforms

import argparse

import pytorch_lightning as pl

# constants

NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100

SEQ_LEN = 4096

# helpers

def cycle(loader):
    while True:
        for data in loader:
            yield data

with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq[0:-1], full_seq[1:]

    def __len__(self):
        return self.data.size(0) // self.seq_len

class ReformerTrainer(pl.LightningModule):

    def __init__(self, batch_size=4, distributed_mode=False):
        super(ReformerTrainer, self).__init__()
        self.batch_size = batch_size
        self.distributed_mode = distributed_mode
        # instantiate model
        self.model = ReformerLM(
            emb = 512,
            depth = 6,
            max_seq_len = SEQ_LEN,
            num_tokens = 256,
            heads = 8,
            bucket_size = 64,
            n_hashes = 4,
            ff_chunks = 10,
            lsh_dropout = 0.1,
            weight_tie = True,
            causal = True,
            use_full_attn = False # set this to true for comparison with full attention
        )

    def forward(self, x):
        pred = self.model(x).transpose(1, 2)
        return pred

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = F.cross_entropy(y_hat, y, reduction='mean')
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}
    
    def validation_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
        
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        return {'test_loss': F.cross_entropy(y_hat, y)}
    
    def test_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        tensorboard_logs = {'test_loss': avg_loss}
        return {'avg_test_loss': avg_loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=LEARNING_RATE)

    @pl.data_loader
    def train_dataloader(self):
        # REQUIRED
        dataset = TextSamplerDataset(data_train, SEQ_LEN)
        if self.distributed_mode:
            dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
            dataloader = DataLoader(dataset, sampler=dist_sampler, batch_size=self.batch_size)
        else:
            dataloader = DataLoader(dataset, batch_size=self.batch_size)
        return dataloader

    @pl.data_loader
    def val_dataloader(self):
        # OPTIONAL
        dataset = TextSamplerDataset(data_val, SEQ_LEN)
        if self.distributed_mode:
            dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
            dataloader = DataLoader(dataset, sampler=dist_sampler, batch_size=self.batch_size)
        else:
            dataloader = DataLoader(dataset, batch_size=self.batch_size)
        return dataloader

    @pl.data_loader
    def test_dataloader(self):
        dataset = TextSamplerDataset(data_val, SEQ_LEN)
        if self.distributed_mode:
            dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
            dataloader = DataLoader(dataset, sampler=dist_sampler, batch_size=self.batch_size)
        else:
            dataloader = DataLoader(dataset, batch_size=self.batch_size)
        return dataloader

def main():
    
    parser = argparse.ArgumentParser("reformer-lightning example")
    parser.add_argument("--gpus", default=1, help="gpus to use")
    parser.add_argument("-d", "--distributed", default=False, action="store_true",
                        help="activates distributed using data distributed parallel")
    parser.add_argument("-b", "--batch_size", type=int, default=4, help="batch_size")
    args = parser.parse_args()

    model = ReformerTrainer(args.batch_size, args.distributed)

    # most basic trainer, uses good defaults
    if args.distributed:
        trainer = Trainer(gpus=args.gpus, distributed_backend='ddp', accumulate_grad_batches=GRADIENT_ACCUMULATE_EVERY)
    else:
        trainer = Trainer(gpus=args.gpus, distributed_backend='dp', accumulate_grad_batches=GRADIENT_ACCUMULATE_EVERY)
    trainer.fit(model)
    trainer.test()


if __name__ == "__main__":
    main()
@lucidrains
Copy link
Owner

@Phirefly9 yes, I believe I ran into a related RevTorch error yesterday as well. Could you report this to the RevTorch issues so the author can chew on it?

@lucidrains
Copy link
Owner

@Phirefly9 oh, nevermind, I think they are unrelated

@Phirefly9
Copy link
Author

I'll try to create a minimal example without lightning just hitting revtorch and then open one. Assuming that is the problem area

@lucidrains
Copy link
Owner

Yea, it's related to how RevTorch manually handles the backward pass

@lucidrains
Copy link
Owner

@Phirefly9 I will try to integrate https://github.com/silvandeleemput/memcnn today and see if we can solve these issues that way. worst case, I become an expert with custom backprop and roll my own lol

@Phirefly9
Copy link
Author

Well I come back confused. my experiment with the example revtorch code actually didn't produce the same error, so I've written a different distributed training script for reformer without lightning that is able to train, this one uses Nvidia Apex so I could test half precision at the same time.

optimizer levels other than 'O0' produce the following error, here is the error for 'O1', so half precision is currently not working

File "train_apex_ddp.py", line 133, in <module>
    loss = get_batch_loss(model, next(train_loader))
  File "train_apex_ddp.py", line 126, in get_batch_loss
    pred = model(x)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/ifs_act3_home/clong/code/reformer-pytorch/reformer_pytorch/reformer_pytorch.py", line 427, in forward
    x = self.reformer(x)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/ifs_act3_home/clong/code/reformer-pytorch/reformer_pytorch/reformer_pytorch.py", line 414, in forward
    x = self.layers(x)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/revtorch/revtorch.py", line 192, in forward
    x = _ReversibleModuleFunction.apply(x, self.reversible_blocks)
  File "/opt/conda/lib/python3.6/site-packages/revtorch/revtorch.py", line 145, in forward
    x = block(x)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/revtorch/revtorch.py", line 49, in forward
    y1 = x1 + self.f_block(x2)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/ifs_act3_home/clong/code/reformer-pytorch/reformer_pytorch/reformer_pytorch.py", line 63, in forward
    return self.fn(x)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/ifs_act3_home/clong/code/reformer-pytorch/reformer_pytorch/reformer_pytorch.py", line 333, in forward
    outputs = process_inputs_chunk(self.lsh_attn, qk, v, chunks=self.attn_chunks)
  File "/ifs_act3_home/clong/code/reformer-pytorch/reformer_pytorch/reformer_pytorch.py", line 24, in process_inputs_chunk
    outputs = [fn(*input_pair) for input_pair in zip(*chunked_inputs)]
  File "/ifs_act3_home/clong/code/reformer-pytorch/reformer_pytorch/reformer_pytorch.py", line 24, in <listcomp>
    outputs = [fn(*input_pair) for input_pair in zip(*chunked_inputs)]
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/ifs_act3_home/clong/code/reformer-pytorch/reformer_pytorch/reformer_pytorch.py", line 165, in forward
    buckets = self.hash_vectors(n_buckets, qk)
  File "/ifs_act3_home/clong/code/reformer-pytorch/reformer_pytorch/reformer_pytorch.py", line 132, in hash_vectors
    rotated_vecs = torch.einsum('btf,bfhi->bhti', dropped_vecs, random_rotations)
  File "/opt/conda/lib/python3.6/site-packages/torch/functional.py", line 242, in einsum
    return torch._C._VariableFunctions.einsum(equation, operands)
RuntimeError: Expected object of scalar type Half but got scalar type Float for argument #2 'mat2' in call to _th_bmm

I recommend using the nvidia pytorch container because it has apex installed already
python -m torch.distributed.launch --nproc_per_node=4 train_apex_ddp.py -b 4

from reformer_pytorch import ReformerLM

import os
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

import argparse

try:
    from apex.parallel import DistributedDataParallel as DDP
    from apex.fp16_utils import *
    from apex import amp, optimizers
    from apex.multi_tensor_apply import multi_tensor_applier
except ImportError:
    raise ImportError("This code requires APEX")

parser = argparse.ArgumentParser()
parser.add_argument("-b", "--batch_size", type=int, default=4)
parser.add_argument("--local_rank", default=0, type=int, help="don't set this")
args = parser.parse_args()

args.distributed = False
if 'WORLD_SIZE' in os.environ:
    args.distributed = int(os.environ['WORLD_SIZE']) > 1

args.gpu = 0
args.world_size = 1

if args.distributed:
    args.gpu = args.local_rank
    torch.cuda.set_device(args.gpu)
    torch.distributed.init_process_group(backend='nccl',
                                            init_method='env://')
    args.world_size = torch.distributed.get_world_size()

# constants

NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100

SEQ_LEN = 4096

# helpers

def cycle(loader):
    while True:
        for data in loader:
            yield data

# instantiate model

model = ReformerLM(
    emb = 512,
    depth = 6,
    max_seq_len = SEQ_LEN,
    num_tokens = 256,
    heads = 8,
    bucket_size = 64,
    n_hashes = 4,
    ff_chunks = 10,
    lsh_dropout = 0.1,
    weight_tie = True,
    causal = True,
    use_full_attn = False # set this to true for comparison with full attention
)

model.cuda()

# prepare enwik8 data

with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq[0:-1].cuda(), full_seq[1:].cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

train_set = TextSamplerDataset(data_train, SEQ_LEN)
if args.distributed:
    dist_sampler = torch.utils.data.distributed.DistributedSampler(train_set)
    train_loader = cycle(DataLoader(train_set, sampler=dist_sampler, batch_size=args.batch_size))
else:
    train_loader = cycle(DataLoader(train_set, shuffle=True, batch_size=args.batch_size))

dev_set = TextSamplerDataset(data_val, SEQ_LEN)
if args.distributed:
    dist_sampler = torch.utils.data.distributed.DistributedSampler(dev_set)
    val_loader = cycle(DataLoader(dev_set, sampler=dist_sampler, batch_size=args.batch_size))
else:
    val_loader = cycle(DataLoader(dev_set, shuffle=True, batch_size=args.batch_size))

# optimizer
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

model, optim = amp.initialize(model, optim, opt_level='O0')

if args.distributed:
    net = torch.nn.parallel.DistributedDataParallel(model,
                                                    device_ids=[args.local_rank],
                                                    output_device=args.local_rank)

# training

def get_batch_loss(model, data):
    x, y = data
    pred = model(x)
    return F.cross_entropy(pred.transpose(1, 2), y, reduction='mean')

for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = get_batch_loss(model, next(train_loader))
        with amp.scale_loss(loss, optim) as scaled_loss:
            scaled_loss.backward()

    if args.local_rank == 0:
        print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i != 0 and i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = get_batch_loss(model, next(val_loader))
            if args.local_rank == 0:
                print(f'validation loss: {loss.item()}')

@lucidrains
Copy link
Owner

@Phirefly9 oh, that's a complete new error, and I understand why it doesn't work. i can put in a fix for that soon! (the rotation matrix for calculating LSH needs to be halved as well)

@lucidrains
Copy link
Owner

@Phirefly9 79974b4 can you try again?

@Phirefly9
Copy link
Author

git rev-parse HEAD
79974b46867e051d15cd9bebe491e4faa9584034

seems the same?

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
  0%|                                                                                                                                                                                   | 0/100000 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "train_apex_ddp.py", line 133, in <module>
    loss = get_batch_loss(model, next(train_loader))
  File "train_apex_ddp.py", line 126, in get_batch_loss
    pred = model(x)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/ifs_act3_home/clong/code/reformer-pytorch/reformer_pytorch/reformer_pytorch.py", line 483, in forward
    x = self.reformer(x, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/ifs_act3_home/clong/code/reformer-pytorch/reformer_pytorch/reformer_pytorch.py", line 466, in forward
    x = self.layers(x)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/revtorch/revtorch.py", line 192, in forward
    x = _ReversibleModuleFunction.apply(x, self.reversible_blocks)
  File "/opt/conda/lib/python3.6/site-packages/revtorch/revtorch.py", line 145, in forward
    x = block(x)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/revtorch/revtorch.py", line 49, in forward
    y1 = x1 + self.f_block(x2)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/ifs_act3_home/clong/code/reformer-pytorch/reformer_pytorch/reformer_pytorch.py", line 67, in forward
    return self.fn(x)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/ifs_act3_home/clong/code/reformer-pytorch/reformer_pytorch/reformer_pytorch.py", line 92, in forward
    return self.fn(x, *self.args, **self.kwargs)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/ifs_act3_home/clong/code/reformer-pytorch/reformer_pytorch/reformer_pytorch.py", line 402, in forward
    outputs = process_inputs_chunk(partial_attn_fn, qk, v, chunks=self.attn_chunks)
  File "/ifs_act3_home/clong/code/reformer-pytorch/reformer_pytorch/reformer_pytorch.py", line 26, in process_inputs_chunk
    outputs = [fn(*input_pair) for input_pair in zip(*chunked_inputs)]
  File "/ifs_act3_home/clong/code/reformer-pytorch/reformer_pytorch/reformer_pytorch.py", line 26, in <listcomp>
    outputs = [fn(*input_pair) for input_pair in zip(*chunked_inputs)]
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/ifs_act3_home/clong/code/reformer-pytorch/reformer_pytorch/reformer_pytorch.py", line 337, in forward
    dot = torch.einsum('bie,bje->bij', q, qk)
  File "/opt/conda/lib/python3.6/site-packages/torch/functional.py", line 242, in einsum
    return torch._C._VariableFunctions.einsum(equation, operands)
RuntimeError: Expected object of scalar type Half but got scalar type Float for argument #2 'mat2' in call to _th_bmm
Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/opt/conda/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/opt/conda/lib/python3.6/site-packages/torch/distributed/launch.py", line 253, in <module>
    main()
  File "/opt/conda/lib/python3.6/site-packages/torch/distributed/launch.py", line 249, in main
    cmd=cmd)
subprocess.CalledProcessError: Command '['/opt/conda/bin/python', '-u', 'train_apex_ddp.py', '--local_rank=0', '-b', '4']' returned non-zero exit status 1.

@lucidrains
Copy link
Owner

@Phirefly9 just made another commit, can you upgrade and try again?

@Phirefly9
Copy link
Author

I was able to train using the O1 optimization level, which is the standard one APEX typically recommends, I upped it to O2 and got

Selected optimization level O2:  FP16 training with FP32 batchnorm and FP32 master weights.

Defaults for this optimization level are:
enabled                : True
opt_level              : O2
cast_model_type        : torch.float16
patch_torch_functions  : False
keep_batchnorm_fp32    : True
master_weights         : True
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O2
cast_model_type        : torch.float16
patch_torch_functions  : False
keep_batchnorm_fp32    : True
master_weights         : True
loss_scale             : dynamic
  0%|                                                                                                                                                                                   | 0/100000 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "train_apex_ddp.py", line 133, in <module>
    loss = get_batch_loss(model, next(train_loader))
  File "train_apex_ddp.py", line 126, in get_batch_loss
    pred = model(x)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/apex/amp/_initialize.py", line 197, in new_fwd
    **applier(kwargs, input_caster))
  File "/ifs_act3_home/clong/code/reformer-pytorch/reformer_pytorch/reformer_pytorch.py", line 487, in forward
    x = self.reformer(x, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/ifs_act3_home/clong/code/reformer-pytorch/reformer_pytorch/reformer_pytorch.py", line 470, in forward
    x = self.layers(x)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/revtorch/revtorch.py", line 192, in forward
    x = _ReversibleModuleFunction.apply(x, self.reversible_blocks)
  File "/opt/conda/lib/python3.6/site-packages/revtorch/revtorch.py", line 145, in forward
    x = block(x)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/revtorch/revtorch.py", line 49, in forward
    y1 = x1 + self.f_block(x2)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/ifs_act3_home/clong/code/reformer-pytorch/reformer_pytorch/reformer_pytorch.py", line 71, in forward
    return self.fn(x)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/ifs_act3_home/clong/code/reformer-pytorch/reformer_pytorch/reformer_pytorch.py", line 96, in forward
    return self.fn(x, *self.args, **self.kwargs)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/ifs_act3_home/clong/code/reformer-pytorch/reformer_pytorch/reformer_pytorch.py", line 391, in forward
    x = torch.cat((x, mem, keys), dim=1)
RuntimeError: Expected object of scalar type Half but got scalar type Float for sequence element 2 in sequence argument at position #1 'tensors'

@lucidrains
Copy link
Owner

apex is not very seamless lol, ok i'll look into it

@lucidrains
Copy link
Owner

oh, my error lol, ok fixed. can you try again?

@Phirefly9
Copy link
Author

O2 and O3 both run through on that commit. nice work!

I would be fine closing this issue on that, I don't know what the deal with pytorch-lightning is, but I think we are both in agreement it is probably down in revtorch. So I can play with lightning some more and see if I can isolate the issue in revtorch using it

@lucidrains
Copy link
Owner

Robin from RevTorch got back to me saying he has little time. If you can figure out what's wrong with his implementation, let me know. I'm going to investigate memcnn in the meanwhile, although that repository will require a PR as well to split on the right dimension, even if it does work.

@lucidrains
Copy link
Owner

To be honest, custom backprop scares me lol

@lucidrains
Copy link
Owner

@Phirefly9 RobinBruegger/RevTorch#8 may fix the issue, but I'm not entirely sure. have you tried setting the find_unused_parameters flag as recommended in the error?

@Phirefly9
Copy link
Author

I'm sure it will work if I add that flag. I'm thinking it's just a lightning bug, I've trained revtorch's example using lightning and it worked in distributed, and I've been looking all over the code and don't see anything. At his point I think it's a bug with lightning. If I can't find the issue after some more searching I'll open a ticket with them.

@lucidrains
Copy link
Owner

@Phirefly9 Robin and I fixed an issue with RevTorch to allow for multiple backward passes. do you think you could try the above again and see if it incidentally fixed your issue?

@Phirefly9
Copy link
Author

It did not unfortunately. I've opened up an issue on pytorch-lightning and hope to hear from them soon

@fcampagne
Copy link

I would also like to train Reformer from this repo with DistributedDataParallel. Is the current workaround to use DistributedDataParallel from Apex as a drop-in replacement for the pytorch implementation, or is it sufficient to call amp.initialize(model, optim, opt_level='O0') and proceed with the pytorch implementation of DistributedDataParallel as shown in the code above?

@zbloss
Copy link
Contributor

zbloss commented Feb 27, 2020

We may be able to use DistributedDataParellel, but I am currently trying to utilize Microsoft's new DeepSpeed library for distributed training

@Phirefly9
Copy link
Author

@fcampagne I would recommend distributeddataparallel, you will find it's faster is most cases.

Deepspeed is probably the new standard though. it integrates APEX and distributedDataParallel as well as other improvements. The other benifit is that is usually only 1 line change from a 1 gpu pytorch script

@fcampagne
Copy link

I answered my own question and found that is is necessary to use DDP from Apex (i.e., from apex.parallel import DistributedDataParallel as DDP) instead of the pytorch implementation.
Calling amp.initialize(model, optim, opt_level='O0') was not necessary for me to prevent the exception (Expected to mark a variable ready only once). I will try adding it to see if performance improves though.

@fcampagne
Copy link

I looked at DeepSpeed as well, looks good, but stuck on pytorch 1.2 as far as their supported dependencies. We're on 1.4 already.

@lucidrains
Copy link
Owner

@Phirefly9 @zbloss @justindujardin @fcampagne Guys! I got DeepSpeed working with Reformer after the latest Reversible Net changes! It's blazing fast! (using it in place of DataParallel locally)

@lucidrains
Copy link
Owner

I'm not sure about distributed, but the parallelism Deepspeed provided even on my two GPUs at home is world's faster. You can follow the example at https://github.com/lucidrains/reformer-pytorch/tree/master/examples/enwik8_deepspeed

@lucidrains
Copy link
Owner

closing because of independent replication of Deepspeed training in other issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants