Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sync master #7

Merged
merged 47 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
dacfceb
readme : add GPT-NeoX + Pythia to the list of supported models (#7491)
felladrin May 23, 2024
55ac3b7
ci : use Pythia models instead of OpenLlama (#7470)
ggerganov May 23, 2024
3015851
llama : add getters for n_threads/n_threads_batch (#7464)
danbev May 23, 2024
8b94e79
readme : add Bunny in supported models [no ci] (#7469)
criminact May 23, 2024
007489e
Fix phi3 chat template confusion with zephyr (#7449)
tristandruyen May 23, 2024
1debe72
ggml : silence UB sanitizer error during iq2_xxs quantization (#0)
ggerganov May 23, 2024
74f33ad
readme : remove trailing space (#7469)
ggerganov May 23, 2024
0df0aa8
add build shared lib in win release package (#7438)
arthw May 24, 2024
fbca2f2
Add support for ArcticForCausalLM (#7020)
fairydreaming May 24, 2024
27891f6
docker.yml: disable light-intel and server-intel test (#7515)
mofosyne May 24, 2024
d041d2c
flake.lock: Update (#7232)
ggerganov May 24, 2024
b83bab1
gguf-py : fix and simplify quantized shape round-trip (#7483)
compilade May 25, 2024
5768433
Make tokenize CLI tool have nicer command line arguments. (#6188)
Noeda May 25, 2024
902184d
fix missing slash in `fs_get_cache_directory()` (#7503)
ngxson May 25, 2024
9791f40
android : module (#7502)
eltonkola May 25, 2024
faa0e69
ggml: aarch64: SVE kernels for q8_0_q8_0, q4_0_q8_0 vector dot (#7433)
msy-kato May 25, 2024
00c6390
main : don't print special tokens with --grammar (#6923)
jart May 25, 2024
3cbd23e
labeler: added Apple Metal detector (+Kompute) (#7529)
mofosyne May 25, 2024
9588f19
train : change default FA argument (#7528)
ggerganov May 25, 2024
b9adcbb
SimpleChat Completion Mode flexibility and cleanup, Settings gMe, Opt…
hanishkvc May 26, 2024
9146d36
Readme: add akx/ggify to tools (#1484)
akx May 26, 2024
c429b33
llama : add Smaug 70B support (#7402)
bartowski1182 May 26, 2024
32a2821
Fix aya-23 conversion scripts (#7539)
Galunid May 26, 2024
d298382
main: replace --no-special with --special (#7534)
mofosyne May 26, 2024
dff451c
flake.lock: Update (#7540)
ggerganov May 26, 2024
d6ef0e7
github: add self sorted issue ticket forms (#7543)
mofosyne May 27, 2024
eaf6e03
llama : add comments about experimental flags (#7544)
ggerganov May 27, 2024
62bfef5
metal : disable FA kernel for HS=256 (#7556)
ggerganov May 27, 2024
1d8fca7
metal : add GGML_OP_REPEAT kernels (#7557)
ggerganov May 27, 2024
5487593
Add freq factors (#7495)
AidanBeltonS May 27, 2024
95f84d5
Fix q_xxs using mul_mat_q (#7459)
AidanBeltonS May 27, 2024
197c006
Allow multiple copy function pointers for CUDA graph kernel param upd…
agray3 May 27, 2024
10b1e45
make: add --device-debug to NVCC debug flags (#7542)
JohannesGaessler May 27, 2024
0136966
adding in x64 targets to cmake presets (#7574)
kunnis May 27, 2024
852aafb
update HIP_UMA #7399 (#7414)
Djip007 May 27, 2024
74b239b
llava : update clip.h (#7580)
eltociear May 28, 2024
c417671
Markdownish code block fix (#7571)
nathan-sixnines May 28, 2024
9335b96
server: do not remove whitespace at the start of a completion chunk (…
mgroeber9110 May 28, 2024
0548a41
ggml : generalize GGML_OP_CONCAT (#7563)
ggerganov May 28, 2024
e2b0650
[SYCL]fix ggml_sycl_mul_mat_id() to match the change of api (#7436)
arthw May 28, 2024
271ff3f
github: add refactor to issue template (#7561)
mofosyne May 28, 2024
8b99e2a
llama : handle unknown utf8 bytes (#7588)
ggerganov May 28, 2024
edc2943
tests : fix test-tokenizer-0.sh
ggerganov May 28, 2024
ee3dff6
Add support for DeepseekV2ForCausalLM (#7519)
fairydreaming May 28, 2024
2b737ca
rpc : resource management rework (#7562)
rgerganov May 28, 2024
56411a9
vulkan: properly initialize vulkan devices for LLAMA_SPLIT_MODE_NONE …
Adriankhl May 28, 2024
8767ce2
Merge branch 'prepare-PR-of-minicpm-v2.5' into prepare-PR
tc-mb May 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
metal : add GGML_OP_REPEAT kernels (ggml-org#7557)
ggml-ci
  • Loading branch information
ggerganov authored May 27, 2024
commit 1d8fca72ae9154eec0e1c0a75cfaac3c50f08e4a
53 changes: 49 additions & 4 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
GGML_METAL_KERNEL_TYPE_MUL_ROW,
GGML_METAL_KERNEL_TYPE_DIV,
GGML_METAL_KERNEL_TYPE_DIV_ROW,
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
GGML_METAL_KERNEL_TYPE_REPEAT_I16,
GGML_METAL_KERNEL_TYPE_SCALE,
GGML_METAL_KERNEL_TYPE_SCALE_4,
GGML_METAL_KERNEL_TYPE_CLAMP,
Expand Down Expand Up @@ -485,6 +489,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
Expand Down Expand Up @@ -746,6 +754,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
case GGML_OP_ACC:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_REPEAT:
case GGML_OP_SCALE:
case GGML_OP_CLAMP:
case GGML_OP_SQR:
Expand Down Expand Up @@ -979,8 +988,6 @@ static enum ggml_status ggml_metal_graph_compute(
switch (dst->op) {
case GGML_OP_CONCAT:
{
const int64_t nb = ne00;

id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;

[encoder setComputePipelineState:pipeline];
Expand Down Expand Up @@ -1011,7 +1018,6 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];

const int nth = MIN(1024, ne0);

Expand All @@ -1021,11 +1027,14 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_OP_MUL:
case GGML_OP_DIV:
{
GGML_ASSERT(src0t == GGML_TYPE_F32);
GGML_ASSERT(src1t == GGML_TYPE_F32);

const size_t offs = 0;

bool bcast_row = false;

int64_t nb = ne00;
int64_t nb = ne00; // used by the "row" kernels

id<MTLComputePipelineState> pipeline = nil;

Expand Down Expand Up @@ -1094,6 +1103,42 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
}
} break;
case GGML_OP_REPEAT:
{
id<MTLComputePipelineState> pipeline;

switch (src0t) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
default: GGML_ASSERT(false);
}

[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];

const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);

[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_ACC:
{
GGML_ASSERT(src0t == GGML_TYPE_F32);
Expand Down
47 changes: 47 additions & 0 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,53 @@ kernel void kernel_div(
}
}

template<typename T>
kernel void kernel_repeat(
device const char * src0,
device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;

const int64_t i03 = i3 % ne03;
const int64_t i02 = i2 % ne02;
const int64_t i01 = i1 % ne01;

device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;

for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
const int i00 = i0 % ne00;
*((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
}
}

typedef decltype(kernel_repeat<float>) kernel_repeat_t;

template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;

// assumption: src1 is a row
// broadcast src1 into src0
kernel void kernel_add_row(
Expand Down