-
Notifications
You must be signed in to change notification settings - Fork 96
/
Copy pathclient_initialize.py
57 lines (45 loc) · 1.27 KB
/
client_initialize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import click
import cupy
from dask import array as da
from dask.distributed import Client
from dask_cuda.initialize import initialize
@click.command(context_settings=dict(ignore_unknown_options=True))
@click.argument(
"address", required=True, type=str,
)
@click.option(
"--enable-nvlink/--disable-nvlink",
default=False,
help="Enable NVLink communication",
)
@click.option(
"--enable-infiniband/--disable-infiniband",
default=False,
help="Enable InfiniBand communication with RDMA",
)
def main(
address, enable_nvlink, enable_infiniband,
):
enable_rdmacm = False
ucx_net_devices = None
if enable_infiniband:
# enable_rdmacm = True # RDMACM not working right now
ucx_net_devices = "mlx5_0:1"
# set up environment
initialize(
enable_tcp_over_ucx=True,
enable_nvlink=enable_nvlink,
enable_infiniband=enable_infiniband,
enable_rdmacm=enable_rdmacm,
net_devices=ucx_net_devices,
)
# initialize client
client = Client(address)
# user code here
rs = da.random.RandomState(RandomState=cupy.random.RandomState)
x = rs.random((10000, 10000), chunks=1000)
x.sum().compute()
# shutdown cluster
client.shutdown()
if __name__ == "__main__":
main()