Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml : update mul_mat_id to use the same tensor for all the experts #6387

Merged
merged 36 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
0c7e21d
ggml : update mul_mat_id to use the same tensor for all the experts
slaren Mar 29, 2024
9c9fe60
update cuda
slaren Mar 29, 2024
2479900
minor
slaren Mar 29, 2024
93db37e
update metal
slaren Mar 29, 2024
325e5ef
update test-backend-ops
slaren Mar 29, 2024
26c09ad
fix cuda
slaren Mar 29, 2024
2abb6c7
Update ggml-metal.m
slaren Mar 30, 2024
6203d72
update convert.py
slaren Mar 30, 2024
4a5d50e
update convert-hf-to-gguf.py
slaren Mar 31, 2024
3b3298a
update convert.py for mixtral hf models
slaren Mar 31, 2024
8c2f7b8
Update convert-hf-to-gguf.py
slaren Mar 31, 2024
4531b02
cuda : support non-pow-2 number of experts
slaren Apr 1, 2024
6886fdb
allow quantize to work for split and merged experts models in the sam…
slaren Apr 1, 2024
deea200
cleanup + disable mmap automatically with split tensors models
slaren Apr 1, 2024
b4a6206
update imatrix
slaren Apr 2, 2024
8f84ca3
test-backend-ops : test qwen argsort
slaren Apr 2, 2024
5de4a5d
update grok model loading
slaren Apr 2, 2024
6875369
llama : add merged experts tensors to the grok tensor map
slaren Apr 2, 2024
6f33852
minor
slaren Apr 2, 2024
68d21de
gguf : bump version
slaren Apr 2, 2024
f27cbf3
fix quantizing of merged experts
slaren Apr 2, 2024
d08a1f4
convert-hf-to-gguf.py : update grok (untested)
slaren Apr 2, 2024
9530398
make linter happy
slaren Apr 2, 2024
f421b32
cuda/argsort : use shared memory instead of pool memory
slaren Apr 2, 2024
c704c77
convert : fix grok tensor names
ggerganov Apr 2, 2024
fe62909
metal : add support for non-pow-2 argsort
slaren Apr 2, 2024
31adc93
llama : more loader cleanup, better error checking
slaren Apr 2, 2024
86f3666
cuda : fix warning
slaren Apr 2, 2024
a1343ae
llama : still use mmap for loading old models, but copy the data to a…
slaren Apr 2, 2024
19dafaf
add review note
slaren Apr 3, 2024
3779b98
llama : remove ffn tensor counting + add sanity check
ggerganov Apr 3, 2024
e810899
convert : fix handling of n_experts == None
ggerganov Apr 3, 2024
fc719b6
imatrix : fix ncall counters
ggerganov Apr 3, 2024
822caa4
llama : produce error if imatrix size does not match
ggerganov Apr 3, 2024
a054283
quantize : terminate on errors + trace logs
ggerganov Apr 3, 2024
716e960
metal : pad shared memory to 16 bytes
ggerganov Apr 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 16 additions & 196 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -401,10 +401,8 @@ GGML_CALL static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t
GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;

if (tensor->view_src != NULL && tensor->view_offs == 0) {
if (tensor->view_src != NULL) {
assert(tensor->view_src->buffer->buft == buffer->buft);
tensor->backend = tensor->view_src->backend;
tensor->extra = tensor->view_src->extra;
return;
}

Expand Down Expand Up @@ -1962,227 +1960,49 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
}
}

#if 0
template<typename ... Srcs>
static __global__ void k_compute_batched_ptrs_id(
const void ** ptrs_src, void ** ptrs_dst,
int ne12, int ne13,
int ne23,
int nb02, int nb03,
int nb12, int nb13,
int nb2, int nb3,
int r2, int r3,
ggml_type src0_type, half * src0_as_f16, int64_t src0_ne,
const half * src1_f16, half * dst_f16,
const int32_t * ids, const int id,
Srcs... src0s) {

int i = ids[id];

half * src0_f16;
const void * srcs_ar[] = { (const half *) src0s... };
if (src0_type == GGML_TYPE_F16) {
src0_f16 = (half *) srcs_ar[i];
} else {
src0_f16 = src0_as_f16;
if (threadIdx.x == 0 && threadIdx.y == 0) {
const to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(src0_type);
to_fp16(srcs_ar[i], src0_f16, src0_ne, cudaStreamFireAndForget);
}
}

int i13 = blockIdx.x * blockDim.x + threadIdx.x;
int i12 = blockIdx.y * blockDim.y + threadIdx.y;

if (i13 >= ne13 || i12 >= ne12) {
return;
}

int i03 = i13 / r3;
int i02 = i12 / r2;

ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_f16 + i02*nb02 + i03*nb03;
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_f16 + i12*nb12/2 + i13*nb13/2;
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
}

static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
const struct ggml_tensor * ids = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
const struct ggml_tensor * src00 = dst->src[2];

const int id = dst->op_params[0];

GGML_ASSERT(!ggml_is_transposed(src00));
GGML_ASSERT(!ggml_is_transposed(src1));

GGML_ASSERT(src00->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
GGML_ASSERT(src1->type == GGML_TYPE_F32);

const int64_t ne00 = src00->ne[0]; GGML_UNUSED(ne00);
const int64_t ne01 = src00->ne[1];
const int64_t ne02 = src00->ne[2];
const int64_t ne03 = src00->ne[3];

//const int64_t nb01 = src00->nb[1];
const int64_t nb02 = src00->nb[2]; GGML_UNUSED(nb02);
const int64_t nb03 = src00->nb[3]; GGML_UNUSED(nb03);

const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
const int64_t ne12 = src1->ne[2];
const int64_t ne13 = src1->ne[3];

//const int64_t nb11 = src1->nb[1];
const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);

const int64_t ne1 = ggml_nelements(src1);
const int64_t ne = ggml_nelements(dst);

ggml_cuda_set_device(g_main_device);
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];

CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream));

//ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
//void * src0_ddq = src0_extra->data_device[g_main_device];
//half * src0_as_f16 = (half *) src0_ddq;

ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
float * src1_ddf = (float *) src1_extra->data_device[g_main_device];

ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];

// convert src1 to fp16
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
GGML_ASSERT(to_fp16_cuda != nullptr);

size_t src1_as = 0;
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);

size_t dst_as = 0;
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);

GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);

// broadcast factors
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;

const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;

// use cublasGemmBatchedEx
const int ne23 = ne12*ne13;

const void ** ptrs_src = nullptr;
void ** ptrs_dst = nullptr;

size_t ptrs_src_s = 0;
size_t ptrs_dst_s = 0;

ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);

int64_t src0_ne = ggml_nelements(src00);
half * src0_as_f16 = nullptr;
size_t src0_as = 0;
if (src00->type != GGML_TYPE_F16) {
src0_as_f16 = (half *) ggml_cuda_pool_malloc(src0_ne * sizeof(half), &src0_as);
}

static_assert(GGML_MAX_SRC == 6, "GGML_MAX_SRC == 6");
dim3 block_dims(ne13, ne12);
k_compute_batched_ptrs_id<<<1, block_dims, 0, main_stream>>>(
ptrs_src, ptrs_dst,
ne12, ne13,
ne23,
ne00*ne01*sizeof(half), ne00*ne01*ne02*sizeof(half),
nb12, nb13,
dst->nb[2], dst->nb[3],
r2, r3,
src00->type, src0_as_f16, src0_ne,
src1_as_f16, dst_f16,
(const int *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device], id,
dst->src[2] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[2]->extra)->data_device[g_main_device] : nullptr,
dst->src[3] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[3]->extra)->data_device[g_main_device] : nullptr,
dst->src[4] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[4]->extra)->data_device[g_main_device] : nullptr,
dst->src[5] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[5]->extra)->data_device[g_main_device] : nullptr
);
CUDA_CHECK(cudaGetLastError());

CUBLAS_CHECK(
cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, ne00,
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, ne10,
&beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
ne23,
CUBLAS_COMPUTE_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));

if (src0_as != 0) {
ggml_cuda_pool_free(src0_as_f16, src0_as);
}
if (ptrs_src_s != 0) {
ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
}
if (ptrs_dst_s != 0) {
ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
}

const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);

ggml_cuda_pool_free(src1_as_f16, src1_as);
ggml_cuda_pool_free(dst_f16, dst_as);
}
#endif

static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
#if 0
ggml_cuda_mul_mat_id_cublas(dst);
// TODO: mmq/mmv support
#endif
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * ids = dst->src[2];

GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");

cudaStream_t stream = ctx.stream();

const size_t nb11 = src1->nb[1];
const size_t nb1 = dst->nb[1];

const struct ggml_tensor * ids = src0;
const int32_t id = ((int32_t *) dst->op_params)[0];
const int32_t n_as = ((int32_t *) dst->op_params)[1];
const int32_t n_as = src0->ne[2];

std::vector<char> ids_host(ggml_nbytes(ids));
const char * ids_dev = (const char *) ids->data;
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));

ggml_tensor src0_row = *src0;
ggml_tensor src1_row = *src1;
ggml_tensor dst_row = *dst;

char * src0_original = (char *) src0->data;
char * src1_original = (char *) src1->data;
char * dst_original = (char *) dst->data;

src0_row.ne[2] = 1;
src0_row.ne[3] = 1;
src0_row.nb[3] = src0->nb[2];

if (src1->ne[1] == 1) {
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);

GGML_ASSERT(row_id >= 0 && row_id < n_as);

const struct ggml_tensor * src0_row = dst->src[row_id + 2];

src0_row.data = src0_original + row_id*src0->nb[2];
src1_row.data = src1_original + i01*src1->nb[1];
dst_row.data = dst_original + i01*dst->nb[1];

ggml_cuda_mul_mat(ctx, src0_row, &src1_row, &dst_row);
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
}
} else {
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
Expand All @@ -2192,8 +2012,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
dst_row.data = dst_contiguous.get();

for (int32_t row_id = 0; row_id < n_as; ++row_id) {
const struct ggml_tensor * src0_row = dst->src[row_id + 2];

int64_t num_src1_rows = 0;
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
Expand All @@ -2213,6 +2031,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
continue;
}

src0_row.data = src0_original + row_id*src0->nb[2];

src1_row.ne[1] = num_src1_rows;
dst_row.ne[1] = num_src1_rows;

Expand All @@ -2224,7 +2044,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
dst_row.nb[2] = num_src1_rows*nb1;
dst_row.nb[3] = num_src1_rows*nb1;

ggml_cuda_mul_mat(ctx, src0_row, &src1_row, &dst_row);
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);

num_src1_rows = 0;
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
Expand Down
Loading
Loading