Skip to content

Commit

Permalink
Fix bfloat16 serialization for tensors with zero elements (#560)
Browse files Browse the repository at this point in the history
Follow-up to #553.
  • Loading branch information
borzunov authored Mar 28, 2023
1 parent 98531ce commit 3164928
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
3 changes: 2 additions & 1 deletion hivemind/compression/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: b
def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
shape = torch.Size(serialized_tensor.size)
if serialized_tensor.dtype == "bfloat16":
if len(serialized_tensor.buffer) // shape.numel() == 4: # legacy mode: convert to fp32
numel = shape.numel()
if numel > 0 and len(serialized_tensor.buffer) // numel == 4: # legacy mode: convert to fp32
array = np.frombuffer(serialized_tensor.buffer, dtype=np.float32)
tensor = torch.as_tensor(array, dtype=torch.bfloat16)
else: # efficient mode: send bfloat16 data directly
Expand Down
7 changes: 4 additions & 3 deletions tests/test_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
serialized_tensor = serialize_torch_tensor(tensor, compression)
chunks = list(split_for_streaming(serialized_tensor, chunk_size))
assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
assert len(chunks) == max((len(serialized_tensor.buffer) - 1) // chunk_size + 1, 1)
restored = combine_from_streaming(chunks)
result = deserialize_torch_tensor(restored)
assert torch.allclose(result, tensor, rtol=rtol, atol=atol)
Expand All @@ -69,10 +69,11 @@ def test_serialize_tensor():


@pytest.mark.parametrize("use_legacy_bfloat16", [True, False])
@pytest.mark.parametrize("tensor_size", [(4096, 16), (0, 0)])
@pytest.mark.forked
def test_serialize_bfloat16(use_legacy_bfloat16: bool):
def test_serialize_bfloat16(use_legacy_bfloat16: bool, tensor_size: tuple):
hivemind.compression.base.USE_LEGACY_BFLOAT16 = use_legacy_bfloat16
tensor = torch.randn(4096, 16, dtype=torch.bfloat16)
tensor = torch.randn(tensor_size, dtype=torch.bfloat16)
_check(tensor, CompressionType.NONE)
_check(tensor, CompressionType.BLOCKWISE_8BIT, rtol=0.1, atol=0.01, chunk_size=1024)

Expand Down

0 comments on commit 3164928

Please sign in to comment.