diff --git a/hivemind/compression/base.py b/hivemind/compression/base.py index 8e6f273bd..f910aa51a 100644 --- a/hivemind/compression/base.py +++ b/hivemind/compression/base.py @@ -1,4 +1,5 @@ import dataclasses +import os import warnings from abc import ABC, abstractmethod from enum import Enum, auto @@ -13,6 +14,7 @@ # While converting read-only NumPy arrays into PyTorch tensors, we don't make extra copies for efficiency warnings.filterwarnings("ignore", message="The given NumPy array is not writable", category=UserWarning) +USE_LEGACY_BFLOAT16 = bool(int(os.environ.get("USE_LEGACY_BFLOAT16", 1))) Key = Any @@ -81,26 +83,39 @@ class NoCompression(CompressionBase): def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor: tensor = tensor.detach() + shape = tensor.shape dtype_name = str(tensor.dtype).lstrip("torch.") + raw_data = tensor if tensor.dtype == torch.bfloat16: - tensor = tensor.to(torch.float32) + if USE_LEGACY_BFLOAT16: + raw_data = tensor.to(torch.float32) + else: + typed_storage = tensor.storage() + storage = typed_storage.untyped() if hasattr(typed_storage, "untyped") else typed_storage._untyped() + raw_data = torch.tensor(storage, dtype=torch.int8) return runtime_pb2.Tensor( compression=self.compression_type, - buffer=tensor.numpy().tobytes(), - size=tensor.shape, + buffer=raw_data.numpy().tobytes(), + size=shape, dtype=dtype_name, requires_grad=tensor.requires_grad, ) def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor: + shape = torch.Size(serialized_tensor.size) if serialized_tensor.dtype == "bfloat16": - array = np.frombuffer(serialized_tensor.buffer, dtype=np.float32) - tensor = torch.as_tensor(array, dtype=torch.bfloat16) + if len(serialized_tensor.buffer) // shape.numel() == 4: # legacy mode: convert to fp32 + array = np.frombuffer(serialized_tensor.buffer, dtype=np.float32) + tensor = torch.as_tensor(array, dtype=torch.bfloat16) + else: # efficient mode: send bfloat16 data directly + storage_type = torch.TypedStorage if hasattr(torch, "TypedStorage") else torch._TypedStorage + storage = storage_type.from_buffer(serialized_tensor.buffer, byte_order="little", dtype=torch.bfloat16) + tensor = torch.as_tensor(storage, dtype=torch.bfloat16) else: array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype)) tensor = torch.as_tensor(array) - return tensor.reshape(tuple(serialized_tensor.size)) + return tensor.reshape(shape) def estimate_compression_ratio(self, info: CompressionInfo) -> float: return 1.0 diff --git a/tests/test_compression.py b/tests/test_compression.py index 6f868c387..9f1c78d86 100644 --- a/tests/test_compression.py +++ b/tests/test_compression.py @@ -68,8 +68,10 @@ def test_serialize_tensor(): _check(torch.tensor(1.0), CompressionType.FLOAT16) +@pytest.mark.parametrize("use_legacy_bfloat16", [True, False]) @pytest.mark.forked -def test_serialize_bfloat16(): +def test_serialize_bfloat16(use_legacy_bfloat16: bool): + hivemind.compression.base.USE_LEGACY_BFLOAT16 = use_legacy_bfloat16 tensor = torch.randn(4096, 16, dtype=torch.bfloat16) _check(tensor, CompressionType.NONE) _check(tensor, CompressionType.BLOCKWISE_8BIT, rtol=0.1, atol=0.01, chunk_size=1024)