Skip to content

Commit

Permalink
Add model config for known models
Browse files Browse the repository at this point in the history
  • Loading branch information
GregoryComer committed Apr 18, 2024
1 parent b7f7512 commit da23171
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 63 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,11 @@ jobs:
cat ./output_eager2
echo "Tests complete."
- name: Test download
run: |
python torchchat.py generate stories15M
test-tinystories-eager:
strategy:
matrix:
Expand Down
38 changes: 15 additions & 23 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from sentencepiece import SentencePieceProcessor

from build.model import model_aliases, Transformer
from build.model import resolve_model_config, Transformer


@dataclass
Expand Down Expand Up @@ -67,12 +67,14 @@ def __post_init__(self):

@classmethod
def from_args(cls, args): # -> BuilderArgs:
model = resolve_model_name(args.model) if args.model else None
checkpoint_path = (
Path(args.model_directory) / model / "model.pth"
if model and not args.checkpoint_path
else args.checkpoint_path
)
checkpoint_path = args.checkpoint_path

if args.model: # Using a named, well-known model
model_config, model_name = resolve_model_config(args.model)

checkpoint_path = (
Path(args.model_directory) / model_name / model_config.checkpoint_file
)

is_chat_model = False
if args.is_chat_model:
Expand Down Expand Up @@ -130,13 +132,12 @@ class TokenizerArgs:
def from_args(cls, args): # -> TokenizerArgs:
is_SentencePiece = True
is_TikToken = False
checkpoint_dir = args.checkpoint_dir

model = resolve_model_name(args.model) if args.model else None
checkpoint_dir = (
Path(args.model_directory) / model
if not args.checkpoint_dir and args.model
else args.checkpoint_dir
)
if args.model: # Using a named, well-known model
_, model_name = resolve_model_config(args.model)

checkpoint_dir = Path(args.model_directory) / model_name

if args.tokenizer_path:
tokenizer_path = args.tokenizer_path
Expand Down Expand Up @@ -202,6 +203,7 @@ def _set_gguf_kwargs(builder_args, is_et, context: str):
if is_et:
builder_args.gguf_kwargs["load_as_quantized"] = False


def _unset_gguf_kwargs(builder_args):
builder_args.gguf_kwargs = None

Expand All @@ -224,8 +226,6 @@ def _load_model_default(builder_args):
model = Transformer.from_params(builder_args.params_path)
elif builder_args.params_table:
model = Transformer.from_table(builder_args.params_path)
elif builder_args.checkpoint_dir:
model = Transformer.from_name(builder_args.checkpoint_dir.name)
else:
model = Transformer.from_name(builder_args.checkpoint_path.parent.name)

Expand Down Expand Up @@ -354,11 +354,3 @@ def _initialize_model(
model.to(dtype=builder_args.precision)

return model


def resolve_model_name(model: str) -> str:
# If the provided model name is an alias, retrieve the full path.
if model in model_aliases:
return model_aliases[model]
else:
return model
3 changes: 2 additions & 1 deletion build/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,9 @@ def permute(w, n_heads):
del final_result[key]
del final_result[key.replace("wq", "wk")]
del final_result[key.replace("wq", "wv")]
print(f"Saving checkpoint to {model_dir / 'model.pth'}")
print(f"Saving checkpoint to {model_dir / 'model.pth'}...")
torch.save(final_result, model_dir / "model.pth")
print("Done.")

if remove_bin_files:
for file in bin_files:
Expand Down
6 changes: 5 additions & 1 deletion build/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@ def load_model(gguf_file: str) -> torch.nn.Module:


def load_model_and_state_dict(
gguf_file: str, load_as_quantized: bool, *, inner_k_tiles=8
gguf_file: str,
*,
load_state_dict: bool = True,
load_as_quantized: bool = True,
inner_k_tiles=8
) -> torch.nn.Module:
"""
Parses the GGUF file and returns an nn.Module on meta device along with a state_dict
Expand Down
83 changes: 71 additions & 12 deletions build/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
from dataclasses import dataclass
from typing import Dict, Optional
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Dict, Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn
Expand All @@ -21,6 +23,11 @@ def find_multiple(n: int, k: int) -> int:
return n + k - (n % k)


class ModelDistributionChannel(Enum):
HuggingFaceSnapshot = 1
DirectDownload = 2


@dataclass
class ModelArgs:
block_size: int = 2048
Expand Down Expand Up @@ -94,14 +101,67 @@ def from_name(cls, name: str):
return cls(**transformer_configs[config[0]])


# Aliases for well-known models. Maps a short name to a HuggingFace path. These
# can be used from the CLI in-place of the full model path.
model_aliases = {
"llama2": "meta-llama/Llama-2-7b-chat-hf",
"llama2-7": "meta-llama/Llama-2-7b-chat-hf",
"mistral-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.2",
# Model configuration for known models.
@dataclass
class ModelConfig:
aliases: Sequence[str] = field(default_factory=list)
distribution_path: Union[str, Sequence[str]] = field(default="")
distribution_channel: ModelDistributionChannel = field(
default=ModelDistributionChannel.HuggingFaceSnapshot
)
checkpoint_file: str = field(default="model.pth")

@classmethod
def from_name(cls, name: str):
if name in model_configs:
return cls(**model_configs[name])

raise ValueError(f"Unknown model {name}.")


model_configs = {
"meta-llama/Llama-2-7b-chat-hf": {
"aliases": ["llama2", "llama2-7b"],
"distribution_channel": ModelDistributionChannel.HuggingFaceSnapshot,
"distribution_path": "meta-llama/Llama-2-7b-chat-hf",
},
"mistralai/Mistral-7B-Instruct-v0.2": {
"aliases": ["mistral-7b-instruct"],
"distribution_channel": ModelDistributionChannel.HuggingFaceSnapshot,
"distribution_path": "mistralai/Mistral-7B-Instruct-v0.2",
},
"stories15M": {
"distribution_channel": ModelDistributionChannel.DirectDownload,
"distribution_path": [
"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt",
"https://github.com/karpathy/llama2.c/raw/master/tokenizer.model",
],
"checkpoint_file": "stories15M.pt",
},
"stories110M": {
"distribution_channel": ModelDistributionChannel.DirectDownload,
"distribution_path": [
"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt",
"https://github.com/karpathy/llama2.c/raw/master/tokenizer.model",
],
"checkpoint_file": "stories110M.pt",
},
}

model_aliases: Dict[str, str] = {}
for name, config in model_configs.items():
if "aliases" in config:
for alias in config["aliases"]:
model_aliases[alias] = name


def resolve_model_config(model: str) -> Tuple[ModelConfig, str]:
if model in model_aliases:
model = model_aliases[model]

return ModelConfig.from_name(model), model


transformer_configs = {
"CodeLlama-7b-Python-hf": {
"block_size": 16384,
Expand Down Expand Up @@ -256,11 +316,10 @@ def from_params(cls, params_path: str):
@classmethod
def from_gguf(cls, gguf_path: str, **kwargs):
from build.gguf_loader import load_model_and_state_dict
model, state_dict = load_model_and_state_dict(gguf_path, **kwargs)

model, state_dict = load_model_and_state_dict(
gguf_path, load_as_quantized=True, inner_k_tiles=8
)
model.load_state_dict(state_dict, assign=True)
if state_dict != {}:
model.load_state_dict(state_dict, assign=True)
return model


Expand Down
1 change: 1 addition & 0 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def _add_arguments_common(parser):
help="The directory to store downloaded model artifacts",
)


def arg_init(args):

if Path(args.quantize).is_file():
Expand Down
62 changes: 44 additions & 18 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,24 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import urllib.request
from pathlib import Path
from typing import Optional
from typing import Optional, Sequence

from build.convert_hf_checkpoint import convert_hf_checkpoint
from build.model import model_aliases
from build.model import ModelConfig, ModelDistributionChannel, resolve_model_config

from requests.exceptions import HTTPError


def download_and_convert(
model: str, models_dir: Path, hf_token: Optional[str] = None
) -> None:
from huggingface_hub import snapshot_download

if model in model_aliases:
model = model_aliases[model]

def _download_and_convert_hf_snapshot(
model: str, models_dir: Path, hf_token: Optional[str]
):
model_dir = models_dir / model
os.makedirs(model_dir, exist_ok=True)

from huggingface_hub import snapshot_download

# Download and store the HF model artifacts.
print(f"Downloading {model} from HuggingFace...")
try:
Expand All @@ -44,18 +42,46 @@ def download_and_convert(

# Convert the model to the torchchat format.
print(f"Converting {model} to torchchat format...")
convert_hf_checkpoint(
model_dir=model_dir, model_name=Path(model), remove_bin_files=True
)
convert_hf_checkpoint(model_dir=model_dir, model_name=model, remove_bin_files=True)


def is_model_downloaded(model: str, models_dir: Path) -> bool:
if model in model_aliases:
model = model_aliases[model]

def _download_direct(
model: str,
urls: Sequence[str],
models_dir: Path,
):
model_dir = models_dir / model
os.makedirs(model_dir, exist_ok=True)

for url in urls:
filename = url.split("/")[-1]
local_path = model_dir / filename
print(f"Downloading {url}...")
urllib.request.urlretrieve(url, str(local_path.absolute()))


def download_and_convert(
model: str, models_dir: Path, hf_token: Optional[str] = None
) -> None:
model_config, model_name = resolve_model_config(model)

if (
model_config.distribution_channel
== ModelDistributionChannel.HuggingFaceSnapshot
):
_download_and_convert_hf_snapshot(model_name, models_dir, hf_token)
elif model_config.distribution_channel == ModelDistributionChannel.DirectDownload:
_download_direct(model_name, model_config.distribution_path, models_dir)
else:
raise RuntimeError(
f"Unknown distribution channel {model_config.distribution_channel}."
)


def is_model_downloaded(model: str, models_dir: Path) -> bool:
_, model_name = resolve_model_config(model)

# TODO Can we be more thorough here?
model_dir = models_dir / model_name
return os.path.isdir(model_dir)


Expand Down
14 changes: 7 additions & 7 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,12 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):


def prefill(
model: Transformer,
x: torch.Tensor,
input_pos: torch.Tensor,
*,
sequential_prefill = True,
**sampling_kwargs
model: Transformer,
x: torch.Tensor,
input_pos: torch.Tensor,
*,
sequential_prefill=True,
**sampling_kwargs,
) -> torch.Tensor:
logging.debug(f"x: {x}, input_pos: {input_pos}")
width = x.size(1)
Expand Down Expand Up @@ -349,7 +349,7 @@ def _main(
is_speculative = speculative_builder_args.checkpoint_path is not None

if generator_args.chat_mode and not builder_args.is_chat_model:
# This is not a log message, it's a dangerous condition message
# This is not a log message, it's a dangerous condition message
# that we must ensure is displayed
print(
"""
Expand Down
2 changes: 1 addition & 1 deletion torchchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,5 @@
export_main(args)
else:
raise RuntimeError(
"Must specify a valid subcommand: download, chat, generate, export, or eval."
"Must specify a valid subcommand: download, generate, export, or eval."
)

0 comments on commit da23171

Please sign in to comment.