Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Offload Cache Support torch.dtype #1141

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions src/llmcompressor/pipelines/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,29 +134,25 @@ def _onload_value(self, intermediate: IntermediateValue) -> Any:
if isinstance(value, torch.Tensor):
return value.to(device=device)

elif is_dataclass(value):
if is_dataclass(value):
for field in fields(value): # `asdict` is recursive, not applicable here
v = getattr(value, field.name)
setattr(value, field.name, self._onload_value(v))

return value

elif isinstance(value, tuple):
if isinstance(value, tuple):
return tuple(self._onload_value(v) for v in value)

elif isinstance(value, (int, str, float, bool)) or value is None:
return value

else:
return value
return value

def _offload_value(self, value: Any) -> IntermediateValue:
if isinstance(value, torch.Tensor):
return IntermediateValue(
value=value.to(device=self.offload_device), device=value.device
)

elif is_dataclass(value):
if is_dataclass(value):
for field in fields(value): # `asdict` is recursive, not applicable here
v = getattr(value, field.name)
setattr(value, field.name, self._offload_value(v))
Expand All @@ -168,12 +164,10 @@ def _offload_value(self, value: Any) -> IntermediateValue:
value=tuple(self._offload_value(v) for v in value), device=None
)

if isinstance(value, (int, str, float, bool)) or value is None:
return IntermediateValue(value=value, device=None)

else:
if not isinstance(value, (int, str, float, bool, torch.dtype, type(None))):
warnings.warn(f"Offloading not implemented for type {type(value)}.")
return IntermediateValue(value=value, device=None)

return IntermediateValue(value=value, device=None)

@staticmethod
def _mask_padding(
Expand Down
1 change: 0 additions & 1 deletion tests/e2e/vLLM/test_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ def test_vllm(self):
session.reset()

if SKIP_HF_UPLOAD.lower() != "yes":

logger.info("================= UPLOADING TO HUB ======================")

stub = f"{HF_MODEL_HUB_NAME}/{self.save_dir}-e2e"
Expand Down
26 changes: 26 additions & 0 deletions tests/llmcompressor/pipelines/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def sample_cache(sample_dataloader):
)


@pytest.mark.unit
def test_initialization(sample_dataloader):
cache = IntermediatesCache.from_dataloader(
dataloader=sample_dataloader,
Expand All @@ -40,6 +41,7 @@ def test_initialization(sample_dataloader):
assert isinstance(cache.batch_intermediates[0], dict)


@pytest.mark.unit
def test_fetch_inputs(sample_cache):
fetched = sample_cache.fetch(0, ["input_ids", "attention_mask"])

Expand All @@ -50,6 +52,7 @@ def test_fetch_inputs(sample_cache):
assert isinstance(fetched["attention_mask"], torch.Tensor)


@pytest.mark.unit
def test_update_intermediates(sample_cache):
new_outputs = {
"hidden_states": torch.randn(2, 4, 768),
Expand All @@ -63,6 +66,7 @@ def test_update_intermediates(sample_cache):
assert "logits" in sample_cache.batch_intermediates[0]


@pytest.mark.unit
def test_delete_intermediates(sample_cache):
# First add some intermediates
new_outputs = {
Expand All @@ -78,6 +82,7 @@ def test_delete_intermediates(sample_cache):
assert "logits" in sample_cache.batch_intermediates[0]


@pytest.mark.unit
def test_mask_padding():
input_ids = torch.tensor([[1, 2, 3, 0], [4, 5, 6, 0]])
attention_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 1, 0]])
Expand All @@ -89,6 +94,7 @@ def test_mask_padding():
assert torch.equal(masked, expected)


@pytest.mark.unit
def test_offload_and_onload_tensor():
cache = IntermediatesCache([], torch.device("cpu"))

Expand All @@ -111,6 +117,7 @@ class SampleDataclass:
b: int


@pytest.mark.unit
def test_offload_and_onload_dataclass():
cache = IntermediatesCache([], torch.device("cpu"))

Expand All @@ -129,6 +136,24 @@ def test_offload_and_onload_dataclass():
assert onloaded == sample_data


@pytest.mark.unit
def test_offload_and_onload_dtype():
cache = IntermediatesCache([], torch.device("cpu"))

# Create a sample dataclass instance
sample_data = torch.float32

# Test dataclass offloading
offloaded = cache._offload_value(sample_data)
assert isinstance(offloaded, IntermediateValue)
assert isinstance(offloaded.value, torch.dtype)

# Test dataclass onloading
onloaded = cache._onload_value(offloaded)
assert onloaded == sample_data


@pytest.mark.unit
def test_4d_attention_mask():
input_ids = torch.tensor([[1, 2, 3, 0]])
attention_mask = torch.ones(1, 1, 1, 4) # 4D attention mask
Expand All @@ -140,6 +165,7 @@ def test_4d_attention_mask():
assert torch.equal(masked, expected)


@pytest.mark.unit
def test_device_handling(sample_dataloader):
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
Expand Down