From 5078aa8b44fb4f7154489e14ee387c923e51c6b0 Mon Sep 17 00:00:00 2001
From: Nexesenex <124105151+Nexesenex@users.noreply.github.com>
Date: Tue, 27 Aug 2024 03:36:47 +0200
Subject: [PATCH] Revert "ggml : add SSM Metal kernels (#8546)"

This reverts commit fc18425b6a8ad03847383ce2b69d52edfd49b0ff.
---
 ggml/src/ggml-metal.m      | 122 -------------------------------------
 ggml/src/ggml-metal.metal  | 121 ------------------------------------
 ggml/src/ggml.c            |   4 +-
 tests/test-backend-ops.cpp |  58 ------------------
 4 files changed, 2 insertions(+), 303 deletions(-)

diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index 936751800518b2..3746bf5a44b2d4 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
@@ -82,8 +82,6 @@
     GGML_METAL_KERNEL_TYPE_RMS_NORM,
     GGML_METAL_KERNEL_TYPE_GROUP_NORM,
     GGML_METAL_KERNEL_TYPE_NORM,
-    GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
-    GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
     GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
@@ -544,8 +542,6 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                      rms_norm,                       ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                    group_norm,                     ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM,                          norm,                           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,                  ssm_conv_f32,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,                  ssm_scan_f32,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,                mul_mv_f32_f32,                 ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,                mul_mv_f16_f16,                 ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,                mul_mv_f16_f32,                 ctx->support_simdgroup_reduction);
@@ -807,9 +803,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
                 return false;
             }
             return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
-        case GGML_OP_SSM_CONV:
-        case GGML_OP_SSM_SCAN:
-            return true;
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
             return ctx->support_simdgroup_reduction &&
@@ -1545,121 +1538,6 @@ static enum ggml_status ggml_metal_graph_compute(
                             [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                         }
                     } break;
-                case GGML_OP_SSM_CONV:
-                    {
-                        GGML_ASSERT(src0t == GGML_TYPE_F32);
-                        GGML_ASSERT(src1t == GGML_TYPE_F32);
-
-                        GGML_ASSERT(ggml_is_contiguous(src0));
-                        GGML_ASSERT(ggml_is_contiguous(src1));
-
-                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
-
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
-                        [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
-                        [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
-                        [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:3];
-                        [encoder setBytes:&ne01    length:sizeof(ne01) atIndex:4];
-                        [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:5];
-                        [encoder setBytes:&nb00    length:sizeof(nb00) atIndex:6];
-                        [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:7];
-                        [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:8];
-                        [encoder setBytes:&ne10    length:sizeof(ne10) atIndex:9];
-                        [encoder setBytes:&ne11    length:sizeof(ne11) atIndex:10];
-                        [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:11];
-                        [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:12];
-                        [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:13];
-                        [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:14];
-                        [encoder setBytes:&ne2     length:sizeof(ne2)  atIndex:15];
-                        [encoder setBytes:&nb0     length:sizeof(nb0)  atIndex:16];
-                        [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:17];
-                        [encoder setBytes:&nb2     length:sizeof(nb2)  atIndex:18];
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                    } break;
-                case GGML_OP_SSM_SCAN:
-                    {
-                        struct ggml_tensor * src3 = gf->nodes[i]->src[3];
-                        struct ggml_tensor * src4 = gf->nodes[i]->src[4];
-                        struct ggml_tensor * src5 = gf->nodes[i]->src[5];
-
-                        GGML_ASSERT(src3);
-                        GGML_ASSERT(src4);
-                        GGML_ASSERT(src5);
-
-                        size_t offs_src3 = 0;
-                        size_t offs_src4 = 0;
-                        size_t offs_src5 = 0;
-
-                        id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
-                        id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
-                        id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
-
-                        const int64_t  ne30 = src3->ne[0]; GGML_UNUSED(ne30);
-                        const int64_t  ne31 = src3->ne[1]; GGML_UNUSED(ne31);
-
-                        const uint64_t nb30 = src3->nb[0];
-                        const uint64_t nb31 = src3->nb[1];
-
-                        const int64_t  ne40 = src4->ne[0]; GGML_UNUSED(ne40);
-                        const int64_t  ne41 = src4->ne[1]; GGML_UNUSED(ne41);
-                        const int64_t  ne42 = src4->ne[2]; GGML_UNUSED(ne42);
-
-                        const uint64_t nb40 = src4->nb[0];
-                        const uint64_t nb41 = src4->nb[1];
-                        const uint64_t nb42 = src4->nb[2];
-
-                        const int64_t  ne50 = src5->ne[0]; GGML_UNUSED(ne50);
-                        const int64_t  ne51 = src5->ne[1]; GGML_UNUSED(ne51);
-                        const int64_t  ne52 = src5->ne[2]; GGML_UNUSED(ne52);
-
-                        const uint64_t nb50 = src5->nb[0];
-                        const uint64_t nb51 = src5->nb[1];
-                        const uint64_t nb52 = src5->nb[2];
-
-                        const int64_t d_state      = ne00;
-                        const int64_t d_inner      = ne01;
-                        const int64_t n_seq_tokens = ne11;
-                        const int64_t n_seqs       = ne02;
-
-                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
-
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                        [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                        [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
-                        [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
-                        [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
-                        [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
-                        [encoder setBuffer:id_dst  offset:offs_dst  atIndex:6];
-
-                        [encoder setBytes:&d_state      length:sizeof(d_state)      atIndex:7];
-                        [encoder setBytes:&d_inner      length:sizeof(d_inner)      atIndex:8];
-                        [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
-                        [encoder setBytes:&n_seqs       length:sizeof(n_seqs)       atIndex:10];
-
-                        [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
-                        [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
-                        [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
-                        [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
-                        [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
-                        [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
-                        [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
-                        [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
-                        [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
-                        [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
-                        [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
-                        [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
-                        [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
-                        [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
-                        [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
-                        [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
-                        [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
-                        [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                    } break;
                 case GGML_OP_MUL_MAT:
                     {
                         GGML_ASSERT(ne00 == ne10);
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index 755970f31ce296..34194abc106679 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -667,127 +667,6 @@ kernel void kernel_diag_mask_inf_8(
     }
 }
 
-// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
-// TODO: optimize
-kernel void kernel_ssm_conv_f32(
-        device const  void * src0,
-        device const  void * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t ir = tgpig.x;
-    const int64_t i2 = tgpig.y;
-    const int64_t i3 = tgpig.z;
-
-    const int64_t nc  = ne10;
-    const int64_t ncs = ne00;
-    const int64_t nr  = ne01;
-    const int64_t n_t = ne1;
-    const int64_t n_s = ne2;
-
-    device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
-    device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
-    device       float * x = (device       float *) ((device       char *) dst  + ir*nb0  + i2*nb1  + i3*nb2);
-
-    float sumf = 0.0f;
-
-    for (int64_t i0 = 0; i0 < nc; ++i0) {
-        sumf += s[i0] * c[i0];
-    }
-
-    x[0] = sumf;
-}
-
-// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
-// TODO: optimize
-kernel void kernel_ssm_scan_f32(
-        device const void * src0,
-        device const void * src1,
-        device const void * src2,
-        device const void * src3,
-        device const void * src4,
-        device const void * src5,
-        device      float * dst,
-        constant  int64_t & d_state,
-        constant  int64_t & d_inner,
-        constant  int64_t & n_seq_tokens,
-        constant  int64_t & n_seqs,
-        constant uint64_t & nb00,
-        constant uint64_t & nb01,
-        constant uint64_t & nb02,
-        constant uint64_t & nb10,
-        constant uint64_t & nb11,
-        constant uint64_t & nb12,
-        constant uint64_t & nb13,
-        constant uint64_t & nb20,
-        constant uint64_t & nb21,
-        constant uint64_t & nb22,
-        constant uint64_t & nb30,
-        constant uint64_t & nb31,
-        constant uint64_t & nb40,
-        constant uint64_t & nb41,
-        constant uint64_t & nb42,
-        constant uint64_t & nb50,
-        constant uint64_t & nb51,
-        constant uint64_t & nb52,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t ir = tgpig.x;
-    const int64_t i3 = tgpig.y;
-
-    const int64_t nc  = d_state;
-    const int64_t nr  = d_inner;
-    const int64_t n_t = n_seq_tokens;
-    const int64_t n_s = n_seqs;
-
-    for (int64_t i2 = 0; i2 < n_t; ++i2) {
-        device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
-        device const float * x  = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12);
-        device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22);
-        device const float * A  = (device const float *) ((device const char *) src3 + ir*nb31);
-        device const float * B  = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42);
-        device const float * C  = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52);
-        device       float * y  = (device       float *) ((device       char *) dst  + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides
-        device       float * s  = (device       float *) ((device       char *) dst  + ir*nb01 + i3*nb02 +    nb13);
-
-        if (i2 > 0) {
-            s0 = s;
-        }
-
-        // i1 == 0
-        float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
-        float x_dt = x[0] * dt_soft_plus;
-        float sumf = 0.0f;
-
-        for (int64_t i0 = 0; i0 < nc; ++i0) {
-            int64_t i = i0;
-            float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt);
-            sumf += state * C[i0];
-            s[i] = state;
-        }
-
-        y[0] = sumf;
-    }
-}
-
 kernel void kernel_norm(
         device const  void * src0,
         device       float * dst,
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index e52471ce3f861d..d0a99150314c92 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -15898,8 +15898,8 @@ static void ggml_compute_forward_ssm_scan_f32(
             const float * A  = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
             const float * B  = (const float *) ((const char *) src4->data +  i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
             const float * C  = (const float *) ((const char *) src5->data +  i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
-                  float * y  = (      float *) ((      char *) dst->data  + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
-                  float * s  = (      float *) ((      char *) dst->data  + ir0*(src0->nb[1]) + i3*(src0->nb[2]) +     src1->nb[3]);  // {d_state, d_inner, n_s}
+            float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
+            float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
 
             // use the output as the source for the next token-wise iterations
             if (i2 > 0) { s0 = s; }
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 3955ef3323f5ef..eab5f027f498bd 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -949,58 +949,6 @@ struct test_rms_norm : public test_case {
     }
 };
 
-// GGML_OP_SSM_CONV
-struct test_ssm_conv : public test_case {
-    const ggml_type type;
-    const std::array<int64_t, 4> ne_a;
-    const std::array<int64_t, 4> ne_b;
-
-    std::string vars() override {
-        return VARS_TO_STR3(type, ne_a, ne_b);
-    }
-
-    test_ssm_conv(ggml_type type = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne_a = {10, 10, 10, 1},
-            std::array<int64_t, 4> ne_b = {3, 3, 1, 1})
-        : type(type), ne_a(ne_a), ne_b(ne_b) {}
-
-    ggml_tensor * build_graph(ggml_context * ctx) override {
-        ggml_tensor * a   = ggml_new_tensor(ctx, type, 4, ne_a.data());
-        ggml_tensor * b   = ggml_new_tensor(ctx, type, 4, ne_b.data());
-        ggml_tensor * out = ggml_ssm_conv(ctx, a, b);
-        return out;
-    }
-};
-
-// GGML_OP_SSM_SCAN
-struct test_ssm_scan : public test_case {
-    const ggml_type type;
-
-    const int64_t d_state;
-    const int64_t d_inner;
-    const int64_t n_seq_tokens;
-    const int64_t n_seqs;
-
-    std::string vars() override {
-        return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs);
-    }
-
-    test_ssm_scan(ggml_type type = GGML_TYPE_F32,
-            int64_t d_state = 32, int64_t d_inner = 32, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
-        : type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
-
-    ggml_tensor * build_graph(ggml_context * ctx) override {
-        ggml_tensor * s   = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, d_inner,      n_seqs, 1 }.data());
-        ggml_tensor * x   = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_inner, n_seq_tokens, n_seqs, 1 }.data());
-        ggml_tensor * dt  = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_inner, n_seq_tokens, n_seqs, 1 }.data());
-        ggml_tensor * A   = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, d_inner,      1     , 1 }.data());
-        ggml_tensor * B   = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, n_seq_tokens, n_seqs, 1 }.data());
-        ggml_tensor * C   = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, n_seq_tokens, n_seqs, 1 }.data());
-        ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C);
-        return out;
-    }
-};
-
 // GGML_OP_MUL_MAT
 struct test_mul_mat : public test_case {
     const ggml_type type_a;
@@ -2292,12 +2240,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
     }
 
-    test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1}));
-    test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1}));
-    test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1}));
-
-    test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4));
-
 #if 1
     for (ggml_type type_a : base_types) {
         for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {