From 9eb83fe65c35788ec1c4de67a1b493e3dc0d1f49 Mon Sep 17 00:00:00 2001 From: alexm-nm <59768536+alexm-nm@users.noreply.github.com> Date: Fri, 23 Feb 2024 15:28:20 -0500 Subject: [PATCH] Add empty tensor initialization to LazyCompressedParameter (#53) --- .../layers/parameters/lazy_compressed.py | 33 +++++++++++++++---- .../sparsity/sparse_w16a16_linear_method.py | 6 +++- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/parameters/lazy_compressed.py b/vllm/model_executor/layers/parameters/lazy_compressed.py index a22f718197e10..52a2597955fbc 100644 --- a/vllm/model_executor/layers/parameters/lazy_compressed.py +++ b/vllm/model_executor/layers/parameters/lazy_compressed.py @@ -17,6 +17,7 @@ class LazyCompressedParameter(torch.Tensor): @staticmethod def __new__(cls, uncompressed_data: torch.Tensor, + is_empty: bool = False, storage_format_cls: Type[ CompressedStorageFormat] = SparseBitmaskStorageFormat, compress_transposed: bool = False): @@ -30,12 +31,16 @@ def __new__(cls, cls, size=uncompressed_data.shape, dtype=uncompressed_data.dtype, + device=uncompressed_data.device, requires_grad=False) + self._is_param = True + self.storage_format_cls = storage_format_cls - self.compressed_data = None - self.uncompressed_data = uncompressed_data self.compress_transposed = compress_transposed - self._is_param = True + self.compressed_data = None + + self.is_empty = is_empty + self.uncompressed_data = None if self.is_empty else uncompressed_data return self @@ -45,7 +50,10 @@ def has_compressed_data(self) -> bool: @property def has_uncompressed_data(self) -> bool: - return (self.uncompressed_data is not None) + if self.is_empty: + raise ValueError( + "has_uncompressed_data() was called with empty data") + return self.uncompressed_data is not None @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): @@ -56,8 +64,16 @@ def unwrap(e): if isinstance(e, LazyCompressedParameter): assert ret_storage_format_cls is None or ret_storage_format_cls == e.storage_format_cls ret_storage_format_cls = e.storage_format_cls - return e.uncompressed_data if isinstance( - e, LazyCompressedParameter) else e + + if e.is_empty: + e.is_empty = False + e.uncompressed_data = torch.empty(size=e.size(), + dtype=e.dtype, + device=e.device) + + return e.uncompressed_data + else: + return e rs = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) @@ -65,7 +81,10 @@ def wrap(e): if isinstance(e, torch.Tensor) and ret_storage_format_cls is not None: return LazyCompressedParameter( - e, storage_format_cls=ret_storage_format_cls) + e, + # Here, "e" is the output of "func" so it is real data and we store it + is_empty=False, + storage_format_cls=ret_storage_format_cls) return e rs = tree_map(wrap, rs) diff --git a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py index 7a3b8d30beabd..11ac8390205d9 100644 --- a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py +++ b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py @@ -35,8 +35,12 @@ def create_weights(self, input_size_per_partition: int, weight = LazyCompressedParameter( torch.empty((output_size_per_partition, input_size_per_partition), dtype=params_dtype), + # For create_weights(..), we initialize an empty tensor to + # save GPU memory. When the parameter will be loaded from + # disk it will be copied into this tensor + is_empty=True, storage_format_cls=self.storage_format_cls, - # if we don't support F.linear or something analogous, + # If we don't support F.linear or something analogous, # transpose when we compress so we can use a basic matmul compress_transposed=not supports_linear)