Skip to content

Commit

Permalink
Introduce HFLoad for loading Pytorch and SafeTensor weights
Browse files Browse the repository at this point in the history
  • Loading branch information
LeshengJin committed Oct 23, 2023
1 parent e5927ce commit a8b0241
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 165 deletions.
2 changes: 1 addition & 1 deletion python/mlc_chat/compiler/parameter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
A subpackage of the compiler that represents mapping between external parameters, quantized
parameters and parameters in MLC-defined models.
"""
from .hf_torch_loader import HFTorchLoader
from .hf_loader import HFLoader
from .mapping import ExternMapping, QuantizeMapping
Original file line number Diff line number Diff line change
@@ -1,96 +1,43 @@
"""A weight loader for HuggingFace's PyTorch format"""
import dataclasses

import gc
import json
import logging
import time
from collections import OrderedDict, defaultdict
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, Iterator, List, Set, Tuple
from typing import Dict, Iterator, List, Tuple

import numpy as np
from tqdm import tqdm
from tvm.runtime import NDArray
from tvm.runtime.ndarray import array as as_ndarray

from .mapping import ExternMapping
from .utils import (
Stats,
_check_parameter_usage,
_load_safetensor_shard,
_load_torch_shard,
)

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class Stats:
"""Statistics of the loading process of HuggingFace PyTorch loader.
Attributes
----------
load_time_sec : float
Time used in loading the parameters.
map_time_sec : float
Time used in applying the mapping function, i.e. `ExternMapping.map_func`.
quant_time_sec : float
Time used in quantizing the parameters, i.e. `QuantizeMapping.quant_func`.
current_memory_gb : float
The current RAM usage in GB.
total_memory_gb : float
The total size data loaded from disk in GB.
max_memory_gb : float
The maximum RAM usage in GB.
"""

load_time_sec: float = 0.0
map_time_sec: float = 0.0
quant_time_sec: float = 0.0

current_memory_gb: float = 0.0
total_memory_gb: float = 0.0
max_memory_gb: float = 0.0

def timer(self, attr):
"""A context manager to time the scope and add the time to the attribute."""

@contextmanager
def timed_scope():
start_time = time.time()
yield
elapsed_time = time.time() - start_time
setattr(self, attr, getattr(self, attr) + elapsed_time)

return timed_scope()

def mem_add(self, nbytes: int):
"""Add the memory usage by the given number of bytes."""
mem_gb = float(nbytes) / float(1024**3)
self.current_memory_gb += mem_gb
self.total_memory_gb += mem_gb
self.max_memory_gb = max(self.max_memory_gb, self.current_memory_gb)

def mem_rm(self, nbytes: int):
"""Remove the memory usage by the given number of bytes."""
mem_gb = float(nbytes) / float(1024**3)
self.current_memory_gb -= mem_gb


class HFTorchLoader: # pylint: disable=too-few-public-methods
"""A loader loading HuggingFace's PyTorch format and converts them to MLC's parameters.
class HFLoader: # pylint: disable=too-few-public-methods
"""A loader loading HuggingFace's PyTorch/SafeTensor format and converts them
to MLC's parameters.
Attributes
----------
stats : Stats
Statistics of the loading process.
extern_param_map : ExternMapping
The parameter mapping from MLC to HuggingFace PyTorch.
The parameter mapping from MLC to HuggingFace PyTorch/SafeTensor.
torch_to_path : Dict[str, Path]
A mapping from PyTorch parameter name to the path of the file containing it, or the path
meaning all parameters are stored in a single file.
A mapping from PyTorch/SafeTensor parameter name to the path of the file containing it,
or the path meaning all parameters are stored in a single file.
cached_files : Dict[Path, Dict[str, np.ndarray]]
A cache of the loaded files. The key is the path of the file, and the value is a mapping
Expand All @@ -113,20 +60,23 @@ def __init__(
----------
path : pathlib.Path
Path to either a JSON indexing file, or a PyTorch bin file.
1) For JSON indexing file, it is usually `pytorch_model.bin.index.json` in the repo,
which contains a `weight_map` that maps each PyTorch parameter to the file containing
the weight. 2) For PyTorch bin file, it is usually `pytorch_model.bin` in the repo,
1) For JSON indexing file, it is usually `pytorch_model.bin.index.json`
or `model.safetensors.index.json` in the repo, which contains a `weight_map` that
maps each PyTorch parameter to the file containing the weight.
2) For PyTorch bin file, it is usually `pytorch_model.bin` in the repo,
which contains all the parameters.
3) For safetensor file, it is usually `model.safetensors` in the repo,
which contains all the parameters.
extern_param_map : ExternMapping
Maps an MLC parameter to a list of PyTorch parameters.
Maps an MLC parameter to a list of PyTorch/SafeTensor parameters.
"""
assert path.is_file()
self.stats = Stats()
self.extern_param_map = extern_param_map
self.cached_files = {}
self.torch_to_path = {}
if path.suffix == ".bin":
if path.suffix == ".bin" or path.suffix == ".safetensors":
self._load_file(path)
for name in self.cached_files[path].keys():
self.torch_to_path[name] = path
Expand All @@ -148,21 +98,8 @@ def load(self) -> Iterator[Tuple[str, NDArray]]:
cached_files = list(self.cached_files.keys())
for path in cached_files:
self._unload_file(path)

logger.info(
"Time used: "
"PyTorch loading: %.3f sec; "
"Pre-quantization mapping: %.3f sec; "
"Quantization: %.3f sec",
self.stats.load_time_sec,
self.stats.map_time_sec,
self.stats.quant_time_sec,
)
logger.info(
"Memory usage: Total size loaded from disk: %.3f GB; Peak memory usage: %.3f GB",
self.stats.total_memory_gb,
self.stats.max_memory_gb,
)
self.stats.log_time_info("HF")
self.stats.log_mem_usage()

def _load_mlc_param(self, mlc_name: str) -> np.ndarray:
torch_names = self.extern_param_map.param_map[mlc_name]
Expand Down Expand Up @@ -190,53 +127,24 @@ def _load_mlc_param(self, mlc_name: str) -> np.ndarray:
return param

def _load_file(self, path: Path) -> None:
logger.info("Loading PyTorch parameters from: %s", path)
logger.info("Loading HF parameters from: %s", path)
load_func = _load_safetensor_shard if path.suffix == ".safetensors" else _load_torch_shard
with self.stats.timer("load_time_sec"):
result = {}
for name, param in _load_torch_shard(path):
for name, param in load_func(path):
result[name] = param
self.stats.mem_add(param.nbytes)
self.cached_files[path] = result

def _unload_file(self, path: Path) -> None:
logger.info("Unloading PyTorch weight file: %s", path)
logger.info("Unloading HF weight file: %s", path)
with self.stats.timer("load_time_sec"):
for _, param in self.cached_files[path].items():
self.stats.mem_rm(param.nbytes)
del self.cached_files[path]
gc.collect()


def _check_parameter_usage(param_map: ExternMapping, torch_weights: Set[str]):
used_torch_names = set(sum(param_map.param_map.values(), ()))
# Check 1. All PyTorch parameters in the weight files are used unless explicitly specified
unused_torch_names = torch_weights - used_torch_names - param_map.unused_params
if unused_torch_names:
logger.warning(
"Unused torch parameters: %s",
", ".join(sorted(unused_torch_names)),
)
# Check 2. All PyTorch parameters required are stored in the weight files
nonexistent_torch_names = used_torch_names - torch_weights
if nonexistent_torch_names:
raise ValueError(
"The following torch parameters do not exist in the weight files:\n "
+ "\n ".join(sorted(nonexistent_torch_names)),
)


def _load_torch_shard(path: Path):
import torch # pylint: disable=import-outside-toplevel

for name, param in torch.load(path, map_location=torch.device("cpu")).items():
param = param.detach().cpu()
dtype = str(param.dtype)
if dtype == "torch.bfloat16":
param = param.float()
param = param.numpy()
yield name, param


def _loading_order(param_map: ExternMapping, torch_to_path: Dict[str, Path]) -> List[str]:
# Step 1. Build a map from path to torch parameters
path_to_torch: Dict[Path, List[str]] = defaultdict(list)
Expand All @@ -257,4 +165,4 @@ def _loading_order(param_map: ExternMapping, torch_to_path: Dict[str, Path]) ->
return list(order.keys())


__all__ = ["HFTorchLoader"]
__all__ = ["HFLoader"]
129 changes: 129 additions & 0 deletions python/mlc_chat/compiler/parameter/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""A weight loader for HuggingFace's PyTorch format"""
import dataclasses
import logging
import time
from contextlib import contextmanager
from pathlib import Path
from typing import Set

from .mapping import ExternMapping

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class Stats:
"""Statistics of the loading process of HuggingFace PyTorch loader.
Attributes
----------
load_time_sec : float
Time used in loading the parameters.
map_time_sec : float
Time used in applying the mapping function, i.e. `ExternMapping.map_func`.
quant_time_sec : float
Time used in quantizing the parameters, i.e. `QuantizeMapping.quant_func`.
current_memory_gb : float
The current RAM usage in GB.
total_memory_gb : float
The total size data loaded from disk in GB.
max_memory_gb : float
The maximum RAM usage in GB.
"""

load_time_sec: float = 0.0
map_time_sec: float = 0.0
quant_time_sec: float = 0.0

current_memory_gb: float = 0.0
total_memory_gb: float = 0.0
max_memory_gb: float = 0.0

def timer(self, attr):
"""A context manager to time the scope and add the time to the attribute."""

@contextmanager
def timed_scope():
start_time = time.time()
yield
elapsed_time = time.time() - start_time
setattr(self, attr, getattr(self, attr) + elapsed_time)

return timed_scope()

def mem_add(self, nbytes: int):
"""Add the memory usage by the given number of bytes."""
mem_gb = float(nbytes) / float(1024**3)
self.current_memory_gb += mem_gb
self.total_memory_gb += mem_gb
self.max_memory_gb = max(self.max_memory_gb, self.current_memory_gb)

def mem_rm(self, nbytes: int):
"""Remove the memory usage by the given number of bytes."""
mem_gb = float(nbytes) / float(1024**3)
self.current_memory_gb -= mem_gb

def log_time_info(self, weight_format: str):
"""Log the time used in loading, pre-quantization and quantization."""
logger.info(
"Time used: "
"%s loading: %.3f sec; "
"Pre-quantization mapping: %.3f sec; "
"Quantization: %.3f sec",
weight_format,
self.load_time_sec,
self.map_time_sec,
self.quant_time_sec,
)

def log_mem_usage(self):
"""Log the Memory usage information."""
logger.info(
"Memory usage: Total size loaded from disk: %.3f GB; Peak memory usage: %.3f GB",
self.total_memory_gb,
self.max_memory_gb,
)


def _check_parameter_usage(param_map: ExternMapping, torch_weights: Set[str]):
used_torch_names = set(sum(param_map.param_map.values(), []))
# Check 1. All PyTorch parameters in the weight files are used unless explicitly specified
unused_torch_names = torch_weights - used_torch_names - param_map.unused_params
if unused_torch_names:
logger.warning(
"Unused torch parameters: %s",
", ".join(sorted(unused_torch_names)),
)
# Check 2. All PyTorch parameters required are stored in the weight files
nonexistent_torch_names = used_torch_names - torch_weights
if nonexistent_torch_names:
raise ValueError(
"The following torch parameters do not exist in the weight files:\n "
+ "\n ".join(sorted(nonexistent_torch_names)),
)


def _load_torch_shard(path: Path):
import torch # pylint: disable=import-outside-toplevel

for name, param in torch.load(path, map_location=torch.device("cpu")).items():
param = param.detach().cpu()
dtype = str(param.dtype)
if dtype == "torch.bfloat16":
param = param.float()
param = param.numpy()
yield name, param


def _load_safetensor_shard(path: Path):
import safetensors # pylint: disable=import-outside-toplevel

with safetensors.safe_open(path, framework="numpy", device="cpu") as in_file:
for name in in_file.keys():
param = in_file.get_tensor(name)
yield name, param
Loading

0 comments on commit a8b0241

Please sign in to comment.