Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Slim-LM] Introduce HFLoad for loading Pytorch and SafeTensor weights #1113

Merged
merged 2 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/mlc_chat/compiler/parameter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
A subpackage of the compiler that represents mapping between external parameters, quantized
parameters and parameters in MLC-defined models.
"""
from .hf_torch_loader import HFTorchLoader
from .hf_loader import HFLoader
from .mapping import ExternMapping, QuantizeMapping
Original file line number Diff line number Diff line change
@@ -1,96 +1,39 @@
"""A weight loader for HuggingFace's PyTorch format"""
import dataclasses

import gc
import json
import logging
import time
from collections import OrderedDict, defaultdict
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, Iterator, List, Set, Tuple
from typing import Dict, Iterator, List, Tuple

import numpy as np
from tqdm import tqdm
from tvm.runtime import NDArray
from tvm.runtime.ndarray import array as as_ndarray

from .mapping import ExternMapping
from .stats import Stats
from .utils import check_parameter_usage, load_safetensor_shard, load_torch_shard

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class Stats:
"""Statistics of the loading process of HuggingFace PyTorch loader.

Attributes
----------
load_time_sec : float
Time used in loading the parameters.

map_time_sec : float
Time used in applying the mapping function, i.e. `ExternMapping.map_func`.

quant_time_sec : float
Time used in quantizing the parameters, i.e. `QuantizeMapping.quant_func`.

current_memory_gb : float
The current RAM usage in GB.

total_memory_gb : float
The total size data loaded from disk in GB.

max_memory_gb : float
The maximum RAM usage in GB.
"""

load_time_sec: float = 0.0
map_time_sec: float = 0.0
quant_time_sec: float = 0.0

current_memory_gb: float = 0.0
total_memory_gb: float = 0.0
max_memory_gb: float = 0.0

def timer(self, attr):
"""A context manager to time the scope and add the time to the attribute."""

@contextmanager
def timed_scope():
start_time = time.time()
yield
elapsed_time = time.time() - start_time
setattr(self, attr, getattr(self, attr) + elapsed_time)

return timed_scope()

def mem_add(self, nbytes: int):
"""Add the memory usage by the given number of bytes."""
mem_gb = float(nbytes) / float(1024**3)
self.current_memory_gb += mem_gb
self.total_memory_gb += mem_gb
self.max_memory_gb = max(self.max_memory_gb, self.current_memory_gb)

def mem_rm(self, nbytes: int):
"""Remove the memory usage by the given number of bytes."""
mem_gb = float(nbytes) / float(1024**3)
self.current_memory_gb -= mem_gb


class HFTorchLoader: # pylint: disable=too-few-public-methods
"""A loader loading HuggingFace's PyTorch format and converts them to MLC's parameters.
class HFLoader: # pylint: disable=too-few-public-methods
"""A loader loading HuggingFace's PyTorch/SafeTensor format and converts them
to MLC's parameters.

Attributes
----------
stats : Stats
Statistics of the loading process.

extern_param_map : ExternMapping
The parameter mapping from MLC to HuggingFace PyTorch.
The parameter mapping from MLC to HuggingFace PyTorch/SafeTensor.

torch_to_path : Dict[str, Path]
A mapping from PyTorch parameter name to the path of the file containing it, or the path
meaning all parameters are stored in a single file.
A mapping from PyTorch/SafeTensor parameter name to the path of the file containing it,
or the path meaning all parameters are stored in a single file.

cached_files : Dict[Path, Dict[str, np.ndarray]]
A cache of the loaded files. The key is the path of the file, and the value is a mapping
Expand All @@ -113,20 +56,23 @@ def __init__(
----------
path : pathlib.Path
Path to either a JSON indexing file, or a PyTorch bin file.
1) For JSON indexing file, it is usually `pytorch_model.bin.index.json` in the repo,
which contains a `weight_map` that maps each PyTorch parameter to the file containing
the weight. 2) For PyTorch bin file, it is usually `pytorch_model.bin` in the repo,
1) For JSON indexing file, it is usually `pytorch_model.bin.index.json`
or `model.safetensors.index.json` in the repo, which contains a `weight_map` that
maps each PyTorch parameter to the file containing the weight.
2) For PyTorch bin file, it is usually `pytorch_model.bin` in the repo,
which contains all the parameters.
3) For safetensor file, it is usually `model.safetensors` in the repo,
which contains all the parameters.

extern_param_map : ExternMapping
Maps an MLC parameter to a list of PyTorch parameters.
Maps an MLC parameter to a list of PyTorch/SafeTensor parameters.
"""
assert path.is_file()
self.stats = Stats()
self.extern_param_map = extern_param_map
self.cached_files = {}
self.torch_to_path = {}
if path.suffix == ".bin":
if path.suffix in (".bin", ".safetensors"):
self._load_file(path)
for name in self.cached_files[path].keys():
self.torch_to_path[name] = path
Expand All @@ -137,7 +83,7 @@ def __init__(
self.torch_to_path[torch_name] = path.parent / path_str
else:
raise FileNotFoundError(f"Unknown file suffix: {path}")
_check_parameter_usage(extern_param_map, set(self.torch_to_path.keys()))
check_parameter_usage(extern_param_map, set(self.torch_to_path.keys()))

def load(self) -> Iterator[Tuple[str, NDArray]]:
"""Load the parameters and yield the MLC parameter and its value."""
Expand All @@ -148,21 +94,8 @@ def load(self) -> Iterator[Tuple[str, NDArray]]:
cached_files = list(self.cached_files.keys())
for path in cached_files:
self._unload_file(path)

logger.info(
"Time used: "
"PyTorch loading: %.3f sec; "
"Pre-quantization mapping: %.3f sec; "
"Quantization: %.3f sec",
self.stats.load_time_sec,
self.stats.map_time_sec,
self.stats.quant_time_sec,
)
logger.info(
"Memory usage: Total size loaded from disk: %.3f GB; Peak memory usage: %.3f GB",
self.stats.total_memory_gb,
self.stats.max_memory_gb,
)
self.stats.log_time_info("HF")
self.stats.log_mem_usage()

def _load_mlc_param(self, mlc_name: str) -> np.ndarray:
torch_names = self.extern_param_map.param_map[mlc_name]
Expand Down Expand Up @@ -190,53 +123,24 @@ def _load_mlc_param(self, mlc_name: str) -> np.ndarray:
return param

def _load_file(self, path: Path) -> None:
logger.info("Loading PyTorch parameters from: %s", path)
logger.info("Loading HF parameters from: %s", path)
load_func = load_safetensor_shard if path.suffix == ".safetensors" else load_torch_shard
with self.stats.timer("load_time_sec"):
result = {}
for name, param in _load_torch_shard(path):
for name, param in load_func(path):
result[name] = param
self.stats.mem_add(param.nbytes)
self.cached_files[path] = result

def _unload_file(self, path: Path) -> None:
logger.info("Unloading PyTorch weight file: %s", path)
logger.info("Unloading HF weight file: %s", path)
with self.stats.timer("load_time_sec"):
for _, param in self.cached_files[path].items():
self.stats.mem_rm(param.nbytes)
del self.cached_files[path]
gc.collect()


def _check_parameter_usage(param_map: ExternMapping, torch_weights: Set[str]):
used_torch_names = set(sum(param_map.param_map.values(), ()))
# Check 1. All PyTorch parameters in the weight files are used unless explicitly specified
unused_torch_names = torch_weights - used_torch_names - param_map.unused_params
if unused_torch_names:
logger.warning(
"Unused torch parameters: %s",
", ".join(sorted(unused_torch_names)),
)
# Check 2. All PyTorch parameters required are stored in the weight files
nonexistent_torch_names = used_torch_names - torch_weights
if nonexistent_torch_names:
raise ValueError(
"The following torch parameters do not exist in the weight files:\n "
+ "\n ".join(sorted(nonexistent_torch_names)),
)


def _load_torch_shard(path: Path):
import torch # pylint: disable=import-outside-toplevel

for name, param in torch.load(path, map_location=torch.device("cpu")).items():
param = param.detach().cpu()
dtype = str(param.dtype)
if dtype == "torch.bfloat16":
param = param.float()
param = param.numpy()
yield name, param


def _loading_order(param_map: ExternMapping, torch_to_path: Dict[str, Path]) -> List[str]:
# Step 1. Build a map from path to torch parameters
path_to_torch: Dict[Path, List[str]] = defaultdict(list)
Expand All @@ -257,4 +161,4 @@ def _loading_order(param_map: ExternMapping, torch_to_path: Dict[str, Path]) ->
return list(order.keys())


__all__ = ["HFTorchLoader"]
__all__ = ["HFLoader"]
86 changes: 86 additions & 0 deletions python/mlc_chat/compiler/parameter/stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Statistics of the loading process of parameter loaders"""
import dataclasses
import logging
import time
from contextlib import contextmanager

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class Stats:
"""Statistics of the loading process of parameter loaders.

Attributes
----------
load_time_sec : float
Time used in loading the parameters.

map_time_sec : float
Time used in applying the mapping function, i.e. `ExternMapping.map_func`.

quant_time_sec : float
Time used in quantizing the parameters, i.e. `QuantizeMapping.quant_func`.

current_memory_gb : float
The current RAM usage in GB.

total_memory_gb : float
The total size data loaded from disk in GB.

max_memory_gb : float
The maximum RAM usage in GB.
"""

load_time_sec: float = 0.0
map_time_sec: float = 0.0
quant_time_sec: float = 0.0

current_memory_gb: float = 0.0
total_memory_gb: float = 0.0
max_memory_gb: float = 0.0

def timer(self, attr):
"""A context manager to time the scope and add the time to the attribute."""

@contextmanager
def timed_scope():
start_time = time.time()
yield
elapsed_time = time.time() - start_time
setattr(self, attr, getattr(self, attr) + elapsed_time)

return timed_scope()

def mem_add(self, nbytes: int):
"""Add the memory usage by the given number of bytes."""
mem_gb = float(nbytes) / float(1024**3)
self.current_memory_gb += mem_gb
self.total_memory_gb += mem_gb
self.max_memory_gb = max(self.max_memory_gb, self.current_memory_gb)

def mem_rm(self, nbytes: int):
"""Remove the memory usage by the given number of bytes."""
mem_gb = float(nbytes) / float(1024**3)
self.current_memory_gb -= mem_gb

def log_time_info(self, weight_format: str):
"""Log the time used in loading, pre-quantization and quantization."""
logger.info(
"Time used: "
"%s loading: %.3f sec; "
"Pre-quantization mapping: %.3f sec; "
"Quantization: %.3f sec",
weight_format,
self.load_time_sec,
self.map_time_sec,
self.quant_time_sec,
)

def log_mem_usage(self):
"""Log the Memory usage information."""
logger.info(
"Memory usage: Total size loaded from disk: %.3f GB; Peak memory usage: %.3f GB",
self.total_memory_gb,
self.max_memory_gb,
)
52 changes: 52 additions & 0 deletions python/mlc_chat/compiler/parameter/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Common utilities for loading parameters"""
import logging
from pathlib import Path
from typing import Iterator, Set, Tuple

import numpy as np

from .mapping import ExternMapping

logger = logging.getLogger(__name__)


def check_parameter_usage(param_map: ExternMapping, extern_weights: Set[str]):
"""Check that all external parameters have been used and are stored in the weights file."""
used_extern_names = set(sum(param_map.param_map.values(), []))
# Check 1. All extern parameters in the weight files are used unless explicitly specified
unused_extern_names = extern_weights - used_extern_names - param_map.unused_params
if unused_extern_names:
logger.warning(
"Unused extern parameters: %s",
", ".join(sorted(unused_extern_names)),
)
# Check 2. All extern parameters required are stored in the weight files
nonexistent_extern_names = used_extern_names - extern_weights
if nonexistent_extern_names:
raise ValueError(
"The following extern parameters do not exist in the weight files:\n "
+ "\n ".join(sorted(nonexistent_extern_names)),
)


def load_torch_shard(path: Path) -> Iterator[Tuple[str, np.ndarray]]:
"""Load and yield PyTorch format parameters."""
import torch # pylint: disable=import-outside-toplevel

for name, param in torch.load(path, map_location=torch.device("cpu")).items():
param = param.detach().cpu()
dtype = str(param.dtype)
if dtype == "torch.bfloat16":
param = param.float()
param = param.numpy()
yield name, param


def load_safetensor_shard(path: Path) -> Iterator[Tuple[str, np.ndarray]]:
"""Load and yield SafeTensor format parameters."""
import safetensors # pylint: disable=import-outside-toplevel,import-error

with safetensors.safe_open(path, framework="numpy", device="cpu") as in_file:
for name in in_file.keys():
param = in_file.get_tensor(name)
yield name, param
Loading