Skip to content

Commit

Permalink
New weight loader without np copy (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuohan123 authored May 3, 2023
1 parent 4858f3b commit 27f1410
Show file tree
Hide file tree
Showing 12 changed files with 289 additions and 357 deletions.
46 changes: 4 additions & 42 deletions benchmark/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,53 +6,15 @@
import numpy as np
import torch

from cacheflow.master.simple_frontend import SimpleFrontend
from cacheflow.master.server import (Server, add_server_arguments,
process_server_arguments,
initialize_cluster)
from cacheflow.master.server import (
add_server_arguments, process_server_arguments,
init_local_server_and_frontend_with_arguments)
from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import get_gpu_memory, get_cpu_memory


def main(args: argparse.Namespace):
# TODO(zhuohan): Support pipeline parallelism.
assert args.pipeline_parallel_size == 1, (
'Pipeline parallelism is not supported yet.')
server, frontend = init_local_server_and_frontend_with_arguments(args)

(num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices) = (
initialize_cluster(
use_ray=args.use_ray,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))

# Create a server.
server = Server(
model=args.model,
model_path=args.model_path,
use_dummy_weights=args.use_dummy_weights,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
block_size=args.block_size,
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
max_num_batched_tokens=args.max_num_batched_tokens,
max_num_sequences=args.max_num_sequences,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
use_ray=args.use_ray,
)

# Create a frontend.
frontend = SimpleFrontend(
model_name=args.model,
block_size=args.block_size,
)
sampling_params_dict = {
'n': args.n,
'temperature': 0.0 if args.use_beam_search else 1.0,
Expand Down
47 changes: 4 additions & 43 deletions benchmark/benchmark_text_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,57 +9,18 @@
from transformers import AutoConfig

from benchmark.trace import generate_text_completion_requests
from cacheflow.master.simple_frontend import SimpleFrontend
from cacheflow.master.server import (Server, add_server_arguments,
process_server_arguments,
initialize_cluster)
from cacheflow.master.server import (
add_server_arguments, process_server_arguments,
init_local_server_and_frontend_with_arguments)
from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import get_gpu_memory, get_cpu_memory


logger = logging.getLogger(__name__)


def main(args: argparse.Namespace):
assert args.pipeline_parallel_size == 1, (
'Pipeline parallelism is not supported yet.')
server, frontend = init_local_server_and_frontend_with_arguments(args)

(num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices) = (
initialize_cluster(
use_ray=args.use_ray,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))

# Create a server.
server = Server(
model=args.model,
model_path=args.model_path,
use_dummy_weights=args.use_dummy_weights,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
block_size=args.block_size,
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
max_num_batched_tokens=args.max_num_batched_tokens,
max_num_sequences=args.max_num_sequences,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
use_ray=args.use_ray,
collect_stats=True,
do_memory_analysis=args.do_memory_analysis,
)

# Create a frontend.
frontend = SimpleFrontend(
model_name=args.model,
block_size=args.block_size,
)
# Generate requests.
requests = generate_text_completion_requests(
args.dataset,
Expand Down
17 changes: 10 additions & 7 deletions cacheflow/http_frontend/fastapi_frontend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import asyncio
import time
from typing import List, Dict
from typing import List, Dict, Optional
import json

import ray
Expand All @@ -22,11 +22,12 @@
app = FastAPI()


class FastAPIFrontend:
class FastAPIServer:
def __init__(
self,
model: str,
model_path: str,
cache_dir: Optional[str],
use_np_cache: bool,
pipeline_parallel_size: int,
tensor_parallel_size: int,
block_size: int,
Expand All @@ -52,8 +53,9 @@ def __init__(
remote_server_class = ray.remote(num_gpus=1)(Server)
self.server = remote_server_class.remote(
model=model,
model_path=model_path,
cache_dir=cache_dir,
use_dummy_weights=False,
use_np_cache=use_np_cache,
pipeline_parallel_size=pipeline_parallel_size,
tensor_parallel_size=tensor_parallel_size,
block_size=block_size,
Expand Down Expand Up @@ -148,7 +150,7 @@ async def generate(self, request_dict: Dict):
@app.post("/generate")
async def generate_stream(request: Request):
request_dict = await request.json()
return StreamingResponse(frontend.generate(request_dict))
return StreamingResponse(server.generate(request_dict))


if __name__ == "__main__":
Expand All @@ -170,9 +172,10 @@ async def generate_stream(request: Request):
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))

frontend = FastAPIFrontend(
server = FastAPIServer(
model=args.model,
model_path=args.model_path,
cache_dir=args.cache_dir,
use_np_cache=args.use_np_cache,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
block_size=args.block_size,
Expand Down
65 changes: 58 additions & 7 deletions cacheflow/master/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,21 @@
ray = None

from cacheflow.master.scheduler import Scheduler
from cacheflow.master.simple_frontend import SimpleFrontend
from cacheflow.models import get_memory_analyzer
from cacheflow.worker.controller import Controller, DeviceID
from cacheflow.sequence import SequenceGroup
from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import get_gpu_memory, get_cpu_memory


class Server:
def __init__(
self,
model: str,
model_path: str,
cache_dir: Optional[str],
use_dummy_weights: bool,
use_np_cache: bool,
pipeline_parallel_size: int,
tensor_parallel_size: int,
block_size: int,
Expand Down Expand Up @@ -78,8 +81,9 @@ def __init__(
num_cpu_blocks=self.num_cpu_blocks,
dtype=dtype,
seed=seed,
model_path=model_path,
cache_dir=cache_dir,
use_dummy_weights=use_dummy_weights,
use_np_cache=use_np_cache,
max_num_batched_tokens=max_num_batched_tokens,
use_ray=use_ray,
)
Expand Down Expand Up @@ -203,25 +207,72 @@ def initialize_cluster(
def add_server_arguments(parser: argparse.ArgumentParser):
# Model arguments
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights',
help='model path to download and load the weights')
parser.add_argument('--cache-dir', type=str, default=None,
help='cache dir to download and load the weights, '
'default to the default cache dir of huggingface')
parser.add_argument('--use-np-cache', action='store_true',
help='save a numpy copy of model weights for faster loading')
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser.add_argument('--dtype', type=str, default='half', choices=['half'], help='data type')
# Parallel arguments
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
# KV cache arguments
parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size')
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser.add_argument('--dtype', type=str, default='half', choices=['half'], help='data type')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration')
parser.add_argument('--max-num-sequences', type=int, default=256, help='maximum number of sequences per iteration')
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
return parser


def process_server_arguments(args: argparse.Namespace):
if args.pipeline_parallel_size * args.tensor_parallel_size > 1:
args.use_ray = True
return args


def init_local_server_and_frontend_with_arguments(args: argparse.Namespace):
# TODO(zhuohan): Support pipeline parallelism.
assert args.pipeline_parallel_size == 1, (
'Pipeline parallelism is not supported yet.')

(num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices) = (
initialize_cluster(
use_ray=args.use_ray,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))

# Create a server.
server = Server(
model=args.model,
cache_dir=args.cache_dir,
use_dummy_weights=args.use_dummy_weights,
use_np_cache=args.use_np_cache,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
block_size=args.block_size,
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
max_num_batched_tokens=args.max_num_batched_tokens,
max_num_sequences=args.max_num_sequences,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
use_ray=args.use_ray,
)

# Create a frontend.
frontend = SimpleFrontend(
model_name=args.model,
block_size=args.block_size,
)
return server, frontend
74 changes: 15 additions & 59 deletions cacheflow/models/gpt_neox.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
"""1D GPT-NeoX model compatible with HuggingFace weights."""
import os
import glob
import filelock
from tqdm import tqdm
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
from torch import nn
from huggingface_hub import snapshot_download

from cacheflow.models import InputMetadata
from cacheflow.models.attention import GPTNeoXCacheFlowAttention
from cacheflow.models.sample import Sampler
from cacheflow.models.utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from cacheflow.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
Expand Down Expand Up @@ -196,17 +192,22 @@ def forward(
_column_parallel_weights = ["embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"]
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]

def load_weights(self, weights_path: str):
def load_weights(self, model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, param in state_dict.items():
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if ("attention.bias" in name or "attention.masked_bias" in name
or "rotary_emb.inv_freq" in name):
continue
param = state_dict[name]
if "query_key_value" in name:
# NOTE(woosuk): GPT-NeoX's fused QKV has the shape of
# [num_heads * 3 * head_size, num_heads * head_size], while the
# required shape is [3 * num_heads * head_size, num_heads * head_size].
# Thus, we need weight conversion.
loaded_weight = torch.from_numpy(
np.load(os.path.join(weights_path, name)))
shard_size = param.shape[0]
loaded_weight = loaded_weight[shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
Expand All @@ -223,55 +224,10 @@ def load_weights(self, weights_path: str):
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1).contiguous()
else:
assert False
else:
loaded_weight = torch.from_numpy(
np.load(os.path.join(weights_path, name)))
for p in self._column_parallel_weights:
if p in name:
shard_size = param.shape[0]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
for p in self._row_parallel_weights:
if p in name:
shard_size = param.shape[1]
loaded_weight = loaded_weight[
:,
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break

assert param.shape == loaded_weight.shape
param.data.copy_(loaded_weight)

@staticmethod
def get_weights(model_name: str, path: str):
path = os.path.join(path, f"{model_name}-np")
path = os.path.abspath(os.path.expanduser(path))
os.makedirs(path, exist_ok=True)
lock_path = os.path.join(path, "file_lock")
lock = filelock.FileLock(lock_path)

with lock:
test_weight_path = os.path.join(
path, "gpt_neox.embed_in.weight")
if os.path.exists(test_weight_path):
return path

folder = snapshot_download(model_name, allow_patterns="*.bin",
cache_dir=os.path.join(path, "cache"))
bin_files = glob.glob(os.path.join(folder, "*.bin"))

for bin_file in tqdm(bin_files, desc="Convert format"):
state = torch.load(bin_file, map_location="cpu")
for name, param in tqdm(state.items(), leave=False):
param_path = os.path.join(path, name)
with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy())

return path
raise ValueError(f"Unexpected weight name: {name}")
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights)

def initialize_dummy_weights(self) -> None:
for param in self.state_dict().values():
Expand Down
Loading

0 comments on commit 27f1410

Please sign in to comment.