Skip to content

Commit

Permalink
Merge pull request #226 from bouweandela/add-compiler-cupy-array-support
Browse files Browse the repository at this point in the history
Add support for passing cupy arrays to "C" lang
  • Loading branch information
benvanwerkhoven authored Nov 27, 2023
2 parents 66428e3 + 303ef3a commit e76b774
Show file tree
Hide file tree
Showing 4 changed files with 311 additions and 17 deletions.
4 changes: 3 additions & 1 deletion doc/source/hostcode.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ There are few differences with tuning just a single CUDA or OpenCL kernel, to li
* You have to specify the lang="C" option
* The C function should return a ``float``
* You have to do your own timing and error handling in C
* Data is not automatically copied to and from device memory. To use an array in host memory, pass in a :mod:`numpy` array. To use an array
in device memory, pass in a :mod:`cupy` array.

You have to specify the language as "C" because the Kernel Tuner will be calling a host function. This means that the Kernel
Tuner will have to interface with C and in fact uses a different backend. This also means you can use this way of tuning
Expand Down Expand Up @@ -94,7 +96,7 @@ compiled C code. This way, you don't have to compute the grid size in C, you can

The filter is not passed separately as a constant memory argument, because the CudaMemcpyToSymbol operation is now performed by the C host function. Also,
because the code is compiled differently, we have no direct reference to the compiled module that is uploaded to the device and therefore we can not perform this
operation directly from Python. If you are tuning host code, you have to perform all memory allocations, frees, and memcpy operations inside the C host code,
operation directly from Python. If you are tuning host code, you have the option to perform all memory allocations, frees, and memcpy operations inside the C host code,
that's the purpose of host code after all. That is also why you have to do the timing yourself in C, as you may not want to include the time spent on memory
allocations and other setup into your time measurements.

Expand Down
90 changes: 90 additions & 0 deletions examples/cuda/pnpoly_cupy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#!/usr/bin/env python
""" Point-in-Polygon host/device code tuner
This program is used for auto-tuning the host and device code of a CUDA program
for computing the point-in-polygon problem for very large datasets and large
polygons.
The time measurements used as a basis for tuning include the time spent on
data transfers between host and device memory. The host code uses device mapped
host memory to overlap communication between host and device with kernel
execution on the GPU. Because each input is read only once and each output
is written only once, this implementation almost fully overlaps all
communication and the kernel execution time dominates the total execution time.
The code has the option to precompute all polygon line slopes on the CPU and
reuse those results on the GPU, instead of recomputing them on the GPU all
the time. The time spent on precomputing these values on the CPU is also
taken into account by the time measurement in the code.
This code was written for use with the Kernel Tuner. See:
https://github.com/benvanwerkhoven/kernel_tuner
Author: Ben van Werkhoven <[email protected]>
"""
from collections import OrderedDict
import json
import logging

import cupy as cp
import cupyx as cpx
import kernel_tuner
import numpy


def allocator(size: int) -> cp.cuda.PinnedMemoryPointer:
"""Allocate context-portable device mapped host memory."""
flags = cp.cuda.runtime.hostAllocPortable | cp.cuda.runtime.hostAllocMapped
mem = cp.cuda.PinnedMemory(size, flags=flags)
return cp.cuda.PinnedMemoryPointer(mem, offset=0)


def tune():

#set the number of points and the number of vertices
size = numpy.int32(2e7)
problem_size = (size, 1)
vertices = 600

#allocate context-portable device mapped host memory
cp.cuda.set_pinned_memory_allocator(allocator)

#generate input data
points = cpx.empty_pinned(shape=(2*size,), dtype=numpy.float32)
points[:] = numpy.random.randn(2*size).astype(numpy.float32)

bitmap = cpx.zeros_pinned(shape=(size,), dtype=numpy.int32)
#as test input we use a circle with radius 1 as polygon and
#a large set of normally distributed points around 0,0
vertex_seeds = numpy.sort(numpy.random.rand(vertices)*2.0*numpy.pi)[::-1]
vertex_x = numpy.cos(vertex_seeds)
vertex_y = numpy.sin(vertex_seeds)
vertex_xy = cpx.empty_pinned(shape=(2*vertices,), dtype=numpy.float32)
vertex_xy[:] = numpy.array( list(zip(vertex_x, vertex_y)) ).astype(numpy.float32).ravel()

#kernel arguments
args = [bitmap, points, vertex_xy, size]

#setup tunable parameters
tune_params = OrderedDict()
tune_params["block_size_x"] = [32*i for i in range(1,32)] #multiple of 32
tune_params["tile_size"] = [1] + [2*i for i in range(1,11)]
tune_params["between_method"] = [0, 1, 2, 3]
tune_params["use_precomputed_slopes"] = [0, 1]
tune_params["use_method"] = [0, 1]

#tell the Kernel Tuner how to compute the grid dimensions from the problem_size
grid_div_x = ["block_size_x", "tile_size"]

#start tuning
results = kernel_tuner.tune_kernel("cn_pnpoly_host", ['pnpoly_host.cu', 'pnpoly.cu'],
problem_size, args, tune_params,
grid_div_x=grid_div_x, lang="C", compiler_options=["-arch=sm_52"], verbose=True, log=logging.DEBUG)

return results


if __name__ == "__main__":
results = tune()
with open("pnpoly.json", 'w') as fp:
json.dump(results, fp)
74 changes: 62 additions & 12 deletions kernel_tuner/backends/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,39 @@
SkippableFailure,
)

try:
import cupy as cp
except ImportError:
cp = None


def is_cupy_array(array):
"""Check if something is a cupy array.
:param array: A Python object.
:type array: typing.Any
:returns: True if cupy can be imported and the object is a cupy.ndarray.
:rtype: bool
"""
return cp is not None and isinstance(array, cp.ndarray)


def get_array_module(*args):
"""Return the array module for arguments.
This function is used to implement CPU/GPU generic code. If the cupy module can be imported
and at least one of the arguments is a cupy.ndarray object, the cupy module is returned.
:param args: Values to determine whether NumPy or CuPy should be used.
:type args: numpy.ndarray or cupy.ndarray
:returns: cupy or numpy is returned based on the types of the arguments.
:rtype: types.ModuleType
"""
return np if cp is None else cp.get_array_module(*args)


dtype_map = {
"int8": C.c_int8,
"int16": C.c_int16,
Expand Down Expand Up @@ -103,18 +136,18 @@ def ready_argument_list(self, arguments):
:param arguments: List of arguments to be passed to the C function.
The order should match the argument list on the C function.
Allowed values are np.ndarray, and/or np.int32, np.float32, and so on.
:type arguments: list(numpy objects)
Allowed values are np.ndarray, cupy.ndarray, and/or np.int32, np.float32, and so on.
:type arguments: list(numpy or cupy objects)
:returns: A list of arguments that can be passed to the C function.
:rtype: list(Argument)
"""
ctype_args = [None for _ in arguments]

for i, arg in enumerate(arguments):
if not isinstance(arg, (np.ndarray, np.number)):
if not (isinstance(arg, (np.ndarray, np.number)) or is_cupy_array(arg)):
raise TypeError(
"Argument is not numpy ndarray or numpy scalar %s" % type(arg)
f"Argument is not numpy or cupy ndarray or numpy scalar but a {type(arg)}"
)
dtype_str = str(arg.dtype)
if isinstance(arg, np.ndarray):
Expand All @@ -129,6 +162,8 @@ def ready_argument_list(self, arguments):
raise TypeError("unknown dtype for ndarray")
elif isinstance(arg, np.generic):
data_ctypes = dtype_map[dtype_str](arg)
elif is_cupy_array(arg):
data_ctypes = C.c_void_p(arg.data.ptr)
ctype_args[i] = Argument(numpy=arg, ctypes=data_ctypes)
return ctype_args

Expand Down Expand Up @@ -326,29 +361,44 @@ def memset(self, allocation, value, size):
:param size: The size of to the allocation unit in bytes
:type size: int
"""
C.memset(allocation.ctypes, value, size)
if is_cupy_array(allocation.numpy):
cp.cuda.runtime.memset(allocation.numpy.data.ptr, value, size)
else:
C.memset(allocation.ctypes, value, size)

def memcpy_dtoh(self, dest, src):
"""a simple memcpy copying from an Argument to a numpy array
:param dest: A numpy array to store the data
:type dest: np.ndarray
:param dest: A numpy or cupy array to store the data
:type dest: np.ndarray or cupy.ndarray
:param src: An Argument for some memory allocation
:type src: Argument
"""
dest[:] = src.numpy
if isinstance(dest, np.ndarray) and is_cupy_array(src.numpy):
# Implicit conversion to a NumPy array is not allowed.
value = src.numpy.get()
else:
value = src.numpy
xp = get_array_module(dest)
dest[:] = xp.asarray(value)

def memcpy_htod(self, dest, src):
"""a simple memcpy copying from a numpy array to an Argument
:param dest: An Argument for some memory allocation
:type dst: Argument
:type dest: Argument
:param src: A numpy array containing the source data
:type src: np.ndarray
:param src: A numpy or cupy array containing the source data
:type src: np.ndarray or cupy.ndarray
"""
dest.numpy[:] = src
if isinstance(dest.numpy, np.ndarray) and is_cupy_array(src):
# Implicit conversion to a NumPy array is not allowed.
value = src.get()
else:
value = src
xp = get_array_module(dest.numpy)
dest.numpy[:] = xp.asarray(value)

def cleanup_lib(self):
"""unload the previously loaded shared library"""
Expand Down
Loading

0 comments on commit e76b774

Please sign in to comment.