Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

memory use continuously increasing #486

Open
grez72 opened this issue Jan 28, 2019 · 15 comments
Open

memory use continuously increasing #486

grez72 opened this issue Jan 28, 2019 · 15 comments
Labels
enhancement New feature or request

Comments

@grez72
Copy link

grez72 commented Jan 28, 2019

I'm working from the tutorials for integrating DALI with pytorch, aiming to train models on ImageNet. But I think I'm running into the "memory leak" / "continuously growing memory" issues mentioned in (#344, and #278), although none of the suggestions in those issues solved my problem.

I'm using Nvidia Dali 0.6.1, with ubuntu 16.04, cuda10.0, cudnn7.4.1, pytorch v1.0.0

I'm using a hybrid pipline and the DALIGenericIterator from the pytorch plugin.

from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIGenericIterator, DALIClassificationIterator
import nvidia.dali.ops as ops
import nvidia.dali.types as types

def ram_use():
    import psutil
    pid = os.getpid()
    py = psutil.Process(pid)
    memoryUse = py.memory_info()[0] / 2. ** 30  # memory use in GB...I think
    return memoryUse

class ImageNetPipeline(Pipeline):
    def __init__(self, image_dir, batch_size, num_threads, device_id, exec_async=True):
        super(ImageNetPipeline, self).__init__(batch_size, num_threads, device_id, seed = 12, exec_async=exec_async)
        self.input = ops.FileReader(file_root = image_dir, random_shuffle = True, initial_fill = 21)
        self.decode = ops.nvJPEGDecoder(device = "mixed", output_type = types.RGB)
        self.resize = ops.Resize(device = "gpu", resize_shorter=224)
        self.centerCrop = ops.Crop(device = "gpu", crop=(224,224))
        self.norm = ops.NormalizePermute(device = "gpu",
                                            height = 224,
                                            width = 224,
                                            mean = [x*255 for x in [0.485, 0.456, 0.406]],
                                            std = [x*255 for x in [0.229, 0.224, 0.225]])
    
    def define_graph(self):
        jpegs, labels = self.input()
        images = self.decode(jpegs)
        images = self.resize(images)
        images = self.centerCrop(images)
        images = self.norm(images)
        
        # images are on the GPU
        return (images, labels)

N = 2 # number of GPUs
BATCH_SIZE = 128  # 128, batch size per GPU
ITERATIONS = 32
NUM_THREADS = 8

train_dir = "/data/local_hdd/ImageSets/imagenet/ILSRC2012/train"

pipes = [ImageNetPipeline(image_dir=train_dir, batch_size=BATCH_SIZE, num_threads=NUM_THREADS, device_id=device_id) for device_id in range(N)]
pipes[0].build()
train_iter = DALIGenericIterator(pipes, ['data', 'label'], pipes[0].epoch_size().popitem()[1])

When I iterate through the dataset (not model training, just iterating), things go blazingly fast (5000 images/s), but only up until about 90% of the dataset has been loaded (so close!), at which things slow to a near standstill. During that time, my RAM useage steadily increases by 6-7 GB, (e.g., starting from 5GB to about 12.5GB). I'm not sure why things stall at 12.5GB (the machine has 128 GB of RAM), but this is consistent across many attempted runs.

batch_no = 0
count = 0
for data in train_iter:
    for batch in data:
        batch_no += 1
        count += batch['data'].shape[0]
    if batch_no % 100 == 0:
        print(batch_no, count, ram_use())

image

I made my own copy of DALIGenericIterator to determine the source of the issues. It seems that calling p._share_outputs() increases the memory use. If I "simuluate" iterations without this function call (by calling p._share_outputs once during the first batch, storing the outputs, and just working with the same outputs on each iteration), then the memory doesn't grow.

Is it expected that memory use would grow on each iteration/call to p._share_outputs()?

Is it possible that p._release_outputs() is not releasing memory?

Since _share_outputs, _release_outputs are core functions, I wasn't sure how to further debug this issue.

Many thanks in advance for your help.

@JanuszL
Copy link
Contributor

JanuszL commented Jan 28, 2019

Hi,
This slow down looks strange. Maybe there is a problem with the disk cache that is trashed and at the end of the epoch, and you are blocked by the disk IO.
Regarding memory consumption I can clearly see you saturate around 12GB - that is good, I guess you can get double of that for the validation pipeline.
Regarding iterator - _share_outputs returns you ready buffer and launches next iteration running DALI pipeline. So if you are skipping _share_outputs then DALI is not working. _release_outputs job is to recycle buffer returned by _share_outputs. If you skip _release_outputs then you will hang on _share_outputs waiting for free buffers inside the pipeline.

@grez72
Copy link
Author

grez72 commented Jan 28, 2019

Hi,

Thanks for your quick response. Can you say more about the disk cache being trashed at the end of the epoch? I've never successfully reached the end of an epoch with the loader (it stalls at 90% when the memory consumption saturates). Does this mean I need to figure out how to trash the disk cache before the end of the epoch? Any suggestions on how I can do that?

Thanks!

@JanuszL
Copy link
Contributor

JanuszL commented Jan 29, 2019

I mean that OS keeps data form HD in RAM file caches. I guess in your case, data set may not fully fit into your RAM and at the end, files are accessed not from RAM cache, but from HD directly.
And by the stall, you mean that it hangs or it is just very slow? Have you tried to run the training on the smaller data set?

@grez72
Copy link
Author

grez72 commented Jan 30, 2019

Ah, ok. I'm not 100% sure how to check RAM file cache usage. Do you know how to do that from within python so I can verify? I'm surprised the entire dataset has to fit into RAM - can't memory be released after each iteration/batch?

When the iteration slows, it doesn't completely stop (but it's very very slow, and instead of finishing in 15 min, it would finish hours later). I have been able to iterate through a smaller data set without any issues.

If you have any tips on how to monitor RAM cache usage within python (or via linux terminal), that would be greatly appreciated. In the meantime, I'm making a copy of the image net training set, reducing the file sizes. I will test whether I can iterate through this smaller (in GB) training set, and post a comment here.

Thanks!

@JanuszL
Copy link
Contributor

JanuszL commented Jan 31, 2019

Hi,
I was thinking about how your OS works and etc.

I'm surprised the entire dataset has to fit into RAM - can't memory be released after each iteration/batch?

I mean that if you have data on normal HD, then access to it is rather slow comparing to SSD, and OS tries to cache his accesses to make it faster. So I just wonder that if HD cache is full then OS is not providing you data from RAM but from HD directly and this may be the source of the slow down, but it is just my ques.
Could you check how your disk IO, CPU and GPU utilization looks like when the training stalls, using, for example, https://unix.stackexchange.com/questions/55212/how-can-i-monitor-disk-io, nvidia-smi and top/htop?
Because I don't see any reason why having plenty of RAM memory free you get this perf drop.
Can you also try to comment out training part of your script and just iterate over the data set using DALI pipeline?

@JanuszL JanuszL added bug Something isn't working question Further information is requested labels Jan 31, 2019
@grez72
Copy link
Author

grez72 commented Jan 31, 2019

Thanks again for all of your help with this.

First I resized the images in the ImageNet ILSRC2012 training set to be 256x256x3 (first resizing the shortest edge to 256px, preserving aspect ratio, then center cropping to 256x256). Now the loader iterates at over 21,000 images/s, going through the entire training set in about 60s (no model training, just iterating through images). So, from a practical perspective, the problem is solved.

Nevertheless, I was curious about the slowdown I was having with the original, so I followed the stackoverflow link and monitored disk io using the command "sar -u 1 2". For this test, I iterated through the training set with no model training (just looping through images).

sar output before the slowdown:

Linux 4.4.0-135-generic (nolan) 01/31/2019 x86_64 (16 CPU)
04:46:08 PM CPU %user %nice %system %iowait %steal %idle
04:46:09 PM all 4.52 0.00 5.16 0.00 0.00 90.32
04:46:10 PM all 4.73 0.00 4.79 0.00 0.00 90.49
Average: all 4.62 0.00 4.97 0.00 0.00 90.40

sar output after the slowdown:

Linux 4.4.0-135-generic (nolan) 01/31/2019 x86_64 (16 CPU)
04:45:00 PM CPU %user %nice %system %iowait %steal %idle
04:45:01 PM all 6.69 0.00 3.75 11.38 0.00 78.19
04:45:02 PM all 7.98 0.00 3.52 11.18 0.00 77.32
Average: all 7.33 0.00 3.63 11.28 0.00 77.76

The %iowait jumped from 0.00% to >11.0%, which seems to confirm your guess? The drive is a 4TB HDD, with only a 64MB cache. I might try replacing the drive to one with a larger cache (256 MB) to see if that improves things.

I don't know enough about how the HDD caching system depends on the actual size of the files, but as noted above, everything runs spectacularly fast after I reduced the file sizes.

Thanks for your time and effort helping address this issue (sorry it turned out likely to be a hardware issue!).

@JanuszL
Copy link
Contributor

JanuszL commented Feb 1, 2019

Hi,
Glad we could help. You can also check https://www.linuxatemyram.com/play.html to see how OS is caching your disc access. This 12GB you observed was only memory allocated by the python process, but OS itself would cache disc accesses. I don't know if swapping HDD to one with bigger internal cache would provide you a significant boost, I would rather target one with the lower random IO access time.

@zeakey
Copy link

zeakey commented Mar 8, 2019

hI @JanuszL , I seem to run into the same error.
I was training ImageNet with resnet18 the the program became extremely slow (500+ sec per batch) after several epochs training.

Meanwhile the system got stuck and response very slowly.

My environment:

  • DALI:
    0.6.1

  • OS
    Linux Satan 4.15.0-46-generic #49-Ubuntu SMP Wed Feb 6 09:33:07 UTC 2019 x86_64 x86_64 x86_64 GNU/Linux

  • PyTorch
    1.0.1.post2

  • nvidia-smi:

Sat Mar  9 01:38:15 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 390.12                 Driver Version: 390.12                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  TITAN Xp            Off  | 00000000:02:00.0  On |                  N/A |
| 51%   81C    P2   108W / 250W |   2723MiB / 12195MiB |     87%      Default |
+-------------------------------+----------------------+----------------------+
|   1  TITAN Xp            Off  | 00000000:03:00.0 Off |                  N/A |
| 52%   82C    P2    99W / 250W |   2455MiB / 12196MiB |     86%      Default |
+-------------------------------+----------------------+----------------------+
|   2  TITAN Xp            Off  | 00000000:82:00.0 Off |                  N/A |
| 52%   82C    P2   102W / 250W |   2323MiB / 12196MiB |     87%      Default |
+-------------------------------+----------------------+----------------------+
|   3  TITAN Xp            Off  | 00000000:83:00.0 Off |                  N/A |
| 56%   84C    P2   207W / 250W |   8074MiB / 12196MiB |     99%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0      1456      G   /usr/lib/xorg/Xorg                           108MiB |
|    0      2568      G   compiz                                        47MiB |
|    0     39561      C   python3                                     2553MiB |
|    1     39561      C   python3                                     2441MiB |
|    2     39561      C   python3                                     2309MiB |
|    3      8039      C   python3                                     5705MiB |
|    3     39561      C   python3                                     2345MiB |
+-----------------------------------------------------------------------------+
  • nvcc --version:
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2017 NVIDIA Corporation
Built on Fri_Nov__3_21:07:56_CDT_2017
Cuda compilation tools, release 9.1, V9.1.85

Part of my code is:

# DALI data loader
NUM_GPUS = torch.cuda.device_count()
class HybridTrainPipe(Pipeline):
    def __init__(self, batch_size, num_threads, device_id, data_dir, crop, num_gpus, dali_cpu=False):
        super(HybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id)
        self.input = ops.MXNetReader(path = [data_dir+"/rec/train.rec"], 
        index_path=[data_dir+"/rec/train.idx"], random_shuffle = True, shard_id = device_id, num_shards = num_gpus)
        
        #self.input = ops.FileReader(file_root=data_dir, shard_id=0, num_shards=4, random_shuffle=True)
        #let user decide which pipeline works him bets for RN version he runs
        
        if dali_cpu:
            dali_device = "cpu"
            self.decode = ops.HostDecoder(device=dali_device, output_type=types.RGB)
        else:
            dali_device = "gpu"
            # This padding sets the size of the internal nvJPEG buffers to be able to handle all images from full-sized ImageNet
            # without additional reallocations
            self.decode = ops.nvJPEGDecoder(device="mixed", output_type=types.RGB)

        self.rrc = ops.RandomResizedCrop(device=dali_device, size =(crop, crop), interp_type=types.INTERP_CUBIC, random_area=[0.2, 1])
        self.cmnp = ops.CropMirrorNormalize(device="gpu",
                                            output_dtype=types.FLOAT,
                                            output_layout=types.NCHW,
                                            crop=(crop, crop),
                                            image_type=types.RGB,
                                            mean=[0.485 * 255,0.456 * 255,0.406 * 255],
                                            std=[0.229 * 255,0.224 * 255,0.225 * 255])
        self.coin = ops.CoinFlip(probability=0.5)
        print('DALI "{0}" variant'.format(dali_device))

    def define_graph(self):
        rng = self.coin()
        self.jpegs, self.labels = self.input(name="Reader")
        images = self.decode(self.jpegs)
        images = self.rrc(images)
        output = self.cmnp(images.gpu(), mirror=rng)
        return [output, self.labels]

class HybridValPipe(Pipeline):
    def __init__(self, batch_size, num_threads, device_id, data_dir, crop, size, num_gpus, dali_cpu=False):
        super(HybridValPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id)
        self.input = ops.MXNetReader(path = [data_dir+"/rec/val.rec"], index_path=[data_dir+"/rec/val.idx"],
                                     random_shuffle = False, shard_id = device_id, num_shards = num_gpus)
        
        #self.input = ops.FileReader(file_root=data_dir, shard_id=0, num_shards=4, random_shuffle=False)
        
        if dali_cpu:
            dali_device = "cpu"
            self.decode = ops.HostDecoder(device=dali_device, output_type=types.RGB)
        else:
            dali_device = "gpu"
            # This padding sets the size of the internal nvJPEG buffers to be able to handle all images from full-sized ImageNet
            # without additional reallocations
            self.decode = ops.nvJPEGDecoder(device="mixed", output_type=types.RGB)
        #self.decode = ops.nvJPEGDecoder(device="mixed", output_type=types.RGB)
        self.res = ops.Resize(device=dali_device, resize_shorter=size, interp_type=types.INTERP_CUBIC)
        self.cmnp = ops.CropMirrorNormalize(device="gpu",
                                            output_dtype=types.FLOAT,
                                            output_layout=types.NCHW,
                                            crop=(crop, crop),
                                            image_type=types.RGB,
                                            mean=[0.485 * 255,0.456 * 255,0.406 * 255],
                                            std=[0.229 * 255,0.224 * 255,0.225 * 255])

    def define_graph(self):
        self.jpegs, self.labels = self.input(name="Reader")
        images = self.decode(self.jpegs)
        images = self.res(images)
        output = self.cmnp(images)
        return [output, self.labels]

# train loader
pipes = [HybridTrainPipe(batch_size=int(CONFIGS["DATA"]["BS"]/NUM_GPUS), num_threads=2, device_id=device_id, data_dir=CONFIGS["DATA"]["DIR"], crop=224, num_gpus=NUM_GPUS, dali_cpu=False) for device_id in range(NUM_GPUS)]
pipes[0].build()
train_loader = plugin_pytorch.DALIClassificationIterator(pipes, size=int(pipes[0].epoch_size("Reader")))

# val loader
pipes = [HybridValPipe(batch_size=int(100/NUM_GPUS), num_threads=2, device_id=device_id, data_dir=CONFIGS["DATA"]["DIR"], crop=224, size=256, num_gpus=NUM_GPUS, dali_cpu=False) for device_id in range(NUM_GPUS)]
pipes[0].build()
val_loader = plugin_pytorch.DALIClassificationIterator(pipes, size=int(pipes[0].epoch_size("Reader")))

Updated:

2019-03-09 2

The RAM has been eaten up after 5 epochs iteration!

I hope these information will help you guys localize the bug and make DALI better and stronger.

@suruoxi
Copy link

suruoxi commented May 9, 2019

Meet the same error, running on GPU.

I use e pipeline almost the same with official example code for pytorch. GPU card 0 is used for DALI pipeline, while GPU card 1-7 are used for training. I train ResNet18 on ImageNet dataset with batch size 1792(256*7). The GPU memory used by card 0 increases continuously until a "out of memory" error. The memory usage increase at about 35MB/epoch.

class HybridTrainPipe(Pipeline):
    def __init__(self, batch_size, num_threads, device_id, data_dir, crop, dali_cpu=False):
        super(HybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id)
        self.input = ops.FileReader(file_root=data_dir, shard_id=0, num_shards=1, random_shuffle=True)
        #let user decide which pipeline works him bets for RN version he runs
        if dali_cpu:
            dali_device = "cpu"
            self.decode = ops.HostDecoderRandomCrop(device=dali_device, output_type=types.RGB,
                                                    random_aspect_ratio=[0.8, 1.25],
                                                    random_area=[0.1, 1.0],
                                                    num_attempts=100)
        else:
            dali_device = "gpu"
            # This padding sets the size of the internal nvJPEG buffers to be able to handle all images from full-sized ImageNet
            # without additional reallocations
            self.decode = ops.nvJPEGDecoderRandomCrop(device="mixed", output_type=types.RGB, device_memory_padding=211025920, host_memory_padding=140544512,
                                                      random_aspect_ratio=[0.8, 1.25],
                                                      random_area=[0.1, 1.0],
                                                      num_attempts=100)
        self.res = ops.Resize(device=dali_device, resize_x=crop, resize_y=crop, interp_type=types.INTERP_TRIANGULAR)
        self.cmnp = ops.CropMirrorNormalize(device="gpu",
                                            output_dtype=types.FLOAT,
                                            output_layout=types.NCHW,
                                            crop=(crop, crop),
                                            image_type=types.RGB,
                                            mean=[0.485 * 255,0.456 * 255,0.406 * 255],
                                            std=[0.229 * 255,0.224 * 255,0.225 * 255])
        self.coin = ops.CoinFlip(probability=0.5)
        print('DALI "{0}" variant'.format(dali_device))

    def define_graph(self):
        rng = self.coin()
        self.jpegs, self.labels = self.input(name="Reader")
        images = self.decode(self.jpegs)
        images = self.res(images)
        #output = self.cmnp(images.gpu(), mirror=rng)
        output = self.cmnp(images.gpu(), mirror=rng)
        return [output, self.labels]

def ImageNet(batch_sz, num_workers=16):
    world_size = 1
    rootdir = '/home/futian.zp/data/imagenet/'
    normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

    pipe = HybridTrainPipe(batch_size=batch_sz, num_threads=num_workers, device_id=0,data_dir=rootdir + 'train', crop=224, dali_cpu=False)
    pipe.build()
    train_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size("Reader")/world_size))
    train_loader.num_classes = 1000

    return train_loader

@JanuszL
Copy link
Contributor

JanuszL commented May 10, 2019

Hi,
@suruoxi - I won't recommend configuration of 1+ 7. As the workload and memory consumption is not equally distributed across the GPUs. The recommended way of using DALI is to have one DALI instance per GPU so memory is equally consumed,
I made some simple test with https://github.com/NVIDIA/DALI/blob/master/dali/test/python/test_RN50_data_pipeline.py (python test_RN50_data_pipeline.py -g 1 -b 1792 --epochs 10) and after 8 epochs with RN50 data pipeline on raw ImageNet I got ~8GB of GPU memory.
There is also some dependency between memory consumption and nvJPEG. In the base version of nvJPEG it creates a set of internal buffers per every worker. If you set a number of workers to some very high value - like 30, your memory consumption could jump. You can try to use split_stages parameter which changes a bit how it operates. Long story short, the set of intermediate buffers is fixed, depends only on your batch size and independent on the number of workers.
Following run python test_RN50_data_pipeline.py -g 1 -b 1792 --epochs 10 --workers 30 --decoder_type split -j 30 gave me around 7.5 GB after 10 epochs.

@JinyangGuo
Copy link

Hi,
@JanuszL - In your code each GPU has an instance to process the image. But it looks like it does not include the training code. How to construct a iterator used by pytorch in this case?

@JanuszL
Copy link
Contributor

JanuszL commented Oct 7, 2019

@un-knight
Copy link

I meet the same problem, the GPU memory usage increased with time step by, and finally, lead to one of my processes crashed down because of out-of-memory under data parallel distributed training.
My experiment dataset is imagenet-1k and my training model is mobilenetv1 as well as resnet50.

Here are some error logs:

Traceback (most recent call last):
  File "distributed_training.py", line 127, in <module>
    main()
  File "distributed_training.py", line 92, in main
    train(train_loader, model, criterion, optimizer, epoch, tb_writer, args, lr_scheduler)
  File "/workspace/image-classification/libs/clf_train.py", line 47, in train
    for batch_idx, data in enumerate(train_loader):
  File "/opt/anaconda/lib/python3.7/site-packages/nvidia/dali/plugin/pytorch.py", line 150, in __next__
    outputs.append(p.share_outputs())
  File "/opt/anaconda/lib/python3.7/site-packages/nvidia/dali/pipeline.py", line 402, in share_outputs
    return self._pipe.ShareOutputs()
RuntimeError: Critical error in pipeline: Error in thread 4: [/opt/dali/dali/pipeline/operators/decoder/nvjpeg/decoupled_api/nvjpeg_decoder_decoupled_api.h:351] NVJPEG error "5" : NVJPEG_STATUS_ALLOCATOR_FAILURE n04418357/n04418357_26036.JPEG
Stacktrace (7 entries):
[frame 0]: /opt/anaconda/lib/python3.7/site-packages/nvidia/dali/libdali.so(+0xb410e) [0x7fa5388b010e]
[frame 1]: /opt/anaconda/lib/python3.7/site-packages/nvidia/dali/libdali.so(+0x13b72d) [0x7fa53893772d]
[frame 2]: /opt/anaconda/lib/python3.7/site-packages/nvidia/dali/libdali.so(+0x13c4ac) [0x7fa5389384ac]
[frame 3]: /opt/anaconda/lib/python3.7/site-packages/nvidia/dali/libdali.so(dali::ThreadPool::ThreadMain(int, int, bool)+0x1b9) [0x7fa5389d94d9]
[frame 4]: /opt/anaconda/lib/python3.7/site-packages/nvidia/dali/libdali.so(+0xe205d0) [0x7fa53961c5d0]
[frame 5]: /lib/x86_64-linux-gnu/libpthread.so.0(+0x76db) [0x7fa5935d06db]
[frame 6]: /lib/x86_64-linux-gnu/libc.so.6(clone+0x3f) [0x7fa5932f988f]

Current pipeline object is no longer valid.
An error occurred in nvJPEG worker thread:
Error in thread 4: [/opt/dali/dali/pipeline/operators/decoder/nvjpeg/decoupled_api/nvjpeg_decoder_decoupled_api.h:358] NVJPEG error "6" : NVJPEG_STATUS_EXECUTION_FAILED n03877845/n03877845_719.JPEG
Stacktrace (7 entries):
[frame 0]: /opt/anaconda/lib/python3.7/site-packages/nvidia/dali/libdali.so(+0xb410e) [0x7fa5388b010e]
[frame 1]: /opt/anaconda/lib/python3.7/site-packages/nvidia/dali/libdali.so(+0x13b4ac) [0x7fa5389374ac]
[frame 2]: /opt/anaconda/lib/python3.7/site-packages/nvidia/dali/libdali.so(+0x13c4ac) [0x7fa5389384ac]
[frame 3]: /opt/anaconda/lib/python3.7/site-packages/nvidia/dali/libdali.so(dali::ThreadPool::ThreadMain(int, int, bool)+0x1b9) [0x7fa5389d94d9]
[frame 4]: /opt/anaconda/lib/python3.7/site-packages/nvidia/dali/libdali.so(+0xe205d0) [0x7fa53961c5d0]
[frame 5]: /lib/x86_64-linux-gnu/libpthread.so.0(+0x76db) [0x7fa5935d06db]
[frame 6]: /lib/x86_64-linux-gnu/libc.so.6(clone+0x3f) [0x7fa5932f988f]

Traceback (most recent call last):
  File "/opt/anaconda/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/opt/anaconda/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/opt/anaconda/lib/python3.7/site-packages/torch/distributed/launch.py", line 246, in <module>
    main()
  File "/opt/anaconda/lib/python3.7/site-packages/torch/distributed/launch.py", line 242, in main
    cmd=cmd)
subprocess.CalledProcessError: Command '['/opt/anaconda/bin/python', '-u', 'distributed_training.py', '--local_rank=7']' returned non-zero exit status 1.

I think it may be out-of-memory caused NVJPEG_STATUS_ALLOCATOR_FAILURE. It is odd memory usage increased with training stage step by, doesn't dali free data queue memory after each epoch?

@JanuszL
Copy link
Contributor

JanuszL commented Oct 11, 2019

@un-knight - DALI doesn't free memory after each step/epoch as allocation on the GPU is very time-consuming. What DALI does is lazy reallocation only when currently available memory is not sufficient.
In your case, I would recommend limiting the number of DALI threads per GPU (usually 3 is enough), try smaller batch size

@un-knight
Copy link

@un-knight - DALI doesn't free memory after each step/epoch as allocation on the GPU is very time-consuming. What DALI does is lazy reallocation only when currently available memory is not sufficient.
In your case, I would recommend limiting the number of DALI threads per GPU (usually 3 is enough), try smaller batch size

Yep, I have tried a smaller num_threads to avoid out-of-memory. thanks for your reply!

@JanuszL JanuszL removed the question Further information is requested label Jan 21, 2020
@JanuszL JanuszL added enhancement New feature or request and removed bug Something isn't working labels Apr 16, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

6 participants