Skip to content

Commit

Permalink
Introducing 1-bit quantization for Llama in torchchat (pytorch#911)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#911

Pull Request resolved: pytorch#910

### THIS DIFF
I introduce the ability to use 1-bit quantization in torchao (i.e. pack and unpack bytes than consist of 1 bit of information).

For example, given 8 bytes of 1-bit quantized values:
`00000001` `00000001` `00000000` `00000001` `00000000` `00000000` `00000000` `00000001`

We can pack them into 1 byte: `11010001` and vice-versa.
Main changes:
- added `uint1.h` that contains the internal helper functions to pack-unpack 8 bytes, 64 and 128 bytes of `uint1`s.
- modified `bitpack.h` to add case statements for 1-bit quantization in the general functions that perform vectorized packing/unpacking on ARM neon vectors. (32, 64, 128 values)

### CONTEXT

Refer to previous diffs introducing 2-5 bit quantization. 2-bit: D62133659

### Optional:
I noticed that the individual tests in `test_bitpacking.cpp` for 1, 3, and 5 bits were identical and could potentially be factored out into a group. Maybe for a future diff?

Reviewed By: metascroy

Differential Revision: D63052325
  • Loading branch information
Vaishnavi Gupta authored and facebook-github-bot committed Sep 20, 2024
1 parent 23321fb commit 275541d
Show file tree
Hide file tree
Showing 7 changed files with 505 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <benchmark/benchmark.h>

#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h>
Expand All @@ -16,6 +17,128 @@

namespace {

// Benchmark utility to compare variants of uint1 packing
void pack_uint1_values(
uint8_t* packed,
uint8_t* unpacked,
int packed_size,
int unpacked_size,
int variant) {
constexpr int nbit = 1;
constexpr int bitsPerByte = 8;
assert(unpacked_size * nbit / bitsPerByte == packed_size);
assert(packed_size % variant == 0);

uint8x16_t unpacked0;
uint8x16_t unpacked1;
uint8x16_t unpacked2;
uint8x16_t unpacked3;
uint8x16_t unpacked4;
uint8x16_t unpacked5;
uint8x16_t unpacked6;
uint8x16_t unpacked7;

switch (variant) {
case 8:
for (int i = 0; i < unpacked_size; i += 8) {
torchao::bitpacking::internal::pack_8_uint1_values(
packed + ((i * nbit) / bitsPerByte), unpacked + i);
}
break;
case 64:
for (int i = 0; i < unpacked_size; i += 64) {
torchao::bitpacking::internal::vec_load_64_uint8_values(
unpacked0, unpacked1, unpacked2, unpacked3, unpacked + i);
torchao::bitpacking::internal::vec_pack_64_uint1_values(
packed + ((i * nbit) / bitsPerByte),
unpacked0,
unpacked1,
unpacked2,
unpacked3);
}
break;
case 128:
for (int i = 0; i < unpacked_size; i += 128) {
torchao::bitpacking::internal::vec_load_64_uint8_values(
unpacked0, unpacked1, unpacked2, unpacked3, unpacked + i);
torchao::bitpacking::internal::vec_load_64_uint8_values(
unpacked4, unpacked5, unpacked6, unpacked7, unpacked + i + 64);
torchao::bitpacking::internal::vec_pack_128_uint1_values(
packed + ((i * nbit) / bitsPerByte),
unpacked0,
unpacked1,
unpacked2,
unpacked3,
unpacked4,
unpacked5,
unpacked6,
unpacked7);
}
break;
}
}

// Benchmark utility to compare variants of uint1 packing
void unpack_uint1_values(
uint8_t* unpacked,
uint8_t* packed,
int unpacked_size,
int packed_size,
int variant) {
constexpr int nbit = 1;
constexpr int bitsPerByte = 8;
assert(unpacked_size * nbit / bitsPerByte == packed_size);
assert(packed_size % variant == 0);

uint8x16_t unpacked0;
uint8x16_t unpacked1;
uint8x16_t unpacked2;
uint8x16_t unpacked3;
uint8x16_t unpacked4;
uint8x16_t unpacked5;
uint8x16_t unpacked6;
uint8x16_t unpacked7;

switch (variant) {
case 8:
for (int i = 0; i < unpacked_size; i += 8) {
torchao::bitpacking::internal::unpack_8_uint1_values(
unpacked + i, packed + ((i * nbit) / bitsPerByte));
}
break;
case 64:
for (int i = 0; i < unpacked_size; i += 64) {
torchao::bitpacking::internal::vec_unpack_64_uint1_values(
unpacked0,
unpacked1,
unpacked2,
unpacked3,
packed + ((i * nbit) / bitsPerByte));
torchao::bitpacking::internal::vec_store_64_uint8_values(
unpacked + i, unpacked0, unpacked1, unpacked2, unpacked3);
}
break;
case 128:
for (int i = 0; i < unpacked_size; i += 128) {
torchao::bitpacking::internal::vec_unpack_128_uint1_values(
unpacked0,
unpacked1,
unpacked2,
unpacked3,
unpacked4,
unpacked5,
unpacked6,
unpacked7,
packed + ((i * nbit) / bitsPerByte));
torchao::bitpacking::internal::vec_store_64_uint8_values(
unpacked + i, unpacked0, unpacked1, unpacked2, unpacked3);
torchao::bitpacking::internal::vec_store_64_uint8_values(
unpacked + i + 64, unpacked4, unpacked5, unpacked6, unpacked7);
}
break;
}
}

// Benchmark utility to compare variants of uint2 packing
void pack_uint2_values(
uint8_t* packed,
Expand Down Expand Up @@ -470,6 +593,44 @@ void unpack_uint5_values(

} // namespace

static void benchmark_pack_uint1_values(benchmark::State& state) {
int unpacked_size = state.range(0);
int variant = state.range(1);
int nbit = 1;

assert(unpacked_size % 8 == 0);
int packed_size = (unpacked_size / 8) * nbit;

auto packed = std::vector<uint8_t>(packed_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit);

for (auto _ : state) {
pack_uint1_values(
packed.data(), unpacked.data(), packed_size, unpacked_size, variant);
}
}

static void benchmark_unpack_uint1_values(benchmark::State& state) {
int unpacked_size = state.range(0);
int variant = state.range(1);
int nbit = 1;

assert(unpacked_size % 8 == 0);
int packed_size = (unpacked_size / 8) * nbit;

auto packed = torchao::get_random_lowbit_vector(packed_size, 8);
auto unpacked = std::vector<uint8_t>(unpacked_size, 0);

for (auto _ : state) {
unpack_uint1_values(
unpacked.data(),
packed.data(),
unpacked.size(),
packed.size(),
variant);
}
}

static void benchmark_pack_uint2_values(benchmark::State& state) {
int unpacked_size = state.range(0);
int variant = state.range(1);
Expand All @@ -478,8 +639,8 @@ static void benchmark_pack_uint2_values(benchmark::State& state) {
assert(unpacked_size % 8 == 0);
int packed_size = (unpacked_size / 8) * nbit;

auto packed = std::vector<uint8_t>(unpacked_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8);
auto packed = std::vector<uint8_t>(packed_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit);

for (auto _ : state) {
pack_uint2_values(
Expand Down Expand Up @@ -516,8 +677,8 @@ static void benchmark_pack_uint3_values(benchmark::State& state) {
assert(unpacked_size % 8 == 0);
int packed_size = (unpacked_size / 8) * nbit;

auto packed = std::vector<uint8_t>(unpacked_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8);
auto packed = std::vector<uint8_t>(packed_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit);

for (auto _ : state) {
pack_uint3_values(
Expand Down Expand Up @@ -554,8 +715,8 @@ static void benchmark_pack_uint4_values(benchmark::State& state) {
assert(unpacked_size % 8 == 0);
int packed_size = (unpacked_size / 8) * nbit;

auto packed = std::vector<uint8_t>(unpacked_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8);
auto packed = std::vector<uint8_t>(packed_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit);

for (auto _ : state) {
pack_uint4_values(
Expand Down Expand Up @@ -592,8 +753,8 @@ static void benchmark_pack_uint5_values(benchmark::State& state) {
assert(unpacked_size % 8 == 0);
int packed_size = (unpacked_size / 8) * nbit;

auto packed = std::vector<uint8_t>(unpacked_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8);
auto packed = std::vector<uint8_t>(packed_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit);

for (auto _ : state) {
pack_uint5_values(
Expand Down Expand Up @@ -622,6 +783,8 @@ static void benchmark_unpack_uint5_values(benchmark::State& state) {
}
}

BENCHMARK(benchmark_pack_uint1_values)->ArgsProduct({{128}, {8, 64, 128}});
BENCHMARK(benchmark_unpack_uint1_values)->ArgsProduct({{128}, {8, 64, 128}});
BENCHMARK(benchmark_pack_uint2_values)->ArgsProduct({{128}, {4, 32, 64}});
BENCHMARK(benchmark_unpack_uint2_values)->ArgsProduct({{128}, {4, 32, 64}});
BENCHMARK(benchmark_pack_uint3_values)->ArgsProduct({{128}, {8, 64, 128}});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot(
false>) \
->ArgsProduct(BENCHMARK_PARAMS)

BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT(
1);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT(
2);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT(
Expand All @@ -236,6 +238,8 @@ BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT
4);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT(
5);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
1);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
2);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
Expand All @@ -244,6 +248,8 @@ BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT
4);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
5);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
1);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT(
2);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT(
Expand Down
Loading

0 comments on commit 275541d

Please sign in to comment.