Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU] Reduce compilation time & Upgrade PyTorch XLA version #6856

Merged
merged 80 commits into from
Jul 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
76fc072
Add & warnings
WoosukKwon Jun 24, 2024
27a5ad8
Add in dummy_run
WoosukKwon Jun 24, 2024
5ab6f65
Add is_driver_worker
WoosukKwon Jun 24, 2024
c4e79a0
Make TPUExecutor similar to GPUExecutor
WoosukKwon Jun 24, 2024
ff81993
Add multiprocessing-based TPU executor
WoosukKwon Jun 24, 2024
16e80b2
Use TPU to initialize Ray cluster
WoosukKwon Jun 24, 2024
05884ce
Add pjrt proc init
WoosukKwon Jun 24, 2024
20d23eb
Add Ray TPU executor
WoosukKwon Jun 24, 2024
5d4df21
Use Ray TPU executor for tp
WoosukKwon Jun 24, 2024
6b2c76c
Minor
WoosukKwon Jun 24, 2024
d91446b
Fix TPUWorker.execute_model
WoosukKwon Jun 24, 2024
ab1595d
Add is_driver_worker & input broadcast
WoosukKwon Jun 24, 2024
4b45393
Call xm._init_world_size_ordinal
WoosukKwon Jun 24, 2024
86451a2
Bug fix on vocab
WoosukKwon Jun 24, 2024
0539299
Use all gather for TPU
WoosukKwon Jun 24, 2024
b35917c
Support TPU in GroupCoordinator
WoosukKwon Jun 24, 2024
b9a84bc
Delete multiproc TPU executor
WoosukKwon Jun 25, 2024
c756b76
Minor
WoosukKwon Jun 25, 2024
16e9934
[Bugfix][TPU] Fix CPU cache allocation & swapping
WoosukKwon Jun 26, 2024
e25f470
Merge branch 'fix-tpu-swpa' into tpu-n
WoosukKwon Jun 26, 2024
ca6d1d6
yapf
WoosukKwon Jun 26, 2024
cd4f68d
Add Ray to TPU dependency
WoosukKwon Jun 26, 2024
5df4164
Merge branch 'main' into tpu-n
WoosukKwon Jun 26, 2024
546987a
Fix
WoosukKwon Jun 26, 2024
330be6e
Fix
WoosukKwon Jun 26, 2024
b45ed24
Merge branch 'main' into tpu-n
WoosukKwon Jun 29, 2024
8fab9fd
Add use_all_gather to LoRA
WoosukKwon Jun 29, 2024
c4cbe9f
Fix
WoosukKwon Jun 29, 2024
2871c7c
Merge branch 'main' into tpu-n
WoosukKwon Jun 30, 2024
db7adc7
Add an assert for dim == -1
WoosukKwon Jun 30, 2024
696790d
is_tpu -> use_xla
WoosukKwon Jun 30, 2024
8a08896
Merge branch 'main' into tpu-n
WoosukKwon Jun 30, 2024
36f9070
Merge branch 'main' into tpu-n
WoosukKwon Jul 1, 2024
28afe56
yapf
WoosukKwon Jul 2, 2024
60bf64d
Add hack in vocab
WoosukKwon Jul 2, 2024
0fbb050
Merge branch 'main' into tpu-n
WoosukKwon Jul 7, 2024
ddf4cbe
Merge branch 'main' into tpu-n
WoosukKwon Jul 7, 2024
cd4842d
Fix multi-modal support
WoosukKwon Jul 9, 2024
54e637b
Merge branch 'main' into tpu-n
WoosukKwon Jul 9, 2024
73ed611
Merge branch 'main' into tpu-n
WoosukKwon Jul 10, 2024
717b3fa
Merge branch 'main' into tpu-n
WoosukKwon Jul 15, 2024
6b0c35d
Merge branch 'main' into tpu-n
WoosukKwon Jul 17, 2024
7f583ba
Merge branch 'main' into tpu-n
WoosukKwon Jul 18, 2024
106864d
Remove unused
WoosukKwon Jul 18, 2024
223661f
Minor
WoosukKwon Jul 18, 2024
5bd67bc
Merge branch 'main' into tpu-n
WoosukKwon Jul 21, 2024
ab7cccf
Fix comm error
WoosukKwon Jul 21, 2024
4e0c90a
Use custom inference_mode
WoosukKwon Jul 21, 2024
a2358ed
Remove hack in vocab embedding
WoosukKwon Jul 21, 2024
ac21351
Use patch
WoosukKwon Jul 21, 2024
ba76d9e
Update inference_mode
WoosukKwon Jul 21, 2024
452c321
use_all_gather -> use_gather
WoosukKwon Jul 21, 2024
dcb63b7
Fix patch
WoosukKwon Jul 21, 2024
825cc44
Fix typo
WoosukKwon Jul 21, 2024
f27ef99
Merge branch 'main' into tpu-n
WoosukKwon Jul 22, 2024
9730288
Remove inference_mode
WoosukKwon Jul 22, 2024
631b08b
Add no_grad
WoosukKwon Jul 23, 2024
d65a7d0
Merge branch 'main' into tpu-n
WoosukKwon Jul 23, 2024
755fe0b
Merge branch 'main' into tpu-n
WoosukKwon Jul 24, 2024
d5fadfd
Merge branch 'main' into tpu-n
WoosukKwon Jul 26, 2024
af3a259
[TPU] Support collective communications in XLA devices
WoosukKwon Jul 26, 2024
0f2abea
Use current_platform
WoosukKwon Jul 26, 2024
8ebea7e
is_xla -> is_tpu
WoosukKwon Jul 26, 2024
782b182
Define TPU communicator
WoosukKwon Jul 26, 2024
76fd300
Merge branch 'main' into tpu-n
WoosukKwon Jul 26, 2024
75f842b
Merge branch 'add-xla-comm' into tpu-n
WoosukKwon Jul 26, 2024
8087227
Fix
WoosukKwon Jul 26, 2024
f04e179
Address comments
WoosukKwon Jul 26, 2024
f493c89
Device init
WoosukKwon Jul 26, 2024
f14b085
Fix patch
WoosukKwon Jul 26, 2024
1668582
Merge branch 'add-xla-comm' into tpu-n
WoosukKwon Jul 26, 2024
f9df97d
0726
WoosukKwon Jul 26, 2024
9994742
xr
WoosukKwon Jul 26, 2024
e0d3232
Add dynamic=True
WoosukKwon Jul 27, 2024
2f6f54f
Remove import
WoosukKwon Jul 27, 2024
8bb1159
yapf
WoosukKwon Jul 27, 2024
c11e129
Merge branch 'main' into upgrade-xla
WoosukKwon Jul 27, 2024
fafda57
Add comment & doc
WoosukKwon Jul 27, 2024
79c45d5
Minor
WoosukKwon Jul 27, 2024
4f0a23c
Minor
WoosukKwon Jul 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile.tpu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ARG NIGHTLY_DATE="20240713"
ARG NIGHTLY_DATE="20240726"
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"

FROM $BASE_IMAGE
Expand Down
9 changes: 8 additions & 1 deletion docs/source/getting_started/tpu-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ First, install the dependencies:
$ pip uninstall torch torch-xla -y

$ # Install PyTorch and PyTorch XLA.
$ export DATE="+20240713"
$ export DATE="+20240726"
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl

Expand All @@ -75,6 +75,13 @@ Next, build vLLM from source. This will only take a few seconds:
$ VLLM_TARGET_DEVICE="tpu" python setup.py develop


.. note::

Since TPU relies on XLA which requires static shapes, vLLM bucketizes the possible input shapes and compiles an XLA graph for each different shape.
The compilation time may take 20~30 minutes in the first run.
However, the compilation time reduces to ~5 minutes afterwards because the XLA graphs are cached in the disk (in :code:`VLLM_XLA_CACHE_PATH` or :code:`~/.cache/vllm/xla_cache` by default).


.. tip::

If you encounter the following error:
Expand Down
1 change: 0 additions & 1 deletion vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import torch
import torch_xla.experimental.custom_kernel # Required to register custom ops.
import torch_xla.experimental.dynamo_set_buffer_donor

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
Expand Down
3 changes: 2 additions & 1 deletion vllm/distributed/device_communicators/tpu_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

if current_platform.is_tpu():
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import pjrt


Expand All @@ -20,7 +21,7 @@ def __init__(self, group: ProcessGroup):
local_rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
pjrt.initialize_multiprocess(local_rank, world_size)
xm._init_world_size_ordinal()
xr._init_world_size_ordinal()

def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, x)
Expand Down
15 changes: 13 additions & 2 deletions vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr

from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
Expand Down Expand Up @@ -127,7 +128,7 @@ def load_model(self) -> None:
# determine the order of concatenating the output tensors.
# As a workaround, we use the xm's rank assignment only when loading
# the embedding weights.
xm_tp_rank = xm.get_ordinal()
xm_tp_rank = xr.global_ordinal()
with patch(
"vllm.model_executor.layers.vocab_parallel_embedding."
"get_tensor_model_parallel_rank",
Expand All @@ -146,7 +147,17 @@ def load_model(self) -> None:
xm.wait_device_ops()

model = ModelWrapper(model)
self.model = torch.compile(model, backend="openxla", fullgraph=True)
# NOTE(woosuk): There are two stages of compilation: torch.compile and
# XLA compilation. Setting dynamic=True can reduce the torch.compile
# overhead by reusing the FX graph for different shapes.
# However, the XLA graph will still require static shapes and needs to
# be re-compiled for every different shapes. This overhead is inevitable
# in the first run, but can be skipped afterwards as we cache the XLA
# graphs in the disk (VLLM_XLA_CACHE_PATH).
self.model = torch.compile(model,
backend="openxla",
fullgraph=True,
dynamic=True)

def _dummy_run(
self,
Expand Down
1 change: 0 additions & 1 deletion vllm/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import torch
import torch_xla.core.xla_model as xm
import torch_xla.experimental.dynamo_set_buffer_donor # noqa: F401
import torch_xla.runtime as xr

import vllm.envs as envs
Expand Down
Loading