diff --git a/doc/source/hostcode.rst b/doc/source/hostcode.rst
index 419c88ae3..cb8940465 100644
--- a/doc/source/hostcode.rst
+++ b/doc/source/hostcode.rst
@@ -15,6 +15,8 @@ There are few differences with tuning just a single CUDA or OpenCL kernel, to li
  * You have to specify the lang="C" option
  * The C function should return a ``float``
  * You have to do your own timing and error handling in C
+ * Data is not automatically copied to and from device memory. To use an array in host memory, pass in a :mod:`numpy` array. To use an array
+   in device memory, pass in a :mod:`cupy` array.
 
 You have to specify the language as "C" because the Kernel Tuner will be calling a host function. This means that the Kernel
 Tuner will have to interface with C and in fact uses a different backend. This also means you can use this way of tuning
@@ -94,7 +96,7 @@ compiled C code. This way, you don't have to compute the grid size in C, you can
 
 The filter is not passed separately as a constant memory argument, because the CudaMemcpyToSymbol operation is now performed by the C host function. Also, 
 because the code is compiled differently, we have no direct reference to the compiled module that is uploaded to the device and therefore we can not perform this 
-operation directly from Python. If you are tuning host code, you have to perform all memory allocations, frees, and memcpy operations inside the C host code, 
+operation directly from Python. If you are tuning host code, you have the option to perform all memory allocations, frees, and memcpy operations inside the C host code, 
 that's the purpose of host code after all. That is also why you have to do the timing yourself in C, as you may not want to include the time spent on memory 
 allocations and other setup into your time measurements.
 
diff --git a/examples/cuda/pnpoly_cupy.py b/examples/cuda/pnpoly_cupy.py
new file mode 100755
index 000000000..e24d8523a
--- /dev/null
+++ b/examples/cuda/pnpoly_cupy.py
@@ -0,0 +1,90 @@
+#!/usr/bin/env python
+""" Point-in-Polygon host/device code tuner
+
+This program is used for auto-tuning the host and device code of a CUDA program
+for computing the point-in-polygon problem for very large datasets and large
+polygons.
+
+The time measurements used as a basis for tuning include the time spent on
+data transfers between host and device memory. The host code uses device mapped
+host memory to overlap communication between host and device with kernel
+execution on the GPU. Because each input is read only once and each output
+is written only once, this implementation almost fully overlaps all
+communication and the kernel execution time dominates the total execution time.
+
+The code has the option to precompute all polygon line slopes on the CPU and
+reuse those results on the GPU, instead of recomputing them on the GPU all
+the time. The time spent on precomputing these values on the CPU is also
+taken into account by the time measurement in the code.
+
+This code was written for use with the Kernel Tuner. See:
+     https://github.com/benvanwerkhoven/kernel_tuner
+
+Author: Ben van Werkhoven <b.vanwerkhoven@esciencecenter.nl>
+"""
+from collections import OrderedDict
+import json
+import logging
+
+import cupy as cp
+import cupyx as cpx
+import kernel_tuner
+import numpy
+
+
+def allocator(size: int) -> cp.cuda.PinnedMemoryPointer:
+    """Allocate context-portable device mapped host memory."""
+    flags = cp.cuda.runtime.hostAllocPortable | cp.cuda.runtime.hostAllocMapped
+    mem = cp.cuda.PinnedMemory(size, flags=flags)
+    return cp.cuda.PinnedMemoryPointer(mem, offset=0)
+
+
+def tune():
+
+    #set the number of points and the number of vertices
+    size = numpy.int32(2e7)
+    problem_size = (size, 1)
+    vertices = 600
+
+    #allocate context-portable device mapped host memory
+    cp.cuda.set_pinned_memory_allocator(allocator)
+
+    #generate input data
+    points = cpx.empty_pinned(shape=(2*size,), dtype=numpy.float32)
+    points[:] = numpy.random.randn(2*size).astype(numpy.float32)
+
+    bitmap = cpx.zeros_pinned(shape=(size,), dtype=numpy.int32)
+    #as test input we use a circle with radius 1 as polygon and
+    #a large set of normally distributed points around 0,0
+    vertex_seeds = numpy.sort(numpy.random.rand(vertices)*2.0*numpy.pi)[::-1]
+    vertex_x = numpy.cos(vertex_seeds)
+    vertex_y = numpy.sin(vertex_seeds)
+    vertex_xy = cpx.empty_pinned(shape=(2*vertices,), dtype=numpy.float32)
+    vertex_xy[:] = numpy.array( list(zip(vertex_x, vertex_y)) ).astype(numpy.float32).ravel()
+
+    #kernel arguments
+    args = [bitmap, points, vertex_xy, size]
+
+    #setup tunable parameters
+    tune_params = OrderedDict()
+    tune_params["block_size_x"] = [32*i for i in range(1,32)]  #multiple of 32
+    tune_params["tile_size"] = [1] + [2*i for i in range(1,11)]
+    tune_params["between_method"] = [0, 1, 2, 3]
+    tune_params["use_precomputed_slopes"] = [0, 1]
+    tune_params["use_method"] = [0, 1]
+
+    #tell the Kernel Tuner how to compute the grid dimensions from the problem_size
+    grid_div_x = ["block_size_x", "tile_size"]
+
+    #start tuning
+    results = kernel_tuner.tune_kernel("cn_pnpoly_host", ['pnpoly_host.cu', 'pnpoly.cu'],
+        problem_size, args, tune_params,
+        grid_div_x=grid_div_x, lang="C", compiler_options=["-arch=sm_52"], verbose=True, log=logging.DEBUG)
+
+    return results
+
+
+if __name__ == "__main__":
+    results = tune()
+    with open("pnpoly.json", 'w') as fp:
+        json.dump(results, fp)
diff --git a/kernel_tuner/backends/compiler.py b/kernel_tuner/backends/compiler.py
index a7e15c577..2cccae523 100644
--- a/kernel_tuner/backends/compiler.py
+++ b/kernel_tuner/backends/compiler.py
@@ -21,6 +21,39 @@
     SkippableFailure,
 )
 
+try:
+    import cupy as cp
+except ImportError:
+    cp = None
+
+
+def is_cupy_array(array):
+    """Check if something is a cupy array.
+
+    :param array: A Python object.
+    :type array: typing.Any
+
+    :returns: True if cupy can be imported and the object is a cupy.ndarray.
+    :rtype: bool
+    """
+    return cp is not None and isinstance(array, cp.ndarray)
+
+
+def get_array_module(*args):
+    """Return the array module for arguments.
+
+    This function is used to implement CPU/GPU generic code. If the cupy module can be imported
+    and at least one of the arguments is a cupy.ndarray object, the cupy module is returned.
+
+    :param args: Values to determine whether NumPy or CuPy should be used.
+    :type args: numpy.ndarray or cupy.ndarray
+
+    :returns: cupy or numpy is returned based on the types of the arguments.
+    :rtype: types.ModuleType
+    """
+    return np if cp is None else cp.get_array_module(*args)
+
+
 dtype_map = {
     "int8": C.c_int8,
     "int16": C.c_int16,
@@ -103,8 +136,8 @@ def ready_argument_list(self, arguments):
 
         :param arguments: List of arguments to be passed to the C function.
             The order should match the argument list on the C function.
-            Allowed values are np.ndarray, and/or np.int32, np.float32, and so on.
-        :type arguments: list(numpy objects)
+            Allowed values are np.ndarray, cupy.ndarray, and/or np.int32, np.float32, and so on.
+        :type arguments: list(numpy or cupy objects)
 
         :returns: A list of arguments that can be passed to the C function.
         :rtype: list(Argument)
@@ -112,9 +145,9 @@ def ready_argument_list(self, arguments):
         ctype_args = [None for _ in arguments]
 
         for i, arg in enumerate(arguments):
-            if not isinstance(arg, (np.ndarray, np.number)):
+            if not (isinstance(arg, (np.ndarray, np.number)) or is_cupy_array(arg)):
                 raise TypeError(
-                    "Argument is not numpy ndarray or numpy scalar %s" % type(arg)
+                    f"Argument is not numpy or cupy ndarray or numpy scalar but a {type(arg)}"
                 )
             dtype_str = str(arg.dtype)
             if isinstance(arg, np.ndarray):
@@ -129,6 +162,8 @@ def ready_argument_list(self, arguments):
                     raise TypeError("unknown dtype for ndarray")
             elif isinstance(arg, np.generic):
                 data_ctypes = dtype_map[dtype_str](arg)
+            elif is_cupy_array(arg):
+                data_ctypes = C.c_void_p(arg.data.ptr)
             ctype_args[i] = Argument(numpy=arg, ctypes=data_ctypes)
         return ctype_args
 
@@ -326,29 +361,44 @@ def memset(self, allocation, value, size):
         :param size: The size of to the allocation unit in bytes
         :type size: int
         """
-        C.memset(allocation.ctypes, value, size)
+        if is_cupy_array(allocation.numpy):
+            cp.cuda.runtime.memset(allocation.numpy.data.ptr, value, size)
+        else:
+            C.memset(allocation.ctypes, value, size)
 
     def memcpy_dtoh(self, dest, src):
         """a simple memcpy copying from an Argument to a numpy array
 
-        :param dest: A numpy array to store the data
-        :type dest: np.ndarray
+        :param dest: A numpy or cupy array to store the data
+        :type dest: np.ndarray or cupy.ndarray
 
         :param src: An Argument for some memory allocation
         :type src: Argument
         """
-        dest[:] = src.numpy
+        if isinstance(dest, np.ndarray) and is_cupy_array(src.numpy):
+            # Implicit conversion to a NumPy array is not allowed.
+            value = src.numpy.get()
+        else:
+            value = src.numpy
+        xp = get_array_module(dest)
+        dest[:] = xp.asarray(value)
 
     def memcpy_htod(self, dest, src):
         """a simple memcpy copying from a numpy array to an Argument
 
         :param dest: An Argument for some memory allocation
-        :type dst: Argument
+        :type dest: Argument
 
-        :param src: A numpy array containing the source data
-        :type src: np.ndarray
+        :param src: A numpy or cupy array containing the source data
+        :type src: np.ndarray or cupy.ndarray
         """
-        dest.numpy[:] = src
+        if isinstance(dest.numpy, np.ndarray) and is_cupy_array(src):
+            # Implicit conversion to a NumPy array is not allowed.
+            value = src.get()
+        else:
+            value = src
+        xp = get_array_module(dest.numpy)
+        dest.numpy[:] = xp.asarray(value)
 
     def cleanup_lib(self):
         """unload the previously loaded shared library"""
diff --git a/test/test_compiler_functions.py b/test/test_compiler_functions.py
index 475719fb0..5060e2203 100644
--- a/test/test_compiler_functions.py
+++ b/test/test_compiler_functions.py
@@ -11,11 +11,12 @@
     from unittest.mock import patch, Mock
 
 import kernel_tuner
-from kernel_tuner.backends.compiler import CompilerFunctions, Argument
+from kernel_tuner.backends.compiler import CompilerFunctions, Argument, is_cupy_array, get_array_module
 from kernel_tuner.core import KernelSource, KernelInstance
 from kernel_tuner import util
 
-from .context import skip_if_no_gfortran, skip_if_no_gcc, skip_if_no_openmp
+from .context import skip_if_no_gfortran, skip_if_no_gcc, skip_if_no_openmp, skip_if_no_cupy
+from .test_runners import env as cuda_env  # noqa: F401
 
 
 @skip_if_no_gcc
@@ -108,6 +109,29 @@ def test_ready_argument_list5():
     assert all(output[0].numpy == arg1)
 
 
+@skip_if_no_cupy
+def test_ready_argument_list6():
+    import cupy as cp
+
+    arg = cp.array([1, 2, 3], dtype=np.float32)
+    arguments = [arg]
+
+    cfunc = CompilerFunctions()
+    output = cfunc.ready_argument_list(arguments)
+    print(output)
+
+    assert len(output) == 1
+    assert output[0].numpy is arg
+    mem = cp.cuda.UnownedMemory(
+        ptr=output[0].ctypes.value,
+        size=int(arg.nbytes / arg.dtype.itemsize),
+        owner=None,
+    )
+    ptr = cp.cuda.MemoryPointer(mem, 0)
+    output_arg = cp.ndarray(shape=arg.shape, dtype=arg.dtype, memptr=ptr)
+    assert cp.all(output_arg == arg)
+
+
 @skip_if_no_gcc
 def test_byte_array_arguments():
     arg1 = np.array([1, 2, 3]).astype(np.int8)
@@ -206,8 +230,29 @@ def test_memset():
     assert all(x == np.zeros(4))
 
 
-@skip_if_no_gcc
+@skip_if_no_cupy
 def test_memcpy_dtoh():
+    import cupy as cp
+
+    a = [1, 2, 3, 4]
+    x = cp.asarray(a, dtype=np.float32)
+    x_c = C.c_void_p(x.data.ptr)
+    arg = Argument(numpy=x, ctypes=x_c)
+    output = np.zeros(len(x), dtype=x.dtype)
+
+    cfunc = CompilerFunctions()
+    cfunc.memcpy_dtoh(output, arg)
+
+    print(f"{type(x)=} {x=}")
+    print(f"{type(a)=} {a=}")
+    print(f"{type(output)=} {output=}")
+
+    assert all(output == a)
+    assert all(x.get() == a)
+
+
+@skip_if_no_gcc
+def test_memcpy_host_dtoh():
     a = [1, 2, 3, 4]
     x = np.array(a).astype(np.float32)
     x_c = x.ctypes.data_as(C.POINTER(C.c_float))
@@ -224,8 +269,44 @@ def test_memcpy_dtoh():
     assert all(x == a)
 
 
-@skip_if_no_gcc
+@skip_if_no_cupy
+def test_memcpy_device_dtoh():
+    import cupy as cp
+
+    a = [1, 2, 3, 4]
+    x = cp.asarray(a, dtype=np.float32)
+    x_c = C.c_void_p(x.data.ptr)
+    arg = Argument(numpy=x, ctypes=x_c)
+    output = cp.zeros_like(x)
+
+    cfunc = CompilerFunctions()
+    cfunc.memcpy_dtoh(output, arg)
+
+    print(f"{type(x)=} {x=}")
+    print(f"{type(a)=} {a=}")
+    print(f"{type(output)=} {output=}")
+
+    assert all(output.get() == a)
+    assert all(x.get() == a)
+
+
+@skip_if_no_cupy
 def test_memcpy_htod():
+    import cupy as cp
+
+    a = [1, 2, 3, 4]
+    src = np.array(a, dtype=np.float32)
+    x = cp.zeros(len(src), dtype=src.dtype)
+    x_c = C.c_void_p(x.data.ptr)
+    arg = Argument(numpy=x, ctypes=x_c)
+
+    cfunc = CompilerFunctions()
+    cfunc.memcpy_htod(arg, src)
+
+    assert all(arg.numpy.get() == a)
+
+
+def test_memcpy_host_htod():
     a = [1, 2, 3, 4]
     src = np.array(a).astype(np.float32)
     x = np.zeros_like(src)
@@ -238,6 +319,22 @@ def test_memcpy_htod():
     assert all(arg.numpy == a)
 
 
+@skip_if_no_cupy
+def test_memcpy_device_htod():
+    import cupy as cp
+
+    a = [1, 2, 3, 4]
+    src = cp.array(a, dtype=np.float32)
+    x = cp.zeros(len(src), dtype=src.dtype)
+    x_c = C.c_void_p(x.data.ptr)
+    arg = Argument(numpy=x, ctypes=x_c)
+
+    cfunc = CompilerFunctions()
+    cfunc.memcpy_htod(arg, src)
+
+    assert all(arg.numpy.get() == a)
+
+
 @skip_if_no_gfortran
 def test_complies_fortran_function_no_module():
     kernel_string = """
@@ -335,3 +432,58 @@ def test_benchmark(env):
     assert all(["nthreads" in result for result in results])
     assert all(["time" in result for result in results])
     assert all([result["time"] > 0.0 for result in results])
+
+
+@skip_if_no_cupy
+def test_is_cupy_array():
+    import cupy as cp
+
+    assert is_cupy_array(cp.array([1.0]))
+    assert not is_cupy_array(np.array([1.0]))
+
+
+def test_is_cupy_array_no_cupy():
+    assert not is_cupy_array(np.array([1.0]))
+
+
+@skip_if_no_cupy
+def test_get_array_module():
+    import cupy as cp
+
+    assert get_array_module(cp.array([1.0])) == cp
+    assert get_array_module(np.array([1.0])) == np
+
+
+@skip_if_no_cupy
+@skip_if_no_gcc
+def test_run_kernel():
+    import cupy as cp
+
+    kernel_string = """
+    __global__ void vector_add_kernel(float *c, const float *a, const float *b, int n) {
+        int i = blockIdx.x * block_size_x + threadIdx.x;
+        if (i<n) {
+            c[i] = a[i] + b[i];
+        }
+    }
+
+    extern "C" void vector_add(float *c, const float *a, const float *b, int n) {
+        dim3 dimGrid(n);
+        dim3 dimBlock(block_size_x);
+        vector_add_kernel<<<dimGrid, dimBlock>>>(c, a, b, n);
+    }
+    """
+    a = cp.asarray([1, 2.0], dtype=np.float32)
+    b = cp.asarray([3, 4.0], dtype=np.float32)
+    c = cp.zeros_like(b)
+    n = np.int32(len(c))
+
+    result = kernel_tuner.run_kernel(
+        kernel_name="vector_add",
+        kernel_source=kernel_string,
+        problem_size=n,
+        arguments=[c, a, b, n],
+        params={"block_size_x": 1},
+        lang="C",
+    )
+    assert cp.all((a + b) == c)