Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Frontend] [Core] Support for sharded tensorized models (vllm-project…
Browse files Browse the repository at this point in the history
…#4990)

Signed-off-by: Travis Johnson <[email protected]>
Co-authored-by: Sanger Steel <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
  • Loading branch information
3 people authored and robertgshaw2-redhat committed Jun 16, 2024
1 parent b465102 commit c20bc35
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 109 deletions.
125 changes: 60 additions & 65 deletions examples/tensorize_vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,12 @@
import json
import os
import uuid
from functools import partial

from tensorizer import stream_io

from vllm import LLM
from vllm.distributed import (init_distributed_environment,
initialize_model_parallel)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.model_executor.model_loader.tensorizer import (TensorizerArgs,
TensorizerConfig,
serialize_vllm_model)
tensorize_vllm_model)

# yapf conflicts with isort for this docstring
# yapf: disable
Expand Down Expand Up @@ -61,6 +55,12 @@
You can also provide a `--keyfile` argument to decrypt the model weights if
they were serialized with encryption.
To support distributed tensor-parallel models, each model shard will be
serialized to a separate file. The tensorizer_uri is then specified as a string
template with a format specifier such as '%03d' that will be rendered with the
shard's rank. Sharded models serialized with this script will be named as
model-rank-%03d.tensors
For more information on the available arguments for serializing, run
`python -m examples.tensorize_vllm_model serialize --help`.
Expand Down Expand Up @@ -168,77 +168,72 @@ def parse_args():
def deserialize():
llm = LLM(model=args.model,
load_format="tensorizer",
tensor_parallel_size=args.tensor_parallel_size,
model_loader_extra_config=tensorizer_config
)
return llm


if __name__ == '__main__':
args = parse_args()

args = parse_args()

s3_access_key_id = (getattr(args, 's3_access_key_id', None)
or os.environ.get("S3_ACCESS_KEY_ID", None))
s3_secret_access_key = (getattr(args, 's3_secret_access_key', None)
or os.environ.get("S3_SECRET_ACCESS_KEY", None))
s3_endpoint = (getattr(args, 's3_endpoint', None)
or os.environ.get("S3_ENDPOINT_URL", None))

credentials = {
"s3_access_key_id": s3_access_key_id,
"s3_secret_access_key": s3_secret_access_key,
"s3_endpoint": s3_endpoint
}
s3_access_key_id = (getattr(args, 's3_access_key_id', None)
or os.environ.get("S3_ACCESS_KEY_ID", None))
s3_secret_access_key = (getattr(args, 's3_secret_access_key', None)
or os.environ.get("S3_SECRET_ACCESS_KEY", None))
s3_endpoint = (getattr(args, 's3_endpoint', None)
or os.environ.get("S3_ENDPOINT_URL", None))

_read_stream, _write_stream = (partial(
stream_io.open_stream,
mode=mode,
s3_access_key_id=s3_access_key_id,
s3_secret_access_key=s3_secret_access_key,
s3_endpoint=s3_endpoint,
) for mode in ("rb", "wb+"))
credentials = {
"s3_access_key_id": s3_access_key_id,
"s3_secret_access_key": s3_secret_access_key,
"s3_endpoint": s3_endpoint
}

model_ref = args.model
model_ref = args.model

model_name = model_ref.split("/")[1]
model_name = model_ref.split("/")[1]

os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "8080"
keyfile = args.keyfile if args.keyfile else None

init_distributed_environment(world_size=1, rank=0, local_rank=0)
initialize_model_parallel()
if args.model_loader_extra_config:
config = json.loads(args.model_loader_extra_config)
tensorizer_args = \
TensorizerConfig(**config)._construct_tensorizer_args()
tensorizer_args.tensorizer_uri = args.path_to_tensors
else:
tensorizer_args = None

keyfile = args.keyfile if args.keyfile else None
if args.command == "serialize":
eng_args_dict = {f.name: getattr(args, f.name) for f in
dataclasses.fields(EngineArgs)}

engine_args = EngineArgs.from_cli_args(
argparse.Namespace(**eng_args_dict)
)

if args.model_loader_extra_config:
config = json.loads(args.model_loader_extra_config)
tensorizer_args = TensorizerConfig(**config)._construct_tensorizer_args()
tensorizer_args.tensorizer_uri = args.path_to_tensors
else:
tensorizer_args = None

if args.command == "serialize":
eng_args_dict = {f.name: getattr(args, f.name) for f in
dataclasses.fields(EngineArgs)}

engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
engine = LLMEngine.from_engine_args(engine_args)
input_dir = args.serialized_directory.rstrip('/')
suffix = args.suffix if args.suffix else uuid.uuid4().hex
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
if engine_args.tensor_parallel_size > 1:
model_path = f"{base_path}/model-rank-%03d.tensors"
else:
model_path = f"{base_path}/model.tensors"

input_dir = args.serialized_directory.rstrip('/')
suffix = args.suffix if args.suffix else uuid.uuid4().hex
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
model_path = f"{base_path}/model.tensors"
tensorizer_config = TensorizerConfig(
tensorizer_uri=model_path,
**credentials)
serialize_vllm_model(engine, tensorizer_config, keyfile)
elif args.command == "deserialize":
if not tensorizer_args:
tensorizer_config = TensorizerConfig(
tensorizer_uri=args.path_to_tensors,
encryption_keyfile = keyfile,
**credentials
)
deserialize()
else:
raise ValueError("Either serialize or deserialize must be specified.")
tensorizer_uri=model_path,
encryption_keyfile=keyfile,
**credentials)

tensorize_vllm_model(engine_args, tensorizer_config)

elif args.command == "deserialize":
if not tensorizer_args:
tensorizer_config = TensorizerConfig(
tensorizer_uri=args.path_to_tensors,
encryption_keyfile = keyfile,
**credentials
)
deserialize()
else:
raise ValueError("Either serialize or deserialize must be specified.")
103 changes: 92 additions & 11 deletions tests/tensorizer_loader/test_tensorizer.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
import json
import os
import pathlib
import subprocess
from unittest.mock import MagicMock, patch

import openai
import pytest
import ray
import torch
from tensorizer import EncryptionParams

from tests.nm_utils.utils_skip import should_skip_test_group
from tests.utils import ServerRunner
from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs
# yapf: disable
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
TensorSerializer,
is_vllm_tensorized,
load_with_tensorizer,
open_stream,
serialize_vllm_model)
serialize_vllm_model,
tensorize_vllm_model)

from ..conftest import VllmRunner, cleanup
from ..utils import ServerRunner

# yapf conflicts with isort for this docstring

Expand Down Expand Up @@ -46,6 +54,20 @@ def is_curl_installed():
except (subprocess.CalledProcessError, FileNotFoundError):
return False

def get_torch_model(vllm_runner: VllmRunner):
return vllm_runner \
.model \
.llm_engine \
.model_executor \
.driver_worker \
.model_runner \
.model

def write_keyfile(keyfile_path: str):
encryption_params = EncryptionParams.random()
pathlib.Path(keyfile_path).parent.mkdir(parents=True, exist_ok=True)
with open(keyfile_path, 'wb') as f:
f.write(encryption_params.key)

@pytest.fixture(autouse=True)
def tensorizer_config():
Expand Down Expand Up @@ -98,12 +120,17 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
with vllm_runner(model_ref) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors")
key_path = tmp_path / (model_ref + ".key")
write_keyfile(key_path)

outputs = vllm_model.generate(prompts, sampling_params)

config_for_serializing = TensorizerConfig(tensorizer_uri=model_path)
serialize_vllm_model(vllm_model.model.llm_engine,
config_for_serializing,
encryption_key_path=key_path)
config_for_serializing = TensorizerConfig(
tensorizer_uri=model_path,
encryption_keyfile=key_path
)
serialize_vllm_model(get_torch_model(vllm_model),
config_for_serializing)


config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
encryption_keyfile=key_path)
Expand Down Expand Up @@ -155,7 +182,7 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
with vllm_runner(model_ref, ) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors")

serialize_vllm_model(vllm_model.model.llm_engine,
serialize_vllm_model(get_torch_model(vllm_model),
TensorizerConfig(tensorizer_uri=model_path))

with vllm_runner(
Expand Down Expand Up @@ -190,7 +217,7 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
with vllm_runner(model_ref, ) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors")

serialize_vllm_model(vllm_model.model.llm_engine,
serialize_vllm_model(get_torch_model(vllm_model),
TensorizerConfig(tensorizer_uri=model_path))

model_loader_extra_config = {
Expand Down Expand Up @@ -234,9 +261,9 @@ def test_raise_value_error_on_invalid_load_format(vllm_runner):
model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))


@pytest.mark.skip("Failing in Automation due to "
"'NameError: name 'ncclGetVersion' is not defined'")
def test_tensorizer_with_tp(vllm_runner):
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Requires 2 GPUs")
def test_tensorizer_with_tp_path_without_template(vllm_runner):
with pytest.raises(ValueError):
model_ref = "EleutherAI/pythia-1.4b"
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
Expand All @@ -250,8 +277,62 @@ def test_tensorizer_with_tp(vllm_runner):
s3_endpoint="object.ord1.coreweave.com",
),
tensor_parallel_size=2,
disable_custom_all_reduce=True,
)

@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Requires 2 GPUs")
def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
tmp_path):
model_ref = "EleutherAI/pythia-1.4b"
# record outputs from un-sharded un-tensorized model
base_model = vllm_runner(
model_ref,
disable_custom_all_reduce=True,
enforce_eager=True,
)
outputs = base_model.generate(prompts, sampling_params)

base_model.model.llm_engine.model_executor.shutdown()
del base_model
cleanup()
ray.shutdown()

# load model with two shards and serialize with encryption
model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
key_path = tmp_path / (model_ref + ".key")

tensorizer_config = TensorizerConfig(
tensorizer_uri=model_path,
encryption_keyfile=key_path,
)

tensorize_vllm_model(
engine_args=EngineArgs(
model=model_ref,
tensor_parallel_size=2,
disable_custom_all_reduce=True,
enforce_eager=True,
),
tensorizer_config=tensorizer_config,
)
assert os.path.isfile(model_path % 0), "Serialization subprocess failed"
assert os.path.isfile(model_path % 1), "Serialization subprocess failed"
cleanup()
ray.shutdown()

loaded_vllm_model = vllm_runner(
model_ref,
tensor_parallel_size=2,
load_format="tensorizer",
disable_custom_all_reduce=True,
enforce_eager=True,
model_loader_extra_config=tensorizer_config)

deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)

assert outputs == deserialized_outputs


def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
model_ref = "facebook/opt-125m"
Expand All @@ -260,7 +341,7 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):

with vllm_runner(model_ref) as vllm_model:
outputs = vllm_model.generate(prompts, sampling_params)
serialize_vllm_model(vllm_model.model.llm_engine, config)
serialize_vllm_model(get_torch_model(vllm_model), config)

assert is_vllm_tensorized(config)

Expand Down
18 changes: 17 additions & 1 deletion vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
QuantizationConfig)
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
tensorizer_weights_iterator)
serialize_vllm_model, tensorizer_weights_iterator)
from vllm.model_executor.model_loader.utils import (get_model_architecture,
set_default_torch_dtype)
# UPSTREAM SYNC: needed for sparsity
Expand Down Expand Up @@ -413,6 +413,12 @@ def load_model(self, *, model_config: ModelConfig,
cache_config: CacheConfig) -> nn.Module:
self._verify_config(model_config, parallel_config)

if parallel_config.tensor_parallel_size > 1:
from vllm.distributed import get_tensor_model_parallel_rank
self.tensorizer_config.tensorizer_uri = \
self.tensorizer_config.tensorizer_uri \
% get_tensor_model_parallel_rank()

if is_vllm_tensorized(self.tensorizer_config):
return self._load_model_serialized(model_config, device_config,
lora_config,
Expand All @@ -423,6 +429,16 @@ def load_model(self, *, model_config: ModelConfig,
vision_language_config,
cache_config)

@staticmethod
def save_model(
model: torch.nn.Module,
tensorizer_config: TensorizerConfig,
) -> None:
serialize_vllm_model(
model=model,
tensorizer_config=tensorizer_config,
)


class ShardedStateLoader(BaseModelLoader):
"""
Expand Down
Loading

0 comments on commit c20bc35

Please sign in to comment.