Skip to content

Commit

Permalink
Improve Cython Lifetime Management by Adding References in `DeviceBuf…
Browse files Browse the repository at this point in the history
…fer` (#661)

As discussed with @shwina @harrism and @kkraus14, this PR adds 2 properties to `DeviceBuffer` to allow for automatic reference counting of `MemoryResource` and `Stream` objects. This will prevent any `MemoryResource` from being destructed while any `DeviceBuffer` that needs the MR for deallocation is still alive.

There are a few outstanding issues I could use input on:

1. The test `test_rmm_device_buffer` is failing due to the line: `sys.getsizeof(b) == b.size`. Need input on the best way forward.
   1. This test is failing since `DeviceBuffer` is now involved in GC. Python automatically adds the GC memory overhead to `__size__` (see [here](https://github.com/python/cpython/blob/master/Python/sysmodule.c#L1701)) which makes it difficult to continue working the same way it has before.
   1. Only options I can think of are:
      1. Remove this check from the test or alter the "correct" value
      1. Add `@cython.no_gc` which is very risky.
1. The current PR implementation includes cuda stream object reference counting but treats all `Stream` objects the same. @harrism mentioned only streams owned by RMM should be tracked this way but I am not sure if thats necessary or how to distinguish them at this point.

Other than the above items, all test are passing and I ran this through the cuML test suite without any issues. Thanks for your help.

Authors:
  - Michael Demoret (@mdemoret-nv)

Approvers:
  - Keith Kraus (@kkraus14)

URL: #661
  • Loading branch information
mdemoret-nv authored Feb 25, 2021
1 parent dc5889b commit 230369d
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 11 deletions.
10 changes: 10 additions & 0 deletions python/rmm/_lib/device_buffer.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ from libc.stdint cimport uintptr_t

from rmm._lib.cuda_stream_view cimport cuda_stream_view
from rmm._cuda.stream cimport Stream
from rmm._lib.memory_resource cimport MemoryResource


cdef extern from "rmm/device_buffer.hpp" namespace "rmm" nogil:
Expand All @@ -38,6 +39,15 @@ cdef extern from "rmm/device_buffer.hpp" namespace "rmm" nogil:
cdef class DeviceBuffer:
cdef unique_ptr[device_buffer] c_obj

# Holds a reference to the MemoryResource used for allocation. Ensures the
# MR does not get destroyed before this DeviceBuffer. `mr` is needed for
# deallocation
cdef MemoryResource mr

# Holds a reference to the stream used by the underlying `device_buffer`.
# Ensures the stream does not get destroyed before this DeviceBuffer
cdef Stream stream

@staticmethod
cdef DeviceBuffer c_from_unique_ptr(unique_ptr[device_buffer] ptr)

Expand Down
7 changes: 7 additions & 0 deletions python/rmm/_lib/device_buffer.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ from libcpp.utility cimport move

from rmm._cuda.gpu cimport cudaError, cudaError_t
from rmm._cuda.stream cimport Stream

from rmm._cuda.stream import DEFAULT_STREAM

from rmm._lib.lib cimport (
cudaMemcpyAsync,
cudaMemcpyDeviceToDevice,
Expand All @@ -32,6 +34,7 @@ from rmm._lib.lib cimport (
cudaStream_t,
cudaStreamSynchronize,
)
from rmm._lib.memory_resource cimport get_current_device_resource


cdef class DeviceBuffer:
Expand Down Expand Up @@ -81,6 +84,10 @@ cdef class DeviceBuffer:
if stream.c_is_default():
stream.c_synchronize()

# Save a reference to the MR and stream used for allocation
self.mr = get_current_device_resource()
self.stream = stream

def __len__(self):
return self.size

Expand Down
2 changes: 2 additions & 0 deletions python/rmm/_lib/memory_resource.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,5 @@ cdef class LoggingResourceAdaptor(MemoryResource):
cpdef MemoryResource get_upstream(self)
cpdef get_file_name(self)
cpdef flush(self)

cpdef MemoryResource get_current_device_resource()
2 changes: 1 addition & 1 deletion python/rmm/_lib/memory_resource.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ cpdef get_per_device_resource_type(int device):
return type(get_per_device_resource(device))


cpdef get_current_device_resource():
cpdef MemoryResource get_current_device_resource():
"""
Get the memory resource used for RMM device allocations on the current
device.
Expand Down
48 changes: 38 additions & 10 deletions python/rmm/tests/test_rmm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2020, NVIDIA CORPORATION.
import gc
import os
import sys
from itertools import product
Expand All @@ -8,6 +9,7 @@
from numba import cuda

import rmm
import rmm._cuda.stream

if sys.version_info < (3, 8):
try:
Expand All @@ -20,6 +22,17 @@
cuda.set_memory_manager(rmm.RMMNumbaManager)


@pytest.fixture(scope="function", autouse=True)
def rmm_auto_reinitialize():

# Run the test
yield

# Automatically reinitialize the current memory resource after running each
# test
rmm.reinitialize()


def array_tester(dtype, nelem, alloc):
# data
h_in = np.full(nelem, 3.2, dtype)
Expand Down Expand Up @@ -70,7 +83,6 @@ def test_rmm_modes(dtype, nelem, alloc, managed, pool):
assert rmm.is_initialized()

array_tester(dtype, nelem, alloc)
rmm.reinitialize()


@pytest.mark.parametrize("dtype", _dtypes)
Expand All @@ -92,7 +104,6 @@ def test_rmm_csv_log(dtype, nelem, alloc, tmpdir):
assert csv.find(b"Time,Action,Pointer,Size,Stream") >= 0
finally:
os.remove(fname)
rmm.reinitialize()


@pytest.mark.parametrize("size", [0, 5])
Expand All @@ -109,7 +120,7 @@ def test_rmm_device_buffer(size):
assert len(b) == b.size
assert b.nbytes == b.size
assert b.capacity() >= b.size
assert sys.getsizeof(b) == b.size
assert b.__sizeof__() == b.size

# Test `__cuda_array_interface__`
keyset = {"data", "shape", "strides", "typestr", "version"}
Expand Down Expand Up @@ -299,7 +310,6 @@ def test_pool_memory_resource(dtype, nelem, alloc):
rmm.mr.set_current_device_resource(mr)
assert rmm.mr.get_current_device_resource_type() is type(mr)
array_tester(dtype, nelem, alloc)
rmm.reinitialize()


@pytest.mark.parametrize("dtype", _dtypes)
Expand All @@ -319,7 +329,6 @@ def test_fixed_size_memory_resource(dtype, nelem, alloc, upstream):
rmm.mr.set_current_device_resource(mr)
assert rmm.mr.get_current_device_resource_type() is type(mr)
array_tester(dtype, nelem, alloc)
rmm.reinitialize()


@pytest.mark.parametrize("dtype", _dtypes)
Expand Down Expand Up @@ -350,15 +359,13 @@ def test_binning_memory_resource(dtype, nelem, alloc, upstream_mr):
rmm.mr.set_current_device_resource(mr)
assert rmm.mr.get_current_device_resource_type() is type(mr)
array_tester(dtype, nelem, alloc)
rmm.reinitialize()


def test_reinitialize_max_pool_size():
rmm.reinitialize(
pool_allocator=True, initial_pool_size=0, maximum_pool_size=1 << 23
)
rmm.DeviceBuffer().resize((1 << 23) - 1)
rmm.reinitialize()


def test_reinitialize_max_pool_size_exceeded():
Expand All @@ -367,7 +374,6 @@ def test_reinitialize_max_pool_size_exceeded():
)
with pytest.raises(MemoryError):
rmm.DeviceBuffer().resize(1 << 24)
rmm.reinitialize()


def test_reinitialize_initial_pool_size_gt_max():
Expand All @@ -378,7 +384,30 @@ def test_reinitialize_initial_pool_size_gt_max():
maximum_pool_size=1 << 10,
)
assert "Initial pool size exceeds the maximum pool size" in str(e.value)
rmm.reinitialize()


def test_mr_devicebuffer_lifetime():
# Test ensures MR/Stream lifetime is longer than DeviceBuffer. Even if all
# references go out of scope
# Create new Pool MR
rmm.mr.set_current_device_resource(
rmm.mr.PoolMemoryResource(rmm.mr.get_current_device_resource())
)

# Creates a new non-default stream
stream = rmm._cuda.stream.Stream()

# Allocate DeviceBuffer with Pool and Stream
a = rmm.DeviceBuffer(size=10, stream=stream)

# Change current MR. Will cause Pool to go out of scope
rmm.mr.set_current_device_resource(rmm.mr.CudaMemoryResource())

# Force collection to ensure objects are cleaned up
gc.collect()

# Delete a. Used to crash before. Pool MR should still be alive
del a


@pytest.mark.parametrize("dtype", _dtypes)
Expand All @@ -404,4 +433,3 @@ def test_rmm_enable_disable_logging(dtype, nelem, alloc, tmpdir):
os.remove(fname)

rmm.disable_logging()
rmm.reinitialize()

0 comments on commit 230369d

Please sign in to comment.