Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for quantization with bitsandbytes #490

Merged
merged 30 commits into from
Sep 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
fc657d5
Add support for quantization with bitsandbytes
mryab Jun 27, 2022
794b966
Extend the compression benchmark
mryab Jun 27, 2022
540aa08
Fix formatting and imports
mryab Jun 27, 2022
d68a559
Build a cpuonly version of bitsandbytes
mryab Jul 26, 2022
6a857f7
Build a cpuonly version of bitsandbytes
mryab Jul 26, 2022
8bb5f3b
Build a cpuonly version of bitsandbytes
mryab Jul 26, 2022
e730cff
Revert changes
dbaranchuk Jul 28, 2022
6b5cf2a
Add a test for blockwise compression
mryab Aug 22, 2022
442b932
Replace building bitsandbytes from source with pip installation
mryab Aug 22, 2022
0d63242
Add a note to README about bitsandbytes
mryab Aug 22, 2022
0f7bb72
Revert to cpuonly build
mryab Aug 22, 2022
98538a0
Revert to cpuonly build
mryab Aug 22, 2022
ebc73ad
Revert to cpuonly build
mryab Aug 22, 2022
44060b5
Revert to cpuonly build
mryab Aug 22, 2022
2b82251
Update the docs
mryab Aug 22, 2022
2aef651
Install bitsandbytes in tests as well
mryab Aug 22, 2022
93c3ea5
Install bitsandbytes in tests as well
mryab Aug 22, 2022
300a153
Skip bitsandbytes warnings about cpu-only versions
mryab Aug 22, 2022
3a22e4d
Skip bitsandbytes warnings about cpu-only versions
mryab Aug 22, 2022
552e1d2
Replace bitsandbytes with a newer pypi version
mryab Aug 24, 2022
4bd9ca1
Update README.md
mryab Aug 24, 2022
939e6f8
Use hivemind[bitsandbytes] for README
mryab Aug 24, 2022
8a7d89e
Use hivemind[bitsandbytes] for README
mryab Aug 24, 2022
9735432
Use hivemind[bitsandbytes] for README
mryab Aug 24, 2022
f534fdc
Make bitsandbytes error parsing more explicit
mryab Aug 24, 2022
d5b9265
Verify outputs consistently in test_moe.py
mryab Aug 24, 2022
8120564
Pass device="cpu" in test_background_server_identity_path
mryab Aug 24, 2022
48f4d0b
Filter bitsandbytes warnings
mryab Aug 24, 2022
4c0fef6
Bump the version of bitsandbytes
mryab Sep 9, 2022
f311943
Reduce diff
mryab Sep 10, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/run-benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-dev.txt
- name: Build bitsandbytes
run: |
pip install bitsandbytes==0.32.3
- name: Build hivemind
run: |
pip install .
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-dev.txt
- name: Build bitsandbytes
run: |
pip install bitsandbytes==0.32.3
- name: Build hivemind
run: |
pip install .
Expand Down Expand Up @@ -88,6 +91,9 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-dev.txt
- name: Build bitsandbytes
run: |
pip install bitsandbytes==0.32.3
- name: Build hivemind
run: |
pip install -e . --no-use-pep517
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ If your versions of Python and PyTorch match the requirements, you can install h
pip install hivemind
```

Also, if you want to use blockwise 8-bit compression from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
during data transfer, you can install it with `pip install hivemind[bitsandbytes]`.
After that, you can use the `BlockwiseQuantization` class in [hivemind.compression](./hivemind/compression)

### From source

To install hivemind from source, simply run the following:
Expand Down
29 changes: 20 additions & 9 deletions benchmarks/benchmark_tensor_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,37 @@
logger = get_logger(__name__)


def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> float:
def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> [float, float, int]:
t = time.time()
deserialize_torch_tensor(serialize_torch_tensor(tensor, compression_type))
return time.time() - t
serialized = serialize_torch_tensor(tensor, compression_type)
result = deserialize_torch_tensor(serialized)
return time.time() - t, (tensor - result).square().mean(), serialized.ByteSize()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--size", type=int, default=10000000, required=False)
parser.add_argument("--size", type=int, default=10_000_000, required=False)
parser.add_argument("--seed", type=int, default=7348, required=False)
parser.add_argument("--num_iters", type=int, default=30, required=False)

args = parser.parse_args()

torch.manual_seed(args.seed)
X = torch.randn(args.size)
X = torch.randn(args.size, dtype=torch.float32)

for name, compression_type in CompressionType.items():
tm = 0
total_time = 0
compression_error = 0
total_size = 0
for i in range(args.num_iters):
tm += benchmark_compression(X, compression_type)
tm /= args.num_iters
logger.info(f"Compression type: {name}, time: {tm}")
iter_time, iter_distortion, size = benchmark_compression(X, compression_type)
total_time += iter_time
compression_error += iter_distortion
total_size += size
total_time /= args.num_iters
compression_error /= args.num_iters
total_size /= args.num_iters
logger.info(
f"Compression type: {name}, time: {total_time:.5f}, compression error: {compression_error:.5f}, "
f"size: {int(total_size):d}"
)
2 changes: 1 addition & 1 deletion hivemind/compression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from hivemind.compression.adaptive import PerTensorCompression, RoleAdaptiveCompression, SizeAdaptiveCompression
from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression, TensorRole
from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
from hivemind.compression.quantization import BlockwiseQuantization, Quantile8BitQuantization, Uniform8BitQuantization
from hivemind.compression.serialization import (
deserialize_tensor_stream,
deserialize_torch_tensor,
Expand Down
63 changes: 63 additions & 0 deletions hivemind/compression/quantization.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import importlib.util
import math
import os
import warnings
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import Tuple

import numpy as np
import torch

if importlib.util.find_spec("bitsandbytes") is not None:
warnings.filterwarnings("ignore", module="bitsandbytes", category=UserWarning)
from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise

from hivemind.compression.base import CompressionBase, CompressionInfo
from hivemind.proto import runtime_pb2

Expand Down Expand Up @@ -112,3 +118,60 @@ def quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_siz
for job in jobs:
job.result()
return np.quantile(partition_quantiles, quantiles)


BNB_MISSING_MESSAGE = """BlockwiseQuantization requires bitsandbytes to function properly.
Please install it with `pip install bitsandbytes`
or using the instruction from https://github.com/TimDettmers/bitsandbytes."""


class BlockwiseQuantization(Quantization):
compression_type = runtime_pb2.BLOCKWISE_8BIT
codebook_dtype, indices_dtype = np.float32, np.uint8

def quantize(
self, tensor: torch.Tensor, allow_inplace: bool = False
) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
try:
quantized, (absmax, codebook) = quantize_blockwise(tensor)
except NameError:
raise ImportError(BNB_MISSING_MESSAGE)
return quantized.numpy(), (absmax.numpy(), codebook.numpy())

def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
quantized, (absmax, codebook) = self.quantize(tensor.detach(), allow_inplace=allow_inplace)

serialized_data = (
np.int64(len(absmax)).tobytes(),
np.int64(len(codebook)).tobytes(),
absmax.tobytes(),
codebook.tobytes(),
quantized.tobytes(),
)

return runtime_pb2.Tensor(
buffer=b"".join(serialized_data),
size=tensor.shape,
requires_grad=tensor.requires_grad,
dtype=tensor.numpy().dtype.name,
compression=self.compression_type,
)

def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
absmax_size = int(np.frombuffer(serialized_tensor.buffer, count=1, dtype=np.int64))
codebook_size = int(np.frombuffer(serialized_tensor.buffer, offset=8, count=1, dtype=np.int64))
absmax = np.frombuffer(serialized_tensor.buffer, offset=16, count=absmax_size, dtype=self.codebook_dtype)
codebook = np.frombuffer(
serialized_tensor.buffer, offset=16 + absmax.nbytes, count=codebook_size, dtype=self.codebook_dtype
)
quantized = np.frombuffer(
serialized_tensor.buffer, offset=16 + absmax.nbytes + codebook.nbytes, dtype=self.indices_dtype
)

absmax = torch.as_tensor(absmax)
codebook = torch.as_tensor(codebook)
quantized = torch.as_tensor(quantized).reshape(tuple(serialized_tensor.size))
try:
return dequantize_blockwise(quantized, (absmax, codebook))
except NameError:
raise ImportError(BNB_MISSING_MESSAGE)
13 changes: 7 additions & 6 deletions hivemind/compression/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,22 @@

from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression
from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
from hivemind.compression.quantization import BlockwiseQuantization, Quantile8BitQuantization, Uniform8BitQuantization
from hivemind.proto import runtime_pb2
from hivemind.utils.streaming import combine_from_streaming

BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict(
_BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict(
NONE=NoCompression(),
FLOAT16=Float16Compression(),
MEANSTD_16BIT=ScaledFloat16Compression(),
QUANTILE_8BIT=Quantile8BitQuantization(),
UNIFORM_8BIT=Uniform8BitQuantization(),
BLOCKWISE_8BIT=BlockwiseQuantization(),
)

for key in runtime_pb2.CompressionType.keys():
assert key in BASE_COMPRESSION_TYPES, f"Compression type {key} does not have a registered deserializer"
actual_compression_type = BASE_COMPRESSION_TYPES[key].compression_type
assert key in _BASE_COMPRESSION_TYPES, f"Compression type {key} does not have a registered deserializer"
actual_compression_type = _BASE_COMPRESSION_TYPES[key].compression_type
assert (
runtime_pb2.CompressionType.Name(actual_compression_type) == key
), f"Compression strategy for {key} has inconsistent type"
Expand All @@ -35,14 +36,14 @@ def serialize_torch_tensor(
) -> runtime_pb2.Tensor:
"""Serialize a given tensor into a protobuf message using the specified compression strategy"""
assert tensor.device == torch.device("cpu")
compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(compression_type)]
compression = _BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(compression_type)]
info = info or CompressionInfo.from_tensor(tensor, **kwargs)
return compression.compress(tensor, info, allow_inplace)


def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
"""Restore a pytorch tensor from a protobuf message"""
compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(serialized_tensor.compression)]
compression = _BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(serialized_tensor.compression)]
return compression.extract(serialized_tensor).requires_grad_(serialized_tensor.requires_grad)


Expand Down
1 change: 1 addition & 0 deletions hivemind/proto/runtime.proto
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ enum CompressionType{
FLOAT16 = 2;
QUANTILE_8BIT = 3;
UNIFORM_8BIT = 4;
BLOCKWISE_8BIT = 5;
}

message Tensor {
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
"p2pd": "1252a2a2095040cef8e317f5801df8b8c93559711783a2496a0aff2f3e177e39",
}


here = os.path.abspath(os.path.dirname(__file__))


Expand Down Expand Up @@ -140,7 +139,9 @@ def run(self):
with open("requirements-docs.txt") as docs_requirements_file:
extras["docs"] = list(map(str, parse_requirements(docs_requirements_file)))

extras["all"] = extras["dev"] + extras["docs"]
extras["bitsandbytes"] = ["bitsandbytes==0.32.3"]

extras["all"] = extras["dev"] + extras["docs"] + extras["bitsandbytes"]

setup(
name="hivemind",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_dht_connection_successful():
dht_client_proc.stderr.readline()
first_report_msg = dht_client_proc.stderr.readline()

assert "2 DHT nodes (including this one) are in the local routing table" in first_report_msg
assert "2 DHT nodes (including this one) are in the local routing table" in first_report_msg, first_report_msg

# ensure we get the output of dht_proc after the start of dht_client_proc
sleep(dht_refresh_period)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
assert error.square().mean() < beta
error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
assert error.square().mean() < beta
error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.BLOCKWISE_8BIT)) - X
assert error.square().mean() < beta

zeros = torch.zeros(5, 5)
for compression_type in CompressionType.values():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def test_remote_module_call(hidden_dim=16):

# check that the server is still alive after processing a malformed request
out3_yet_again = real_expert(dummy_x[1:])
assert torch.allclose(out3_yet_again, out3[1:])
assert torch.allclose(out3_yet_again, out3[1:], atol=1e-5, rtol=0)


@pytest.mark.forked
Expand Down
22 changes: 15 additions & 7 deletions tests/test_start_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import re
from functools import partial
from subprocess import PIPE, Popen
from tempfile import TemporaryDirectory

Expand All @@ -10,10 +11,11 @@ def test_background_server_identity_path():
with TemporaryDirectory() as tempdir:
id_path = os.path.join(tempdir, "id")

with background_server(num_experts=1, identity_path=id_path) as server_info_1, background_server(
num_experts=1, identity_path=id_path
) as server_info_2, background_server(num_experts=1, identity_path=None) as server_info_3:
server_runner = partial(background_server, num_experts=1, device="cpu", hidden_dim=1)

with server_runner(identity_path=id_path) as server_info_1, server_runner(
identity_path=id_path
) as server_info_2, server_runner(identity_path=None) as server_info_3:
assert server_info_1.peer_id == server_info_2.peer_id
assert server_info_1.peer_id != server_info_3.peer_id
assert server_info_3.peer_id == server_info_3.peer_id
Expand All @@ -33,9 +35,11 @@ def test_cli_run_server_identity_path():
)

# Skip line "Generating new identity (libp2p private key) in {path to file}"
server_1_proc.stderr.readline()
line = server_1_proc.stderr.readline()
line = server_1_proc.stderr.readline()
addrs_1 = set(re.search(pattern, line).group(1).split(", "))
addrs_pattern_result = re.search(pattern, line)
assert addrs_pattern_result is not None, line
addrs_1 = set(addrs_pattern_result.group(1).split(", "))
ids_1 = set(a.split("/")[-1] for a in addrs_1)

assert len(ids_1) == 1
Expand All @@ -48,7 +52,9 @@ def test_cli_run_server_identity_path():
)

line = server_2_proc.stderr.readline()
addrs_2 = set(re.search(pattern, line).group(1).split(", "))
addrs_pattern_result = re.search(pattern, line)
assert addrs_pattern_result is not None, line
addrs_2 = set(addrs_pattern_result.group(1).split(", "))
ids_2 = set(a.split("/")[-1] for a in addrs_2)

assert len(ids_2) == 1
Expand All @@ -61,7 +67,9 @@ def test_cli_run_server_identity_path():
)

line = server_3_proc.stderr.readline()
addrs_3 = set(re.search(pattern, line).group(1).split(", "))
addrs_pattern_result = re.search(pattern, line)
assert addrs_pattern_result is not None, line
addrs_3 = set(addrs_pattern_result.group(1).split(", "))
ids_3 = set(a.split("/")[-1] for a in addrs_3)

assert len(ids_3) == 1
Expand Down