From 858071268723e77e5d32e6320553207f498f3471 Mon Sep 17 00:00:00 2001 From: Michael Gschwind <61328285+mikekgfb@users.noreply.github.com> Date: Mon, 29 Apr 2024 22:35:11 -0700 Subject: [PATCH] add unpacking support (#525) * add unpacking support * fix typos and linter --- build/utils.py | 54 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/build/utils.py b/build/utils.py index 299c0bd1c6..f04e33ed94 100644 --- a/build/utils.py +++ b/build/utils.py @@ -9,9 +9,61 @@ import logging import os from pathlib import Path -from typing import Dict, List + +########################################################################## +### unpack packed weights ### + +from typing import Any, Callable, Dict, List, Optional, Tuple import torch +import torch.nn.functional as F + + +def unpack_packed_weights( + packed_weights: Dict[str, Any], + packed_linear: Callable, + input_dtype: torch.dtype, + unpacked_dims: Tuple, +) -> torch.Tensor: + """Given a packed weight matrix `packed_weights`, a Callable + implementing a packed linear function for the packed format, and the + unpacked dimensions of the weights, recreate the unpacked weight + matrix. In addition to the packed weights, as a dictionary to specify + whatever arguments the packed routine expects, we also need the input + data type because packing may depend on input dtype, or only some + input dtypes may be supported. We also need the dimensions of the + unpacked matrix. At present, this does not handle padding, but that will + be straightforward to add. Similarly, the same approach can be used + for both linear and mm operators. + + Args: + packed_weights: Dict[str, Any], + packed_linear: Callable, + input_dtype: torch.dtype, + unpacked_dims: Optional[Tuple]=None + + Example usage: + packed_weights = { + "weight" : weight_int4pack, + "qGroupSize": groupsize, + "scales_and_zeros": scales_and_zeros + } + unpacked_weights = unpack_packed_weights( + _weight_int4pack_linear, + packed_weights, + torch.bfloat6, + (256, 1024), + ) + + + """ + assert len(unpacked_dims) == 2, "unpacked_dims must be a tuple of length 2" + cols = unpacked_dims[1] + + unpacked_weights = packed_linear( + torch.eye(cols, dtype=input_dtype), **packed_weights + ).transpose(0, 1) + return unpacked_weights ##########################################################################