-
Notifications
You must be signed in to change notification settings - Fork 96
/
Copy pathlocal_cuda_cluster.py
77 lines (64 loc) · 1.97 KB
/
local_cuda_cluster.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import click
import cupy
from dask import array as da
from dask.distributed import Client
from dask.utils import parse_bytes
from dask_cuda import LocalCUDACluster
@click.command(context_settings=dict(ignore_unknown_options=True))
@click.option(
"--enable-nvlink/--disable-nvlink",
default=False,
help="Enable NVLink communication",
)
@click.option(
"--enable-infiniband/--disable-infiniband",
default=False,
help="Enable InfiniBand communication with RDMA",
)
@click.option(
"--interface",
default=None,
type=str,
help="Interface used by scheduler for communication. Must be "
"specified if NVLink or InfiniBand are enabled.",
)
@click.option(
"--rmm-pool-size",
default="1GB",
type=parse_bytes,
help="If specified, initialize each worker with an RMM pool of "
"the given size, otherwise no RMM pool is created. This can be "
"an integer (bytes) or string (like 5GB or 5000M).",
)
def main(
enable_nvlink, enable_infiniband, interface, rmm_pool_size,
):
enable_rdmacm = False
ucx_net_devices = None
if enable_infiniband:
# enable_rdmacm = True # RDMACM not working right now
ucx_net_devices = "auto"
if (enable_infiniband or enable_nvlink) and not interface:
raise ValueError(
"Interface must be specified if NVLink or Infiniband are enabled"
)
# initialize scheduler & workers
cluster = LocalCUDACluster(
enable_tcp_over_ucx=True,
enable_nvlink=enable_nvlink,
enable_infiniband=enable_infiniband,
enable_rdmacm=enable_rdmacm,
ucx_net_devices=ucx_net_devices,
interface=interface,
rmm_pool_size=rmm_pool_size,
)
# initialize client
client = Client(cluster)
# user code here
rs = da.random.RandomState(RandomState=cupy.random.RandomState)
x = rs.random((10000, 10000), chunks=1000)
x.sum().compute()
# shutdown cluster
client.shutdown()
if __name__ == "__main__":
main()