diff --git a/Dockerfile.tpu b/Dockerfile.tpu index be7dbe63cb237..4fc14d6bd186c 100644 --- a/Dockerfile.tpu +++ b/Dockerfile.tpu @@ -1,4 +1,4 @@ -ARG NIGHTLY_DATE="20240713" +ARG NIGHTLY_DATE="20240726" ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE" FROM $BASE_IMAGE diff --git a/docs/source/getting_started/tpu-installation.rst b/docs/source/getting_started/tpu-installation.rst index 5e2f514a4a509..2e6c522422c22 100644 --- a/docs/source/getting_started/tpu-installation.rst +++ b/docs/source/getting_started/tpu-installation.rst @@ -56,7 +56,7 @@ First, install the dependencies: $ pip uninstall torch torch-xla -y $ # Install PyTorch and PyTorch XLA. - $ export DATE="+20240713" + $ export DATE="+20240726" $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl @@ -75,6 +75,13 @@ Next, build vLLM from source. This will only take a few seconds: $ VLLM_TARGET_DEVICE="tpu" python setup.py develop +.. note:: + + Since TPU relies on XLA which requires static shapes, vLLM bucketizes the possible input shapes and compiles an XLA graph for each different shape. + The compilation time may take 20~30 minutes in the first run. + However, the compilation time reduces to ~5 minutes afterwards because the XLA graphs are cached in the disk (in :code:`VLLM_XLA_CACHE_PATH` or :code:`~/.cache/vllm/xla_cache` by default). + + .. tip:: If you encounter the following error: diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index c53a2f91b89d7..2269ac2606e89 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -3,7 +3,6 @@ import torch import torch_xla.experimental.custom_kernel # Required to register custom ops. -import torch_xla.experimental.dynamo_set_buffer_donor from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 69a9a516f3ebe..16525887cf4eb 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -6,6 +6,7 @@ if current_platform.is_tpu(): import torch_xla.core.xla_model as xm + import torch_xla.runtime as xr from torch_xla._internal import pjrt @@ -20,7 +21,7 @@ def __init__(self, group: ProcessGroup): local_rank = dist.get_rank(group) world_size = dist.get_world_size(group) pjrt.initialize_multiprocess(local_rank, world_size) - xm._init_world_size_ordinal() + xr._init_world_size_ordinal() def all_reduce(self, x: torch.Tensor) -> torch.Tensor: return xm.all_reduce(xm.REDUCE_SUM, x) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index e5bb101fc7df4..1692094af8c41 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, @@ -127,7 +128,7 @@ def load_model(self) -> None: # determine the order of concatenating the output tensors. # As a workaround, we use the xm's rank assignment only when loading # the embedding weights. - xm_tp_rank = xm.get_ordinal() + xm_tp_rank = xr.global_ordinal() with patch( "vllm.model_executor.layers.vocab_parallel_embedding." "get_tensor_model_parallel_rank", @@ -146,7 +147,17 @@ def load_model(self) -> None: xm.wait_device_ops() model = ModelWrapper(model) - self.model = torch.compile(model, backend="openxla", fullgraph=True) + # NOTE(woosuk): There are two stages of compilation: torch.compile and + # XLA compilation. Setting dynamic=True can reduce the torch.compile + # overhead by reusing the FX graph for different shapes. + # However, the XLA graph will still require static shapes and needs to + # be re-compiled for every different shapes. This overhead is inevitable + # in the first run, but can be skipped afterwards as we cache the XLA + # graphs in the disk (VLLM_XLA_CACHE_PATH). + self.model = torch.compile(model, + backend="openxla", + fullgraph=True, + dynamic=True) def _dummy_run( self, diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index c88aba7ae08cd..17fa5c35457c2 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -3,7 +3,6 @@ import torch import torch_xla.core.xla_model as xm -import torch_xla.experimental.dynamo_set_buffer_donor # noqa: F401 import torch_xla.runtime as xr import vllm.envs as envs