diff --git a/dask_cuda/benchmarks/utils.py b/dask_cuda/benchmarks/utils.py index 4ee44820e..9a185a81f 100644 --- a/dask_cuda/benchmarks/utils.py +++ b/dask_cuda/benchmarks/utils.py @@ -34,6 +34,16 @@ def parse_benchmark_args(description="Generic dask-cuda Benchmark", args_list=[] type=str, help="Write dask profile report (E.g. dask-report.html)", ) + parser.add_argument( + "--device-memory-limit", + default=None, + type=parse_bytes, + help="Size of the CUDA device LRU cache, which is used to determine when the " + "worker starts spilling to host memory. Can be an integer (bytes), float " + "(fraction of total device memory), string (like ``'5GB'`` or ``'5000M'``), or " + "``'auto'``, 0, or ``None`` to disable spilling to host (i.e. allow full " + "device memory usage).", + ) parser.add_argument( "--rmm-pool-size", default=None, @@ -203,6 +213,8 @@ def get_cluster_options(args): if args.enable_rdmacm: worker_options["enable_rdmacm"] = "" + if args.device_memory_limit: + worker_options["device_memory_limit"] = args.device_memory_limit if args.ucx_net_devices: worker_options["ucx_net_devices"] = args.ucx_net_devices @@ -229,6 +241,7 @@ def get_cluster_options(args): "enable_nvlink": args.enable_nvlink, "enable_rdmacm": args.enable_rdmacm, "interface": args.interface, + "device_memory_limit": args.device_memory_limit, } if args.no_silence_logs: cluster_kwargs["silence_logs"] = False