-
Notifications
You must be signed in to change notification settings - Fork 234
/
Copy pathinitialization.jl
164 lines (134 loc) · 6.09 KB
/
initialization.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# initialization
# XXX: we currently allow loading CUDA.jl even if the package is not functional, because
# downstream packages can only unconditionally depend on CUDA.jl. that's why we have
# the errors be non-fatal, sometimes even silencing them, and why we have the
# `functional()` API that allows checking for successfull initialization.
# TODO: once we have conditional dependencies, remove this complexity and have __init__ fail
const _initialized = Ref{Bool}(false)
const _initialization_error = Ref{String}()
"""
functional(show_reason=false)
Check if the package has been configured successfully and is ready to use.
This call is intended for packages that support conditionally using an available GPU. If you
fail to check whether CUDA is functional, actual use of functionality might warn and error.
"""
function functional(show_reason::Bool=false)
_initialized[] && return true
if show_reason && isassigned(_initialization_error)
error(_initialization_error[])
elseif show_reason
error("unknown initialization error")
end
return false
end
function __init__()
precompiling = ccall(:jl_generating_output, Cint, ()) != 0
precompiling && return
# TODO: make most of these errors fatal, and remove functional(),
# once we have conditional dependencies
# check that we have a driver
if !isdefined(CUDA_Driver_jll, :libcuda)
_initialization_error[] = "CUDA driver not found"
return
end
global libcuda = CUDA_Driver_jll.libcuda
driver = driver_version()
if driver < v"10.2"
@error "This version of CUDA.jl only supports NVIDIA drivers for CUDA 10.2 or higher (yours is for CUDA $driver)"
_initialization_error[] = "CUDA driver too old"
return
end
if driver < v"11.2"
@warn """The NVIDIA driver on this system only supports up to CUDA $driver.
For performance reasons, it is recommended to upgrade to a driver that supports CUDA 11.2 or higher."""
end
# check that we have a runtime
if !CUDA_Runtime.is_available()
@error """No CUDA Runtime library found. This can have several reasons:
* you are using an unsupported platform: CUDA.jl only supports Linux (x86_64, aarch64, ppc64le), and Windows (x86_64).
refer to the documentation for instructions on how to use a custom CUDA runtime.
* you precompiled CUDA.jl in an environment where the CUDA driver was not available.
in that case, you need to specify (during pre compilation) which version of CUDA to use.
refer to the documentation for instructions on how to use `CUDA.set_runtime_version!`.
* you requested use of a local CUDA toolkit, but not all components were discovered.
try running with JULIA_DEBUG=CUDA_Runtime_Discovery for more information."""
_initialization_error[] = "CUDA runtime not found"
return
end
runtime = try
runtime_version()
catch err
if err isa CuError && err.code == ERROR_NO_DEVICE
@error "No CUDA-capable device found"
_initialization_error[] = "No CUDA-capable device found"
return
end
rethrow()
end
if runtime < v"10.2"
@error "This version of CUDA.jl only supports CUDA 11 or higher (your toolkit provides CUDA $runtime)"
_initialization_error[] = "CUDA runtime too old"
return
end
if runtime.major > driver.major
@warn """You are using CUDA $runtime with a driver that only supports up to $(driver.major).x.
It is recommended to upgrade your driver, or switch to automatic installation of CUDA."""
end
# finally, initialize CUDA
try
cuInit(0)
catch err
@error "Failed to initialize CUDA" exception=(err,catch_backtrace())
_initialization_error[] = "CUDA initialization failed"
return
end
# register device overrides
eval(Expr(:block, overrides...))
empty!(overrides)
@require SpecialFunctions="276daf66-3868-5448-9aa4-cd146d93841b" begin
include("device/intrinsics/special_math.jl")
eval(Expr(:block, overrides...))
empty!(overrides)
end
# ensure that operations executed by the REPL back-end finish before returning,
# because displaying values happens on a different task (CUDA.jl#831)
if isdefined(Base, :active_repl_backend)
push!(Base.active_repl_backend.ast_transforms, synchronize_cuda_tasks)
end
# enable generation of FMA instructions to mimic behavior of nvcc
LLVM.clopts("-nvptx-fma-level=1")
# warn about old, deprecated environment variables
haskey(ENV, "JULIA_CUDA_USE_BINARYBUILDER") &&
@error """JULIA_CUDA_USE_BINARYBUILDER is deprecated, and CUDA.jl always uses artifacts now.
To use a local installation, use overrides or preferences to customize the artifact.
Please check the CUDA.jl or Pkg.jl documentation for more details."""
haskey(ENV, "JULIA_CUDA_VERSION") &&
@error """JULIA_CUDA_VERSION is deprecated. Call `CUDA.jl.set_runtime_version!` to use a different version instead."""
_initialized[] = true
end
function synchronize_cuda_tasks(ex)
quote
try
$(ex)
finally
$task_local_state() !== nothing && $device_synchronize()
end
end
end
## convenience functions
# TODO: update docstrings
export has_cuda, has_cuda_gpu
"""
has_cuda()::Bool
Check whether the local system provides an installation of the CUDA driver and runtime.
Use this function if your code loads packages that require CUDA.jl.
```
"""
has_cuda(show_reason::Bool=false) = functional(show_reason)
"""
has_cuda_gpu()::Bool
Check whether the local system provides an installation of the CUDA driver and runtime, and
if it contains a CUDA-capable GPU. See [`has_cuda`](@ref) for more details.
Note that this function initializes the CUDA API in order to check for the number of GPUs.
"""
has_cuda_gpu(show_reason::Bool=false) = has_cuda(show_reason) && length(devices()) > 0