-
Notifications
You must be signed in to change notification settings - Fork 542
XLA backend
Lc0 XLA backend uses OpenXLA compiler to produce code that executes a neural network. The backend takes ONNX as an input, and converts it to the HLO format that XLA understands.
The C API that OpenXLA implements is called PjRT, and the Lc0 XLA backend runs NN through it.
- On NVidia devices, it runs NN at speeds comparable to the handwritten cuda/cudnn backend (and sometimes faster).
- Only requires ONNX as input (either external, or converted by the Lc0 ONNX converter), meaning it can run NNs for which other backends are no written yet.
- Optimizes for particular GPU (will produce different code for A100 vs H100, and for different memory sizes)
- Supports running on TPUs
- Supports multiGPU / multihost setups for huge nets (not sure if we ever need it)
- Painful to build
- Linux only (in theory runs in windows through WSL)
- Slow backend startup. It compiles networks when it load them, which may take a few minutes.
To build it, just add -Dxla
parameter to ./build.sh
. There's no dependencies, but it needs pjrt
plugin (a .so
file) at the runtime.
Most likely, you'll have to build it yourself, but you may try prebuilt XLA library in the elixir-nx repository. They build entire XLA repository though, so the resulting .so
file is ≈20% larger, and noone tested yet whether the function GetPjrtApi()
is actually exported.
Alternatively, you can use the prebuilt plugin that is shipped with Jax on CUDA-enabled versions.
- Enter a Python environment. e.g.,
python3 -m virtualenv venv
;source venv/bin/activate
. - Install the Jax backend (PJRT) library
pip install jax-cuda12-pjrt
(cuda11 also available). - Copy and rename file found in
venv/lib/python3.10/site-packages/jax_plugins/xla_cuda12/xla_cuda_plugin.so
. Note that you may need to adapt the path. You may rename the file topjrt_c_api_gpu_plugin.so
. - As of writing, the symbol
GetPjrtApi
is confirmed to be exported, so you should be able to use this prebuilt library.
- Clone the openxla/xla repo:
git clone https://github.com/openxla/xla.git
-
Install either Bazel 6.5.0, or (better) Bazelisk, which would install correct version of Bazel automatically.
-
Inside the repo, run
./configure
with desired parameters (see./configure --help
), for example:
TF_CUDA_PATHS=/opt/cuda,/usr ./configure.py --backend CUDA --cudnn_version=8 --cuda_compute_capabilities=5.2
- Build the plugin (takes some hours):
bazel build -c opt //xla/pjrt/c:pjrt_c_api_gpu_plugin.so
- After that, the resulting plugin will be locates at
<xla-repo>/bazel-bin/xla/pjrt/c/pjrt_c_api_gpu_plugin.so
. You'll either have to put it into the same directory as your./lc0
, or will have to specify path to it in the Lc0 backend opts.
The PjRT XLA plugin to run on Google TPU is called libtpu.so
. Noone tried it though.
It's available at https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/2023-09-12/libtpu.so (or check for newer versions here)
Specify --backend=xla
as usual. Supported options (through --backend-opts
):
-
device=0
-- use GPU number 0 (it'sdevice
rather thangpu
because it's not necessary a GPU but rather e.g. TPU) -
plugin_path=./pjrt_c_api_gpu_plugin.so
path to the PjRT plugin. Note that the filename usesdlopen()
convention: if it contains slash (e.g. starts with./
, it's a relative or absolute path to a file); if it doesn't contain a slash, the only a few system defined locations (e.g. LD_LIBRARY_PATH dir) are checked. -
data_type=f32
the data type to use inside the NN (currently supported aref32
,f16
,bf16
). (currently, in all cases the backend communicates with the GPU usingf32
though) -
max_batch=512
andsteps=16
define for which batch sizes to compile the code. Currently XLA backend only supports static batch sizes, and chooses smallest suitable batch size for evaluation. E.g. ifmax_batch=9
andsteps=3
, kernels for batch sizes 3, 6 and 9 will be generated.
To inspect the generated HLO, a bunch of options are added to the leela2onnx
command, for example:
-
--hlo-text-output=filename.txt
-- text dump of HLO -
--hlo-proto-output=filename.pb
-- dump of HLO as proto. This can be fed to all the XLA tools (e.g.hlo-opt
to see the kernels it generates)
If you need to add a convertor for an ONNX op not currently supported, (or not fully supported), here is how it can be done:
- Convert your net to ONNX, e.g.
./lc0 leela2onnx --input=t79.pb.gz --output=t79.onnx
- Convert your ONNX net to HLO (in
text
andproto
form) using thejaxonnxruntime
. You'll need.hlo
file to see how it's implemented in HLO form (for theonnx2hlo.cc
code), and.hlo.proto
to find out how new HLO opcodes (if any) are encoded in proto (for thehlo_builder.cc
code):
import onnx
from jaxonnxruntime.call_onnx import call_onnx_model
from jaxonnxruntime import config
import numpy as np
import jax
config.update('jaxort_only_allow_initializers_as_static_args', False)
MODEL_FILE = 't79.onnx'
BATCH_SIZE = 14
INPUTS = [np.zeros((BATCH_SIZE, 112, 8, 8))]
onnx_model = onnx.load(MODEL_FILE)
m = call_onnx_model(onnx_model, INPUTS)
lowered = jax.jit(m[0]).lower(m[1], INPUTS)
with open(MODEL_FILE + ".hlo", "w") as f:
f.write(lowered.compiler_ir('hlo').as_hlo_text())
with open(MODEL_FILE + ".hlo.proto", "wb") as f:
f.write(lowered.compiler_ir('hlo').as_serialized_hlo_module_proto())
- Convert ONNX and HLO proto to text format: ONNX:
cat t75.onnx | protoc -I ~/path/to/lc0/src/neural/onnx/ ~/path/to/lc0/src/neural/onnx/onnx.proto --decode=pblczero.ModelProto > t75.onnx.asciiproto
HLO:
cat t75.hlo.proto | protoc -I ~/path/to/xla/ ~/path/to/xla/xla/service/hlo.proto --decode=xla.HloModuleProto > t75.hlo.asciiproto
- Add missing HLO opcodes to
hlo_builder.cc
and ONNX operands toonnx2hlo.cc
. When adding new HLO opcodes, it's often needed to add more fields intohlo.proto
. When that happens, also updateprint_hlo.cc
to show them.
Also see the description of the initial PR for more details.
- To iterate on changes, it's useful to use (hidden)
--hlo-allow-partial-result
flag ofleela2onnx
, to see the partial output of the HLO conversion even if there is an error, for example:
./lc0 leela2onnx --input=BT3-768x15x24h-swa-2790000.pb.gz --hlo-text-output=- --hlo-allow-partial-result --hlo-batch-size=333 --onnx-data-type=bf16