From a5a381a6ff9ca56256ce960abd67a7da1d6505b2 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 12 Feb 2025 11:26:08 -0500 Subject: [PATCH 1/2] support dtype Signed-off-by: Kyle Sayers --- src/llmcompressor/pipelines/cache.py | 21 +++++++---------- tests/e2e/vLLM/test_vllm.py | 1 - tests/llmcompressor/pipelines/test_cache.py | 26 +++++++++++++++++++++ 3 files changed, 34 insertions(+), 14 deletions(-) diff --git a/src/llmcompressor/pipelines/cache.py b/src/llmcompressor/pipelines/cache.py index 57c9b1486..090e1450d 100644 --- a/src/llmcompressor/pipelines/cache.py +++ b/src/llmcompressor/pipelines/cache.py @@ -1,5 +1,6 @@ import warnings from dataclasses import dataclass, fields, is_dataclass +from types import NoneType from typing import Any, Dict, List, Optional, Union import torch @@ -134,21 +135,17 @@ def _onload_value(self, intermediate: IntermediateValue) -> Any: if isinstance(value, torch.Tensor): return value.to(device=device) - elif is_dataclass(value): + if is_dataclass(value): for field in fields(value): # `asdict` is recursive, not applicable here v = getattr(value, field.name) setattr(value, field.name, self._onload_value(v)) return value - elif isinstance(value, tuple): + if isinstance(value, tuple): return tuple(self._onload_value(v) for v in value) - elif isinstance(value, (int, str, float, bool)) or value is None: - return value - - else: - return value + return value def _offload_value(self, value: Any) -> IntermediateValue: if isinstance(value, torch.Tensor): @@ -156,7 +153,7 @@ def _offload_value(self, value: Any) -> IntermediateValue: value=value.to(device=self.offload_device), device=value.device ) - elif is_dataclass(value): + if is_dataclass(value): for field in fields(value): # `asdict` is recursive, not applicable here v = getattr(value, field.name) setattr(value, field.name, self._offload_value(v)) @@ -168,12 +165,10 @@ def _offload_value(self, value: Any) -> IntermediateValue: value=tuple(self._offload_value(v) for v in value), device=None ) - if isinstance(value, (int, str, float, bool)) or value is None: - return IntermediateValue(value=value, device=None) - - else: + if not isinstance(value, (int, str, float, bool, NoneType, torch.dtype)): warnings.warn(f"Offloading not implemented for type {type(value)}.") - return IntermediateValue(value=value, device=None) + + return IntermediateValue(value=value, device=None) @staticmethod def _mask_padding( diff --git a/tests/e2e/vLLM/test_vllm.py b/tests/e2e/vLLM/test_vllm.py index 24c0a060d..6c42f82df 100644 --- a/tests/e2e/vLLM/test_vllm.py +++ b/tests/e2e/vLLM/test_vllm.py @@ -130,7 +130,6 @@ def test_vllm(self): session.reset() if SKIP_HF_UPLOAD.lower() != "yes": - logger.info("================= UPLOADING TO HUB ======================") stub = f"{HF_MODEL_HUB_NAME}/{self.save_dir}-e2e" diff --git a/tests/llmcompressor/pipelines/test_cache.py b/tests/llmcompressor/pipelines/test_cache.py index 71a72eb25..eda040d52 100644 --- a/tests/llmcompressor/pipelines/test_cache.py +++ b/tests/llmcompressor/pipelines/test_cache.py @@ -28,6 +28,7 @@ def sample_cache(sample_dataloader): ) +@pytest.mark.unit def test_initialization(sample_dataloader): cache = IntermediatesCache.from_dataloader( dataloader=sample_dataloader, @@ -40,6 +41,7 @@ def test_initialization(sample_dataloader): assert isinstance(cache.batch_intermediates[0], dict) +@pytest.mark.unit def test_fetch_inputs(sample_cache): fetched = sample_cache.fetch(0, ["input_ids", "attention_mask"]) @@ -50,6 +52,7 @@ def test_fetch_inputs(sample_cache): assert isinstance(fetched["attention_mask"], torch.Tensor) +@pytest.mark.unit def test_update_intermediates(sample_cache): new_outputs = { "hidden_states": torch.randn(2, 4, 768), @@ -63,6 +66,7 @@ def test_update_intermediates(sample_cache): assert "logits" in sample_cache.batch_intermediates[0] +@pytest.mark.unit def test_delete_intermediates(sample_cache): # First add some intermediates new_outputs = { @@ -78,6 +82,7 @@ def test_delete_intermediates(sample_cache): assert "logits" in sample_cache.batch_intermediates[0] +@pytest.mark.unit def test_mask_padding(): input_ids = torch.tensor([[1, 2, 3, 0], [4, 5, 6, 0]]) attention_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 1, 0]]) @@ -89,6 +94,7 @@ def test_mask_padding(): assert torch.equal(masked, expected) +@pytest.mark.unit def test_offload_and_onload_tensor(): cache = IntermediatesCache([], torch.device("cpu")) @@ -111,6 +117,7 @@ class SampleDataclass: b: int +@pytest.mark.unit def test_offload_and_onload_dataclass(): cache = IntermediatesCache([], torch.device("cpu")) @@ -129,6 +136,24 @@ def test_offload_and_onload_dataclass(): assert onloaded == sample_data +@pytest.mark.unit +def test_offload_and_onload_dtype(): + cache = IntermediatesCache([], torch.device("cpu")) + + # Create a sample dataclass instance + sample_data = torch.float32 + + # Test dataclass offloading + offloaded = cache._offload_value(sample_data) + assert isinstance(offloaded, IntermediateValue) + assert isinstance(offloaded.value, torch.dtype) + + # Test dataclass onloading + onloaded = cache._onload_value(offloaded) + assert onloaded == sample_data + + +@pytest.mark.unit def test_4d_attention_mask(): input_ids = torch.tensor([[1, 2, 3, 0]]) attention_mask = torch.ones(1, 1, 1, 4) # 4D attention mask @@ -140,6 +165,7 @@ def test_4d_attention_mask(): assert torch.equal(masked, expected) +@pytest.mark.unit def test_device_handling(sample_dataloader): if not torch.cuda.is_available(): pytest.skip("CUDA not available") From 272181cffc7cbe25c1177e105252ee3e07f1c7f3 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 19 Feb 2025 18:07:12 -0500 Subject: [PATCH 2/2] fix import for prev python versions Signed-off-by: Kyle Sayers --- src/llmcompressor/pipelines/cache.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/llmcompressor/pipelines/cache.py b/src/llmcompressor/pipelines/cache.py index 090e1450d..75a6431d5 100644 --- a/src/llmcompressor/pipelines/cache.py +++ b/src/llmcompressor/pipelines/cache.py @@ -1,6 +1,5 @@ import warnings from dataclasses import dataclass, fields, is_dataclass -from types import NoneType from typing import Any, Dict, List, Optional, Union import torch @@ -165,7 +164,7 @@ def _offload_value(self, value: Any) -> IntermediateValue: value=tuple(self._offload_value(v) for v in value), device=None ) - if not isinstance(value, (int, str, float, bool, NoneType, torch.dtype)): + if not isinstance(value, (int, str, float, bool, torch.dtype, type(None))): warnings.warn(f"Offloading not implemented for type {type(value)}.") return IntermediateValue(value=value, device=None)