Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use dtype from model config & Add Dolly V2 #63

Merged
merged 3 commits into from
May 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion cacheflow/master/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,11 @@ def add_server_arguments(parser: argparse.ArgumentParser):
help='save a numpy copy of model weights for faster loading')
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
# NOTE(woosuk): FlashAttention does not support float32.
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'bfloat16'], help='data type')
parser.add_argument('--dtype', type=str, default='default', choices=['default', 'half', 'bfloat16'],
help=('data type for model weights and activations. '
'The "default" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'))
# Parallel arguments
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
Expand Down
34 changes: 28 additions & 6 deletions cacheflow/models/model_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Union, Optional
from typing import Optional

import torch
import torch.nn as nn
from transformers import AutoConfig
from transformers import PretrainedConfig

from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
from cacheflow.models.memory_analyzer import GPT2MemoryAnalyzer
Expand All @@ -22,6 +23,7 @@
'opt': OPTForCausalLM,
'stablelm': GPTNeoXForCausalLM,
'pythia': GPTNeoXForCausalLM,
'dolly-v2': GPTNeoXForCausalLM,
}

_MEMORY_ANALYZERS = {
Expand All @@ -30,19 +32,38 @@
'opt': OPTMemoryAnalyzer,
'stablelm': GPTNeoXMemoryAnalyzer,
'pythia': GPTNeoXMemoryAnalyzer,
'dolly-v2': GPTNeoXMemoryAnalyzer,
}


def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
config_dtype: torch.dtype = getattr(config, 'torch_dtype', torch.float32)
if dtype == 'default':
if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32 models.
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
else:
torch_dtype = get_torch_dtype(dtype)
if torch_dtype != config_dtype and config_dtype != torch.float32:
# TODO(woosuk): Allow using float16 for bfloat16 models and
# vice versa. Print a warning message and continue.
raise ValueError(
f'Cannot use {torch_dtype} for {config_dtype} model.')
return torch_dtype


def get_model(
model_name: str,
dtype: Union[torch.dtype, str],
dtype: str,
cache_dir: Optional[str],
use_dummy_weights: bool,
use_np_cache: bool,
) -> nn.Module:
torch_dtype = get_torch_dtype(dtype)
torch.set_default_dtype(torch_dtype)
config = AutoConfig.from_pretrained(model_name)
torch_dtype = _get_dtype(config, dtype)
torch.set_default_dtype(torch_dtype)
for model_class_name, model_class in _MODELS.items():
if model_class_name in model_name:
if use_dummy_weights:
Expand All @@ -66,12 +87,13 @@ def get_model(
def get_memory_analyzer(
model_name: str,
block_size: int,
dtype: Union[torch.dtype, str],
dtype: str,
gpu_memory: int,
cpu_memory: int,
tensor_parallel_size: int = 1,
) -> CacheFlowMemoryAnalyzer:
torch_dtype = get_torch_dtype(dtype)
config = AutoConfig.from_pretrained(model_name)
torch_dtype = _get_dtype(config, dtype)
for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
if model_class in model_name:
return memory_analyzer(
Expand Down