Skip to content

Commit

Permalink
add unpacking support (pytorch#525)
Browse files Browse the repository at this point in the history
* add unpacking support

* fix typos and linter
  • Loading branch information
mikekgfb authored and malfet committed Jul 17, 2024
1 parent 6732127 commit 8580712
Showing 1 changed file with 53 additions and 1 deletion.
54 changes: 53 additions & 1 deletion build/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,61 @@
import logging
import os
from pathlib import Path
from typing import Dict, List

##########################################################################
### unpack packed weights ###

from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
import torch.nn.functional as F


def unpack_packed_weights(
packed_weights: Dict[str, Any],
packed_linear: Callable,
input_dtype: torch.dtype,
unpacked_dims: Tuple,
) -> torch.Tensor:
"""Given a packed weight matrix `packed_weights`, a Callable
implementing a packed linear function for the packed format, and the
unpacked dimensions of the weights, recreate the unpacked weight
matrix. In addition to the packed weights, as a dictionary to specify
whatever arguments the packed routine expects, we also need the input
data type because packing may depend on input dtype, or only some
input dtypes may be supported. We also need the dimensions of the
unpacked matrix. At present, this does not handle padding, but that will
be straightforward to add. Similarly, the same approach can be used
for both linear and mm operators.
Args:
packed_weights: Dict[str, Any],
packed_linear: Callable,
input_dtype: torch.dtype,
unpacked_dims: Optional[Tuple]=None
Example usage:
packed_weights = {
"weight" : weight_int4pack,
"qGroupSize": groupsize,
"scales_and_zeros": scales_and_zeros
}
unpacked_weights = unpack_packed_weights(
_weight_int4pack_linear,
packed_weights,
torch.bfloat6,
(256, 1024),
)
"""
assert len(unpacked_dims) == 2, "unpacked_dims must be a tuple of length 2"
cols = unpacked_dims[1]

unpacked_weights = packed_linear(
torch.eye(cols, dtype=input_dtype), **packed_weights
).transpose(0, 1)
return unpacked_weights


##########################################################################
Expand Down

0 comments on commit 8580712

Please sign in to comment.