Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] [Core] Support for sharded tensorized models #4990

Merged
merged 15 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions examples/tensorize_vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
import json
import os
import uuid
from functools import partial

from tensorizer import stream_io

from vllm import LLM
from vllm.engine.arg_utils import EngineArgs
Expand Down Expand Up @@ -58,6 +55,12 @@
You can also provide a `--keyfile` argument to decrypt the model weights if
they were serialized with encryption.

To support distributed tensor-parallel models, each model shard will be
serialized to a separate file. The tensorizer_uri is then specified as a string
template with a format specifier such as '%03d' that will be rendered with the
shard's rank. Sharded models serialized with this script will be named as
model-rank-%03d.tensors

For more information on the available arguments for serializing, run
`python -m examples.tensorize_vllm_model serialize --help`.

Expand Down Expand Up @@ -165,6 +168,7 @@ def parse_args():
def deserialize():
llm = LLM(model=args.model,
load_format="tensorizer",
tensor_parallel_size=args.tensor_parallel_size,
model_loader_extra_config=tensorizer_config
)
return llm
Expand All @@ -186,14 +190,6 @@ def deserialize():
"s3_endpoint": s3_endpoint
}

_, _write_stream = (partial(
stream_io.open_stream,
mode=mode,
s3_access_key_id=s3_access_key_id,
s3_secret_access_key=s3_secret_access_key,
s3_endpoint=s3_endpoint,
) for mode in ("rb", "wb+"))

model_ref = args.model

model_name = model_ref.split("/")[1]
Expand Down
41 changes: 18 additions & 23 deletions tests/tensorizer_loader/test_tensorizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import gc
import json
import multiprocessing as mp
import os
import pathlib
import subprocess
Expand Down Expand Up @@ -264,7 +263,8 @@ def test_raise_value_error_on_invalid_load_format(vllm_runner):
model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))


@pytest.mark.skip("Requires multiple GPUs")
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Test requires 2 GPUs")
def test_tensorizer_with_tp_path_without_template(vllm_runner):
with pytest.raises(ValueError):
model_ref = "EleutherAI/pythia-1.4b"
Expand All @@ -279,62 +279,57 @@ def test_tensorizer_with_tp_path_without_template(vllm_runner):
s3_endpoint="object.ord1.coreweave.com",
),
tensor_parallel_size=2,
disable_custom_all_reduce=True,
)

@pytest.mark.skip("Requires multiple GPUs")
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Test requires 2 GPUs")
def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
tmp_path):
model_ref = "EleutherAI/pythia-1.4b"
# record outputs from un-sharded un-tensorized model
base_model = vllm_runner(
model_ref,
disable_custom_all_reduce=True,
enforce_eager=True,
)
outputs = base_model.generate(prompts, sampling_params)

base_model.model.llm_engine.model_executor.shutdown()
del base_model
cleanup()
ray.shutdown()

# load model with two shards and serialize with encryption
model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
key_path = tmp_path / (model_ref + ".key")

config_for_serializing = TensorizerConfig(
tensorizer_config = TensorizerConfig(
tensorizer_uri=model_path,
encryption_keyfile=key_path,
)
# FIXME: launching multiple multiprocessing servers within the same program
# results in a hang... launch serialization in a separate process as a work
# around
serialization_proc = mp.get_context('spawn').Process(
target=tensorize_vllm_model,
kwargs={
"engine_args": EngineArgs(

tensorize_vllm_model(
engine_args=EngineArgs(
model=model_ref,
tensor_parallel_size=2,
disable_custom_all_reduce=True,
enforce_eager=True,
),
"tensorizer_config": config_for_serializing,
},
tensorizer_config=tensorizer_config,
)
serialization_proc.start()
serialization_proc.join()
assert os.path.isfile(model_path % 0), "Serialization subprocess failed"
assert os.path.isfile(model_path % 1), "Serialization subprocess failed"


config_for_deserializing = TensorizerConfig(
tensorizer_uri=model_path,
encryption_keyfile=key_path,
)
cleanup()
ray.shutdown()

loaded_vllm_model = vllm_runner(
model_ref,
tensor_parallel_size=2,
load_format="tensorizer",
disable_custom_all_reduce=True,
enforce_eager=True,
distributed_executor_backend="ray",
model_loader_extra_config=config_for_deserializing)
model_loader_extra_config=tensorizer_config)

deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)

Expand Down
46 changes: 25 additions & 21 deletions vllm/model_executor/model_loader/tensorizer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import argparse
import dataclasses
import io
import os
import re
import time
from dataclasses import dataclass
from functools import partial
from typing import Generator, Optional, Tuple, Type, Union
from typing import BinaryIO, Generator, Optional, Tuple, Type, Union

import torch
from torch import nn
Expand All @@ -14,7 +16,6 @@
from vllm.config import ModelConfig, ParallelConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.executor.distributed_gpu_executor import DistributedGPUExecutor
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
Expand Down Expand Up @@ -59,14 +60,13 @@ class TensorizerConfig:
model_class: Optional[Type[torch.nn.Module]] = None
hf_config: Optional[PretrainedConfig] = None
dtype: Optional[Union[str, torch.dtype]] = None
_is_sharded: bool = False

def __post_init__(self):
if not isinstance(self.tensorizer_uri, str):
try:
self.tensorizer_uri = str(self.tensorizer_uri)
except Exception as exc:
raise ValueError(
"tensorizer_uri must be convertible to a string") from exc
# check if the configuration is for a sharded model
if isinstance(self.tensorizer_uri, str) \
and re.search(r'%0\dd', self.tensorizer_uri):
self._is_sharded = True

def _construct_tensorizer_args(self) -> "TensorizerArgs":
tensorizer_args = {
Expand All @@ -85,16 +85,14 @@ def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
) -> None:
if (parallel_config.tensor_parallel_size == 1):
return

if (uri := self.tensorizer_uri) is not None:
rank_format_match = re.search(r'%0\dd', uri)
if not rank_format_match:
raise ValueError(
"For a sharded model, Tensorizer URI should include a"
" string format template like '%04d' to be formatted"
" with the rank of the shard")
# tensorizer_uri is used for a vLLM serialized model
if self.tensorizer_uri \
and parallel_config.tensor_parallel_size > 1 \
and not self._is_sharded:
raise ValueError(
"For a sharded model, tensorizer_uri should include a"
" string format template like '%04d' to be formatted"
" with the rank of the shard")

def verify_with_model_config(self, model_config: "ModelConfig") -> None:
if (model_config.quantization is not None
Expand All @@ -112,7 +110,8 @@ def load_with_tensorizer(tensorizer_config: TensorizerConfig,

@dataclass
class TensorizerArgs:
tensorizer_uri: str
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str,
bytes, os.PathLike, int]
vllm_tensorized: Optional[bool] = False
verify_hash: Optional[bool] = False
num_readers: Optional[int] = None
Expand Down Expand Up @@ -426,7 +425,7 @@ def serialize_vllm_model(
encryption_params = EncryptionParams(key=key)

output_file = tensorizer_args.tensorizer_uri
if re.search(r'%0\dd', output_file):
if tensorizer_config._is_sharded:
from vllm.distributed import get_tensor_model_parallel_rank
output_file = output_file % get_tensor_model_parallel_rank()

Expand All @@ -446,6 +445,11 @@ def tensorize_vllm_model(engine_args: EngineArgs,
Intended to be used separately from running a vLLM server since it
creates its own Engine instance.
"""
engine_config = engine_args.create_engine_config()
tensorizer_config.verify_with_model_config(engine_config.model_config)
tensorizer_config.verify_with_parallel_config(
engine_config.parallel_config)

# generate the encryption key before creating the engine to support sharding
if generate_keyfile and (keyfile :=
tensorizer_config.encryption_keyfile) is not None:
Expand All @@ -459,7 +463,7 @@ def tensorize_vllm_model(engine_args: EngineArgs,
stream.write(encryption_params.key)

engine = LLMEngine.from_engine_args(engine_args)
if isinstance(engine.model_executor, DistributedGPUExecutor):
if tensorizer_config._is_sharded:
# if the engine is a distributed engine (for tensor parallel) then each
# worker shard needs to serialize its part of the model.
engine.model_executor._run_workers(
Expand Down
Loading