Skip to content

Commit

Permalink
Update q4 quantization format (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
ling0322 authored Dec 10, 2023
1 parent bcdd72b commit 72bd59f
Show file tree
Hide file tree
Showing 33 changed files with 506 additions and 1,537 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ transformers
__pycache__
*.model
*.bin
*.1
*.2
*.egg-info
*.so
*.dll
Expand Down
5 changes: 4 additions & 1 deletion src/llm/chatglm2/chatglm2_model_for_generation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include "llm/common/constants.h"

using ly::Tensor;
namespace F = ly::functional;


namespace libllm {
namespace chatglm2 {
Expand Down Expand Up @@ -64,7 +66,8 @@ ly::Tensor ChatGLM2ModelForGeneration::buildInput(
}

Tensor ChatGLM2ModelForGeneration::forward(ly::StateMap &past, Tensor input) const {
return _model->forward(past, input);
Tensor x = _model->forward(past, input);
return x;
}

Tensor ChatGLM2ModelForGeneration::forwardHidden(Tensor hidden) const {
Expand Down
18 changes: 16 additions & 2 deletions src/llm/cli/llm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,36 @@ void printChatStat(const ChatOutput &chatOutput) {

int main(int argc, char **argv) {
std::string configPath;
std::string deviceType = "auto";

const char *usage =
"Command line interface for libllm.\n"
"Usage: llm --config <libllm-config-file>";
"Usage: llm -config <libllm-config-file> [-device (cpu|gpu|cuda)]";

lut::Flags flags(usage);
flags.define("-config", &configPath, "filename of libllm config file.");
flags.define("-device", &deviceType, "device of the model. (cpu|cuda|auto)");
flags.parse(argc, argv);

if (configPath.empty()) {
flags.printUsage();
return 1;
}

llm::DeviceType device = llm::DeviceType::AUTO;
if (deviceType == "auto")
device = llm::DeviceType::AUTO;
else if (deviceType == "cuda")
device = llm::DeviceType::CUDA;
else if (deviceType == "cpu")
device = llm::DeviceType::CPU;
else {
printf("invalid device");
return 1;
}

llm::init();
std::shared_ptr<llm::Model> model = llm::Model::create(configPath);
std::shared_ptr<llm::Model> model = llm::Model::create(configPath, device);
std::shared_ptr<PromptBulder> promptBuilder = PromptBulder::create(model->getName());
DialogManager dialogManager(model, promptBuilder);
for (; ; ) {
Expand Down
20 changes: 9 additions & 11 deletions src/ly/dtype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace ly {
constexpr int16_t DType::kUnknown;
constexpr int16_t DType::kFloat;
constexpr int16_t DType::kLong;
constexpr int16_t DType::kQInt4SymGroup32;
constexpr int16_t DType::kUInt8;
constexpr int16_t DType::kFloat16;
constexpr int16_t DType::kQInt4Group32;
constexpr int16_t DType::kInt8;
Expand All @@ -46,8 +46,8 @@ DType DType::getTypeImpl<int64_t>() {
return DType::kLong;
}
template<>
DType DType::getTypeImpl<QInt4SymGroup32>() {
return DType::kQInt4SymGroup32;
DType DType::getTypeImpl<UInt8>() {
return DType::kUInt8;
}
template<>
DType DType::getTypeImpl<Float16>() {
Expand All @@ -71,7 +71,7 @@ DType DType::getTypeImpl<half>() {

template DType DType::getTypeImpl<float>();
template DType DType::getTypeImpl<int64_t>();
template DType DType::getTypeImpl<QInt4SymGroup32>();
template DType DType::getTypeImpl<UInt8>();
template DType DType::getTypeImpl<Float16>();
template DType DType::getTypeImpl<QInt4Group32>();
template DType DType::getTypeImpl<Int8>();
Expand All @@ -85,11 +85,11 @@ int64_t DType::getTotalSize(int64_t numel) const {
return 2 * numel;
case DType::kLong:
return 8 * numel;
case DType::kQInt4SymGroup32:
case DType::kQInt4Group32:
CHECK(numel % 2 == 0);
return numel / 2;
case DType::kInt8:
case DType::kUInt8:
return numel;
default:
NOT_IMPL();
Expand All @@ -102,7 +102,7 @@ bool DType::isValid() const {
case DType::kFloat:
case DType::kFloat16:
case DType::kLong:
case DType::kQInt4SymGroup32:
case DType::kUInt8:
case DType::kQInt4Group32:
case DType::kInt8:
return true;
Expand All @@ -113,7 +113,6 @@ bool DType::isValid() const {

bool DType::isQuantized() const {
switch (_dtype) {
case DType::kQInt4SymGroup32:
case DType::kQInt4Group32:
return true;
default:
Expand All @@ -133,7 +132,6 @@ bool DType::isFloat() const {

int DType::getGroupSize() const {
switch (_dtype) {
case DType::kQInt4SymGroup32:
case DType::kQInt4Group32:
return 32;
default:
Expand All @@ -149,10 +147,10 @@ std::string DType::toString() const {
return "float32";
case DType::kLong:
return "int64";
case DType::kQInt4SymGroup32:
return "qint4symg32";
case DType::kUInt8:
return "uint8";
case DType::kQInt4Group32:
return "qint4g32";
return "q4";
case DType::kInt8:
return "int8";
default:
Expand Down
13 changes: 9 additions & 4 deletions src/ly/dtype.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,22 @@ struct QInt4Group32 {
static constexpr int GroupSize = 32;
uint8_t int4x2;
};
static_assert(sizeof(QInt4Group32) == 1, "invalid size of QInt4SymGroup32");
static_assert(sizeof(QInt4Group32) == 1, "invalid size of QInt4Group32");

struct Float16 {
uint16_t v;
};
static_assert(sizeof(Float16) == 2, "invalid size of QInt4SymGroup32");
static_assert(sizeof(Float16) == 2, "invalid size of Float16");

struct Int8 {
int8_t v;
};
static_assert(sizeof(Int8) == 1, "invalid size of QInt4SymGroup32");
static_assert(sizeof(Int8) == 1, "invalid size of Int8");

struct UInt8 {
uint8_t v;
};
static_assert(sizeof(Int8) == 1, "invalid size of UInt8");

typedef int8_t Byte;
typedef int64_t LongType;
Expand All @@ -56,7 +61,7 @@ class DType {
static constexpr int16_t kUnknown = 0;
static constexpr int16_t kFloat = 1;
static constexpr int16_t kLong = 2;
static constexpr int16_t kQInt4SymGroup32 = 3;
static constexpr int16_t kUInt8 = 3;
static constexpr int16_t kFloat16 = 4;
static constexpr int16_t kQInt4Group32 = 5;
static constexpr int16_t kInt8 = 6;
Expand Down
8 changes: 3 additions & 5 deletions src/ly/internal/tensor_data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,13 @@ void TensorData::throwIfInvalid() {
throw lut::AbortedError("invalid tensor (dtype=unknown).");
break;
case DType::kQInt4Group32:
if (getSlot(1)->getDType() != DType::kFloat16 || getSlot(2)->getDType() != DType::kUInt8)
throw lut::AbortedError("invalid q4 tensor data type.");
if (getNumEl() / getDType().getGroupSize() != getSlot(1)->getNumEl())
throw lut::AbortedError("tensor data and scale size mismatch.");
if (getNumEl() / getDType().getGroupSize() != getSlot(2)->getNumEl())
if ((getNumEl() / getDType().getGroupSize() + 1) / 2 != getSlot(2)->getNumEl())
throw lut::AbortedError("tensor data and zero-point size mismatch.");
break;
case DType::kQInt4SymGroup32:
if (getNumEl() / getDType().getGroupSize() != getSlot(1)->getNumEl())
throw lut::AbortedError("tensor data and scale size mismatch.");
break;
}
}

Expand Down
2 changes: 0 additions & 2 deletions src/ly/operators/cpu/cpu_operators.cc
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,6 @@ Tensor CPUOperators::matmul(Tensor A, Tensor B) {
DType typeB = B.getDType();
if (typeA == DType::kFloat && typeB == DType::kFloat) {
return matmulFp32(A, B);
} else if (typeA == DType::kFloat && typeB == DType::kQInt4SymGroup32) {
return matmulFp32Q4SymFp32(A, B);
} else if (typeA == DType::kFloat && typeB == DType::kQInt4Group32) {
return matmulFp32Q4Fp32(A, B);
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/ly/operators/cpu/cpu_tensor_data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ std::shared_ptr<internal::TensorData> CpuTensorData::read(lut::ReadableFile *fp)
throw lut::AbortedError("bad tensor data format.");

int32_t numSlot = fp->readValue<int32_t>();
if (numSlot <= 0 && numSlot > 2)
if (numSlot <= 0 || numSlot > 3)
throw lut::AbortedError("invalid num slot.");

// slot 0
Expand Down
15 changes: 1 addition & 14 deletions src/ly/operators/cpu/lookup.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,6 @@ Tensor lookupFp32(Subtensor<const float> table, Subtensor<const LongType> indice
template<typename T>
void applyDequant(int64_t offset, int n, const internal::TensorData *data, float *tgt);

template<>
void applyDequant<QInt4SymGroup32>(
int64_t offset, int n, const internal::TensorData *data, float *tgt) {
lymath_dequant_q4sym(
n,
(const lymath_q4x2_t *)data->getData<QInt4SymGroup32>(offset),
(const lymath_float16_t *)data->getSlot(1)->getData<Float16>(offset / QInt4SymGroup32::GroupSize),
tgt);
}

template<>
void applyDequant<QInt4Group32>(
int64_t offset, int n, const internal::TensorData *data, float *tgt) {
Expand All @@ -74,7 +64,7 @@ void applyDequant<QInt4Group32>(
(const lymath_q4x2_t *)data->getData<QInt4Group32>(offset),
(const lymath_float16_t *)data->getSlot(1)->getData<Float16>(
offset / QInt4Group32::GroupSize),
(const int8_t *)data->getSlot(2)->getData<Int8>(offset / QInt4Group32::GroupSize),
(const uint8_t *)data->getSlot(2)->getData<UInt8>(offset / QInt4Group32::GroupSize / 2),
tgt);
}

Expand Down Expand Up @@ -117,9 +107,6 @@ Tensor lookup(const Tensor &table, const Tensor &indices) {
return lookupFp32(
Subtensor<const float>::fromTensor(table),
Subtensor<const LongType>::fromTensor(indices));
case DType::kQInt4SymGroup32:
return lookupQuantized<QInt4SymGroup32>(
table, Subtensor<const LongType>::fromTensor(indices));
case DType::kQInt4Group32:
return lookupQuantized<QInt4Group32>(table, Subtensor<const LongType>::fromTensor(indices));
default:
Expand Down
51 changes: 1 addition & 50 deletions src/ly/operators/cpu/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,55 +132,6 @@ Tensor bmmFp32QInt4Fp32(const Tensor &A, const Tensor &B) {
return Tensor();
}

// -- q4sym ----------

Tensor gemmFp32Q4SymFp32(const Tensor &A, const Tensor &B) {
CHECK(A.getDim() == B.getDim() && A.getDim() == 2 && B.getDType() == DType::kQInt4SymGroup32);

Tensor C = cpu::tensor({A.getShape(0), B.getShape(1)}, DType::kFloat);
Subtensor<float> Cs = Subtensor<float>::fromTensor(C);
zerosFp32(Cs);

common::GEMMArgs gemmArgs = common::generateGemmArgs(A, B, C);
const internal::TensorData *dataObjectB = B.getDataObject();
lymath_qgemm_nqn_q4sym_omp(
gemmArgs.transA,
gemmArgs.transB,
gemmArgs.M,
gemmArgs.N,
gemmArgs.K,
A.getData<float>(),
gemmArgs.lda,
reinterpret_cast<const lymath_q4x2_t *>(dataObjectB->getData<QInt4SymGroup32>()),
reinterpret_cast<const lymath_float16_t *>(dataObjectB->getSlot(1)->getData<Float16>()),
Cs.data,
gemmArgs.ldc);

return C;
}

Tensor bmmNx2Fp32Q4SymFp32(const Tensor &A, const Tensor &B) {
std::vector<int> shape = A.getShape();

Tensor xA = A.view({-1, A.getShape(-1)});
Tensor xC = gemmFp32Q4SymFp32(xA, B);

shape.back() = B.getShape(1);
return xC.view(shape);
}

Tensor matmulFp32Q4SymFp32(const Tensor &A, const Tensor &B) {
if (A.getDim() == 2 && B.getDim() == 2) {
return gemmFp32Q4SymFp32(A, B);
} else if (A.getDim() > 2 && A.isContiguous() && B.getDim() == 2) {
return bmmNx2Fp32Q4SymFp32(A, B);
} else {
NOT_IMPL();
}

return Tensor();
}

// -- q4 ----------

Tensor gemmFp32Q4Fp32(const Tensor &A, const Tensor &B) {
Expand All @@ -202,7 +153,7 @@ Tensor gemmFp32Q4Fp32(const Tensor &A, const Tensor &B) {
gemmArgs.lda,
reinterpret_cast<const lymath_q4x2_t *>(dataObjectB->getData<QInt4Group32>()),
reinterpret_cast<const lymath_float16_t *>(dataObjectB->getSlot(1)->getData<Float16>()),
reinterpret_cast<const int8_t *>(dataObjectB->getSlot(2)->getData<Int8>()),
reinterpret_cast<const uint8_t *>(dataObjectB->getSlot(2)->getData<UInt8>()),
Cs.data,
gemmArgs.ldc);

Expand Down
5 changes: 0 additions & 5 deletions src/ly/operators/cpu/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,6 @@ Tensor bmmFp32(const Tensor &A, const Tensor &B);
Tensor bmmNx2Fp32(const Tensor &A, const Tensor &B);
Tensor gemmFp32(const Tensor &A, const Tensor &B);

// q4sym
Tensor matmulFp32Q4SymFp32(const Tensor &A, const Tensor &B);
Tensor gemmFp32Q4SymFp32(const Tensor &A, const Tensor &B);
Tensor bmmNx2Fp32Q4SymFp32(const Tensor &A, const Tensor &B);

// q4
Tensor matmulFp32Q4Fp32(const Tensor &A, const Tensor &B);
Tensor gemmFp32Q4Fp32(const Tensor &A, const Tensor &B);
Expand Down
2 changes: 1 addition & 1 deletion src/ly/operators/cuda/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ PackedSubtensor2DQ4::PackedSubtensor2DQ4(const Tensor &tensor) {

_data = (const uint8_t *)tensor.getDataObject()->getSlot(0)->getRawData();
_scale = (const __half *)tensor.getDataObject()->getSlot(1)->getRawData();
_bias = (const int8_t *)tensor.getDataObject()->getSlot(2)->getRawData();
_zero = (const uint8_t *)tensor.getDataObject()->getSlot(2)->getRawData();
}

Tensor createCudaTensorHalf(lut::Span<const int> shape) {
Expand Down
15 changes: 11 additions & 4 deletions src/ly/operators/cuda/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,20 @@ struct PackedSubtensor2DQ4 {

const half *_scale;
const uint8_t *_data;
const int8_t *_bias;
const uint8_t *_zero;

__device__ int getNumRow() const { return _numRow; }
__device__ int getNumCol() const { return _numCol; }
__device__ const half *getScale() const { return _scale; }
__device__ const uint8_t *getData() const { return _data; }
__device__ const int8_t *getBias() const { return _bias; }
__device__ const uint8_t *getData(int groupIdx) const { return _data + groupIdx * 16; }

__device__ half getScaleValue(int groupIdx) const { return _scale[groupIdx]; }
__device__ uint8_t getZeroValue(int groupIdx) const {
uint8_t zero = _zero[groupIdx / 2];
if (groupIdx % 2) {
zero = zero >> 4;
}
return zero & 0xf;
}

PackedSubtensor2DQ4(const Tensor &tensor);
};
Expand Down
10 changes: 5 additions & 5 deletions src/ly/operators/cuda/dequant.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ void dequantTensor2DQ4(PackedSubtensor2DQ4 qtensor,

for (int elemOffset = idx; elemOffset < numQ4x2; elemOffset += gridDim.x * blockDim.x) {
int elemGroup = elemOffset / 16;
uint8_t q4elem = qtensor.getData()[elemOffset];
half scale = qtensor.getScale()[elemGroup];
int8_t bias = qtensor.getBias()[elemGroup];
uint8_t q4elem = qtensor.getData(0)[elemOffset];
half scale = qtensor.getScaleValue(elemGroup);
int8_t zero = qtensor.getZeroValue(elemGroup);

destData[elemOffset * 2] = __hmul(scale, __int2half_rd(static_cast<int>(q4elem >> 4) - bias));
destData[elemOffset * 2] = __hmul(scale, __int2half_rd(static_cast<int>(q4elem & 0xf) - zero));
destData[elemOffset * 2 + 1] = __hmul(
scale, __int2half_rd(static_cast<int>(q4elem & 0xf) - bias));
scale, __int2half_rd(static_cast<int>(q4elem >> 4) - zero));
}
}

Expand Down
Loading

0 comments on commit 72bd59f

Please sign in to comment.