From fc54ef0d1c138133a01933296d50a36a1ab64735 Mon Sep 17 00:00:00 2001
From: Xuan Son Nguyen <thichthat@gmail.com>
Date: Wed, 21 Aug 2024 11:04:34 +0200
Subject: [PATCH 1/6] server : support reading arguments from environment
 variables (#9105)

* server : support reading arguments from environment variables

* add -fa and -dt

* readme : specify non-arg env var
---
 common/common.cpp          | 64 +++++++++++++++++++++++++++++++++-----
 common/common.h            |  2 +-
 examples/server/README.md  | 19 +++++++++++
 examples/server/server.cpp |  3 ++
 4 files changed, 80 insertions(+), 8 deletions(-)

diff --git a/common/common.cpp b/common/common.cpp
index 382d585a5e6f9..59e8296604c9c 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -77,6 +77,41 @@
 
 using json = nlohmann::ordered_json;
 
+//
+// Environment variable utils
+//
+
+template<typename T>
+static typename std::enable_if<std::is_same<T, std::string>::value, void>::type
+get_env(std::string name, T & target) {
+    char * value = std::getenv(name.c_str());
+    target = value ? std::string(value) : target;
+}
+
+template<typename T>
+static typename std::enable_if<!std::is_same<T, bool>::value && std::is_integral<T>::value, void>::type
+get_env(std::string name, T & target) {
+    char * value = std::getenv(name.c_str());
+    target = value ? std::stoi(value) : target;
+}
+
+template<typename T>
+static typename std::enable_if<std::is_floating_point<T>::value, void>::type
+get_env(std::string name, T & target) {
+    char * value = std::getenv(name.c_str());
+    target = value ? std::stof(value) : target;
+}
+
+template<typename T>
+static typename std::enable_if<std::is_same<T, bool>::value, void>::type
+get_env(std::string name, T & target) {
+    char * value = std::getenv(name.c_str());
+    if (value) {
+        std::string val(value);
+        target = val == "1" || val == "true";
+    }
+}
+
 //
 // CPU utils
 //
@@ -220,12 +255,6 @@ int32_t cpu_get_num_math() {
 // CLI argument parsing
 //
 
-void gpt_params_handle_hf_token(gpt_params & params) {
-    if (params.hf_token.empty() && std::getenv("HF_TOKEN")) {
-        params.hf_token = std::getenv("HF_TOKEN");
-    }
-}
-
 void gpt_params_handle_model_default(gpt_params & params) {
     if (!params.hf_repo.empty()) {
         // short-hand to avoid specifying --hf-file -> default it to --model
@@ -273,7 +302,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
 
     gpt_params_handle_model_default(params);
 
-    gpt_params_handle_hf_token(params);
+    if (params.hf_token.empty()) {
+        get_env("HF_TOKEN", params.hf_token);
+    }
 
     if (params.escape) {
         string_process_escapes(params.prompt);
@@ -293,6 +324,25 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
     return true;
 }
 
+void gpt_params_parse_from_env(gpt_params & params) {
+    // we only care about server-related params for now
+    get_env("LLAMA_ARG_MODEL",            params.model);
+    get_env("LLAMA_ARG_THREADS",          params.n_threads);
+    get_env("LLAMA_ARG_CTX_SIZE",         params.n_ctx);
+    get_env("LLAMA_ARG_N_PARALLEL",       params.n_parallel);
+    get_env("LLAMA_ARG_BATCH",            params.n_batch);
+    get_env("LLAMA_ARG_UBATCH",           params.n_ubatch);
+    get_env("LLAMA_ARG_N_GPU_LAYERS",     params.n_gpu_layers);
+    get_env("LLAMA_ARG_THREADS_HTTP",     params.n_threads_http);
+    get_env("LLAMA_ARG_CHAT_TEMPLATE",    params.chat_template);
+    get_env("LLAMA_ARG_N_PREDICT",        params.n_predict);
+    get_env("LLAMA_ARG_ENDPOINT_METRICS", params.endpoint_metrics);
+    get_env("LLAMA_ARG_ENDPOINT_SLOTS",   params.endpoint_slots);
+    get_env("LLAMA_ARG_EMBEDDINGS",       params.embedding);
+    get_env("LLAMA_ARG_FLASH_ATTN",       params.flash_attn);
+    get_env("LLAMA_ARG_DEFRAG_THOLD",     params.defrag_thold);
+}
+
 bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
     const auto params_org = params; // the example can modify the default params
 
diff --git a/common/common.h b/common/common.h
index df23460a50fe0..f603ba2be1d35 100644
--- a/common/common.h
+++ b/common/common.h
@@ -267,7 +267,7 @@ struct gpt_params {
     std::string lora_outfile = "ggml-lora-merged-f16.gguf";
 };
 
-void gpt_params_handle_hf_token(gpt_params & params);
+void gpt_params_parse_from_env(gpt_params & params);
 void gpt_params_handle_model_default(gpt_params & params);
 
 bool gpt_params_parse_ex   (int argc, char ** argv, gpt_params & params);
diff --git a/examples/server/README.md b/examples/server/README.md
index 930ae15f64d8b..abe245271195b 100644
--- a/examples/server/README.md
+++ b/examples/server/README.md
@@ -247,6 +247,25 @@ logging:
          --log-append             Don't truncate the old log file.
 ```
 
+Available environment variables (if specified, these variables will override parameters specified in arguments):
+
+- `LLAMA_CACHE` (cache directory, used by `--hf-repo`)
+- `HF_TOKEN` (Hugging Face access token, used when accessing a gated model with `--hf-repo`)
+- `LLAMA_ARG_MODEL`
+- `LLAMA_ARG_THREADS`
+- `LLAMA_ARG_CTX_SIZE`
+- `LLAMA_ARG_N_PARALLEL`
+- `LLAMA_ARG_BATCH`
+- `LLAMA_ARG_UBATCH`
+- `LLAMA_ARG_N_GPU_LAYERS`
+- `LLAMA_ARG_THREADS_HTTP`
+- `LLAMA_ARG_CHAT_TEMPLATE`
+- `LLAMA_ARG_N_PREDICT`
+- `LLAMA_ARG_ENDPOINT_METRICS`
+- `LLAMA_ARG_ENDPOINT_SLOTS`
+- `LLAMA_ARG_EMBEDDINGS`
+- `LLAMA_ARG_FLASH_ATTN`
+- `LLAMA_ARG_DEFRAG_THOLD`
 
 ## Build
 
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index ce711eadd29ac..e79e7aa2cb846 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -2507,6 +2507,9 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
+    // parse arguments from environment variables
+    gpt_params_parse_from_env(params);
+
     // TODO: not great to use extern vars
     server_log_json = params.log_json;
     server_verbose = params.verbosity > 0;

From a1631e53f6763e17da522ba219b030d8932900bd Mon Sep 17 00:00:00 2001
From: compilade <git@compilade.net>
Date: Wed, 21 Aug 2024 17:58:11 -0400
Subject: [PATCH 2/6] llama : simplify Mamba with advanced batch splits (#8526)

* llama : advanced batch splits

This includes equal-sequence-length batch splits which are useful
to simplify recurrent model operators.

* llama : always make recurrent state slots contiguous

* ggml : simplify mamba operators

* llama : fix integer signedness mixing

* llama : logits_all has priority over batch->logits

Otherwise, the server embeddings tests failed.
This was likely an existing problem but was only detected here
because of an additional assertion.

* llama : apply suggestions

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* llama : fix t5 segfault

* llama : fix Mamba session save and restore

* llama : minor cosmetic changes

* llama : rename llama_reorder_outputs to llama_output_reorder

Also move it closer to llama_output_reserve.

* llama : fix pooled embeddings when using batches with equal_seqs

* minor : add struct members for clarity

ggml-ci

* llama : fix T5 segfault again

* llama : fix Mamba pooled embeddings with multiple sequences

Until the pooled embeddings are refactored to allow splitting
across ubatches for causal embeddings,
recurrent models can only process a single sequence per ubatch
when calculating pooled embeddings.

* llama : add llama_model_is_recurrent to simplify figuring that out

This will make it easier to more cleanly support RWKV-v6 and Mamba-2.

* llama : fix simple splits when the batch contains embeddings

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
---
 ggml/include/ggml.h |    9 +-
 ggml/src/ggml.c     |  277 +++-----
 include/llama.h     |    3 +
 src/llama.cpp       | 1526 +++++++++++++++++++++++++++++--------------
 4 files changed, 1137 insertions(+), 678 deletions(-)

diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 1d2a354024675..b8a21a2ccc3f0 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -1777,10 +1777,8 @@ extern "C" {
 
     GGML_API struct ggml_tensor * ggml_ssm_conv(
             struct ggml_context * ctx,
-            struct ggml_tensor  * s,
-            struct ggml_tensor  * x,
-            struct ggml_tensor  * c,
-            struct ggml_tensor  * sq);
+            struct ggml_tensor  * sx,
+            struct ggml_tensor  * c);
 
     GGML_API struct ggml_tensor * ggml_ssm_scan(
             struct ggml_context * ctx,
@@ -1789,8 +1787,7 @@ extern "C" {
             struct ggml_tensor  * dt,
             struct ggml_tensor  * A,
             struct ggml_tensor  * B,
-            struct ggml_tensor  * C,
-            struct ggml_tensor  * sq);
+            struct ggml_tensor  * C);
 
     // partition into non-overlapping windows with padding if needed
     // example:
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 88e4fb7325dd9..d63c917a5705a 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -7229,43 +7229,34 @@ struct ggml_tensor * ggml_flash_attn_back(
 
 struct ggml_tensor * ggml_ssm_conv(
         struct ggml_context * ctx,
-        struct ggml_tensor  * s,
-        struct ggml_tensor  * x,
-        struct ggml_tensor  * c,
-        struct ggml_tensor  * sq) {
-    GGML_ASSERT(ggml_is_3d(s));
-    GGML_ASSERT(ggml_is_matrix(x));
+        struct ggml_tensor  * sx,
+        struct ggml_tensor  * c) {
+    GGML_ASSERT(ggml_is_3d(sx));
     GGML_ASSERT(ggml_is_matrix(c));
-    GGML_ASSERT(ggml_is_matrix(sq));
-    GGML_ASSERT(sq->type == GGML_TYPE_I32);
 
-    const int64_t d_conv   = c->ne[0];
-    const int64_t d_inner  = c->ne[1];
-    const int64_t n_tokens = x->ne[1];
-    const int64_t n_kv     = s->ne[2];
+    const int64_t d_conv  = c->ne[0];
+    const int64_t d_inner = c->ne[1];
+    const int64_t n_t     = sx->ne[0] - d_conv + 1; // tokens per sequence
+    const int64_t n_s     = sx->ne[2];
 
-    GGML_ASSERT( s->ne[0] == d_conv - 1);
-    GGML_ASSERT( s->ne[1] == d_inner);
-    GGML_ASSERT( x->ne[0] == d_inner);
-    GGML_ASSERT(sq->ne[0] == n_kv);
-    GGML_ASSERT(sq->ne[1] == n_tokens);
+    // TODO: maybe support other strides than 1?
+    GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
+    GGML_ASSERT(sx->ne[1] == d_inner);
+    GGML_ASSERT(n_t >= 0);
 
     bool is_node = false;
 
-    if (s->grad || x->grad || c->grad || sq->grad) {
+    if (sx->grad || c->grad) {
         GGML_ABORT("fatal error"); // TODO: implement
         is_node = true;
     }
 
-    // 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv}
-    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv));
+    struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_t, n_s);
 
     result->op   = GGML_OP_SSM_CONV;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = s;
-    result->src[1] = x;
-    result->src[2] = c;
-    result->src[3] = sq;
+    result->src[0] = sx;
+    result->src[1] = c;
 
     return result;
 }
@@ -7279,39 +7270,42 @@ struct ggml_tensor * ggml_ssm_scan(
         struct ggml_tensor  * dt,
         struct ggml_tensor  * A,
         struct ggml_tensor  * B,
-        struct ggml_tensor  * C,
-        struct ggml_tensor  * sq) {
+        struct ggml_tensor  * C) {
     GGML_ASSERT(ggml_is_contiguous(s));
     GGML_ASSERT(ggml_is_contiguous(x));
     GGML_ASSERT(ggml_is_contiguous(dt));
     GGML_ASSERT(ggml_is_contiguous(A));
-    GGML_ASSERT(sq->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_is_matrix(A));
+    GGML_ASSERT(ggml_is_3d(B));
+    GGML_ASSERT(ggml_is_3d(s));
     GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
     GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
     GGML_ASSERT(ggml_are_same_shape(x, dt));
+    GGML_ASSERT(ggml_are_same_shape(B, C));
 
     {
-        const int64_t d_state  = s->ne[0];
-        const int64_t d_inner  = s->ne[1];
-        const int64_t n_tokens = x->ne[1];
+        const int64_t d_state      = s->ne[0];
+        const int64_t d_inner      = s->ne[1];
+        const int64_t n_seq_tokens = x->ne[1];
+        const int64_t n_seqs       = x->ne[2];
 
+        GGML_ASSERT(s->ne[2] == n_seqs);
         GGML_ASSERT(x->ne[0] == d_inner);
         GGML_ASSERT(A->ne[0] == d_state);
         GGML_ASSERT(A->ne[1] == d_inner);
         GGML_ASSERT(B->ne[0] == d_state);
-        GGML_ASSERT(B->ne[1] == n_tokens);
-        GGML_ASSERT(C->ne[0] == d_state);
-        GGML_ASSERT(C->ne[1] == n_tokens);
+        GGML_ASSERT(B->ne[1] == n_seq_tokens);
+        GGML_ASSERT(B->ne[2] == n_seqs);
     }
 
     bool is_node = false;
 
-    if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad || sq->grad) {
+    if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad) {
         GGML_ABORT("fatal error"); // TODO: implement
         is_node = true;
     }
 
-    // 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv}
+    // concatenated y + ssm_states
     struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
 
     result->op   = GGML_OP_SSM_SCAN;
@@ -7322,7 +7316,6 @@ struct ggml_tensor * ggml_ssm_scan(
     result->src[3] = A;
     result->src[4] = B;
     result->src[5] = C;
-    result->src[6] = sq;
 
     return result;
 }
@@ -10995,11 +10988,6 @@ static void ggml_compute_forward_concat_f32(
 
     GGML_TENSOR_BINARY_OP_LOCALS
 
-    // TODO: support for transposed / permuted tensors
-    GGML_ASSERT(nb0  == sizeof(float));
-    GGML_ASSERT(nb00 == sizeof(float));
-    GGML_ASSERT(nb10 == sizeof(float));
-
     const int32_t dim = ggml_get_op_params_i32(dst, 0);
 
     GGML_ASSERT(dim >= 0 && dim < 4);
@@ -15782,27 +15770,22 @@ static void ggml_compute_forward_flash_attn_back(
 static void ggml_compute_forward_ssm_conv_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
-    const struct ggml_tensor * src0 = dst->src[0]; // conv_state
-    const struct ggml_tensor * src1 = dst->src[1]; // x
-    const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
-    const struct ggml_tensor * src3 = dst->src[3]; // state_seq
+    const struct ggml_tensor * src0 = dst->src[0]; // conv_x
+    const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
 
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int nc   = src2->ne[0]; // d_conv
-    const int nr   = src0->ne[1]; // d_inner
-    const int n_t  = src1->ne[1]; // n_tokens
-    const int n_kv = src0->ne[2]; // max number of sequences in the batch
+    const int nc  = src1->ne[0]; // d_conv
+    const int ncs = src0->ne[0]; // d_conv - 1 + n_t
+    const int nr  = src0->ne[1]; // d_inner
+    const int n_t =  dst->ne[1]; // tokens per sequence
+    const int n_s =  dst->ne[2]; // number of sequences in the batch
 
-    GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst));
+    GGML_ASSERT( dst->ne[0] == nr);
     GGML_ASSERT(src0->nb[0] == sizeof(float));
     GGML_ASSERT(src1->nb[0] == sizeof(float));
-    GGML_ASSERT(src2->nb[0] == sizeof(float));
-    GGML_ASSERT(src3->nb[0] == sizeof(int32_t));
     GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
-    // for use with the destination state offset between sequences
-    GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float));
 
     // rows per thread
     const int dr = (nr + nth - 1)/nth;
@@ -15812,74 +15795,27 @@ static void ggml_compute_forward_ssm_conv_f32(
     const int ir1 = MIN(ir0 + dr, nr);
     const int ir  = ir1 - ir0;
 
-    if (n_kv > 1) {
-        // multiple sequences means it's hard to know when it's the first time a state is read,
-        // so copy them all over to the destination, just to be sure.
-        for (int i3 = 0; i3 < n_kv; ++i3) {
-            float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
-            float * s  = (float *) ((char *)  dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float));
-            // can't use memcpy because of d_conv vs d_conv - 1
-            for (int i1 = 0; i1 < ir; ++i1) {
-                for (int i0 = 0; i0 < nc - 1; ++i0) {
-                    // copy s0 to last (d_conv - 1) columns of s
-                    s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)];
-                }
-            }
-        }
-    }
-
-    for (int i2 = 0; i2 < n_t; ++i2) {
-        int32_t * sq = (int32_t *) ((char *) src3->data +  i2*(src3->nb[1])); // {n_kv, n_tokens}
-        float *   x  = (float *)   ((char *)  dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
-        float *   s  = (float *)   ((char *)  dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv}
-        float *   s0; // {d_conv - 1, d_inner, n_kv}
-        float *   x0 = (float *)   ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
-        float *   c  = (float *)   ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner}
-        int ne0s0;
+    for (int i3 = 0; i3 < n_s; ++i3) {
+        for (int i2 = 0; i2 < n_t; ++i2) {
+            // {d_conv - 1 + n_t, d_inner, n_seqs}
+            // sliding window
+            const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
+            const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
+            float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
 
-        GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
-
-        // avoid needing to copy the state for the first token
-        if (i2 == 0) {
-            s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv}
-            ne0s0 = src0->ne[0];
-        } else {
-            // the source is the last (d_conv - 1) columns of the destination
-            s0 = s + 1;
-            ne0s0 = nc;
-        }
-
-        // d_inner
-        for (int i1 = 0; i1 < ir; ++i1) {
-            // shift state left
-            for (int i0 = 0; i0 < nc - 1; ++i0) {
-                s[i0 + i1*nc] = s0[i0 + i1*ne0s0];
-            }
-            // insert x on the last column
-            s[(nc - 1) + i1*nc] = x0[i1];
-        }
-
-        // handle copies when there are multiple output states
-        for (int i3 = 1; i3 < n_kv; ++i3) {
-            int32_t seq = sq[i3];
-            if (0 <= seq && seq < n_kv) {
-                float * s1 = s + (seq - sq[0])*nc*nr;
-                memcpy(s1, s, nc*ir*sizeof(float));
-            } else {
-                // stop at negative or too big seq_ids
-                break;
-            }
-        }
+            // TODO: transpose the output for smaller strides for big batches?
+            // d_inner
+            for (int i1 = 0; i1 < ir; ++i1) {
+                // rowwise dot product
+                // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
+                float sumf = 0.0f;
 
-        // it seems a little faster when this is separate from the state shift
-        for (int i1 = 0; i1 < ir; ++i1) {
-            // rowwise dot product
-            float sumf = 0.0f;
-            for (int i0 = 0; i0 < nc; ++i0) {
-                int i = i0 + i1*nc;
-                sumf += s[i] * c[i];
+                // d_conv
+                for (int i0 = 0; i0 < nc; ++i0) {
+                    sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
+                }
+                x[i1] = sumf;
             }
-            x[i1] = sumf;
         }
     }
 }
@@ -15910,15 +15846,14 @@ static void ggml_compute_forward_ssm_scan_f32(
     const struct ggml_tensor * src3 = dst->src[3]; // A
     const struct ggml_tensor * src4 = dst->src[4]; // B
     const struct ggml_tensor * src5 = dst->src[5]; // C
-    const struct ggml_tensor * src6 = dst->src[6]; // sq
 
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int64_t nc   = src0->ne[0]; // d_state
-    const int64_t nr   = src0->ne[1]; // d_inner
-    const int64_t n_t  = src1->ne[1]; // number of tokens in the batch
-    const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch
+    const int64_t nc  = src0->ne[0]; // d_state
+    const int64_t nr  = src0->ne[1]; // d_inner
+    const int64_t n_t = src1->ne[1]; // number of tokens per sequence
+    const int64_t n_s = src0->ne[2]; // number of sequences in the batch
 
     GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
     GGML_ASSERT(src0->nb[0] == sizeof(float));
@@ -15927,12 +15862,12 @@ static void ggml_compute_forward_ssm_scan_f32(
     GGML_ASSERT(src3->nb[0] == sizeof(float));
     GGML_ASSERT(src4->nb[0] == sizeof(float));
     GGML_ASSERT(src5->nb[0] == sizeof(float));
-    // required for the dot product between s and C, and when copying the states
+    // required for the dot product between s and C
     GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
     // required for per-sequence offsets for states
     GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
-    // required to get correct offset for state destination (i.e. src1->nb[2])
-    GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float));
+    // required to get correct offset for state destination (i.e. src1->nb[3])
+    GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
 
     // rows per thread
     const int dr = (nr + nth - 1)/nth;
@@ -15942,64 +15877,36 @@ static void ggml_compute_forward_ssm_scan_f32(
     const int ir1 = MIN(ir0 + dr, nr);
     const int ir  = ir1 - ir0;
 
-    if (n_kv > 1) {
-        // it's hard to know if the source states have already been copied
-        // when there are multiple, so copy them already.
-        for (int i3 = 0; i3 < n_kv; ++i3) {
-            float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
-            float * s  = (float *) ((char *)  dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]);
-            memcpy(s, s0, nc*ir*sizeof(float));
-        }
-    }
-
-    for (int i2 = 0; i2 < n_t; ++i2) {
-        int32_t * sq = (int32_t *) ((char *) src6->data +  i2*(src6->nb[1])); // {n_kv, n_tokens}
-        float *   y  = (float *)   ((char *)  dst->data + ir0*(src1->nb[0]) +    i2*(src1->nb[1])); // {d_inner, n_tokens}
-        float *   s  = (float *)   ((char *)  dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv}
-        float *   s0;
-        float *   x  = (float *)   ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
-        float *   dt = (float *)   ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
-        float *   A  = (float *)   ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
-        float *   B  = (float *)   ((char *) src4->data +  i2*(src4->nb[1])); // {d_state, n_tokens}
-        float *   C  = (float *)   ((char *) src5->data +  i2*(src5->nb[1])); // {d_state, n_tokens}
-
-        GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
-
-        // avoid needing to copy the state for the first token
-        if (i2 == 0) {
-            s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv}
-        } else {
-            // otherwise the source is the same as the destination
-            s0 = s;
-        }
-
-        // d_inner
-        for (int i1 = 0; i1 < ir; ++i1) {
-            // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
-            float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
-            float x_dt = x[i1] * dt_soft_plus;
-            float sumf = 0.0f;
-            // d_state
-            for (int i0 = 0; i0 < nc; ++i0) {
-                int i = i0 + i1*nc;
-                // state = prev_state * dA + dB * x
-                float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
-                // y = rowwise_dotprod(state, C)
-                sumf += state * C[i0];
-                s[i] = state;
-            }
-            y[i1] = sumf;
-        }
-
-        // handle copies when there are multiple output states
-        for (int i3 = 1; i3 < n_kv; ++i3) {
-            int32_t seq = sq[i3];
-            if (0 <= seq && seq < n_kv) {
-                float * s1 = s + (seq - sq[0])*nc*nr;
-                memcpy(s1, s, nc*ir*sizeof(float));
-            } else {
-                // stop at negative or too big seq_ids
-                break;
+    for (int i3 = 0; i3 < n_s; ++i3) {
+        for (int i2 = 0; i2 < n_t; ++i2) {
+            const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
+            const float * x  = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
+            const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
+            const float * A  = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
+            const float * B  = (const float *) ((const char *) src4->data +  i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
+            const float * C  = (const float *) ((const char *) src5->data +  i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
+            float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
+            float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
+
+            // use the output as the source for the next token-wise iterations
+            if (i2 > 0) { s0 = s; }
+
+            // d_inner
+            for (int i1 = 0; i1 < ir; ++i1) {
+                // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
+                float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
+                float x_dt = x[i1] * dt_soft_plus;
+                float sumf = 0.0f;
+                // d_state
+                for (int i0 = 0; i0 < nc; ++i0) {
+                    int i = i0 + i1*nc;
+                    // state = prev_state * dA + dB * x
+                    float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
+                    // y = rowwise_dotprod(state, C)
+                    sumf += state * C[i0];
+                    s[i] = state;
+                }
+                y[i1] = sumf;
             }
         }
     }
diff --git a/include/llama.h b/include/llama.h
index 188ae76f8001e..6cca6320b347d 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -511,6 +511,9 @@ extern "C" {
     // to the decoder to start generating output sequence. For other models, it returns -1.
     LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);
 
+    // Returns true if the model is recurrent (like Mamba, RWKV, etc.)
+    LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
+
     // Returns 0 on success
     LLAMA_API uint32_t llama_model_quantize(
             const char * fname_inp,
diff --git a/src/llama.cpp b/src/llama.cpp
index fe3c0db6f2931..bd7f1508b2644 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -2516,10 +2516,29 @@ struct llama_layer {
     struct ggml_tensor * ffn_down_scale;
 };
 
+// very similar to llama_batch,
+// but has more metadata about sequences
+struct llama_ubatch {
+    bool equal_seqs;
+    // TODO: whole_seqs for embeddings?
+
+    uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
+    uint32_t n_seq_tokens; // tokens per sequence
+    uint32_t n_seqs;
+
+    llama_token  *  token;    // [n_tokens]
+    float        *  embd;     // [n_embd, n_tokens]
+    llama_pos    *  pos;      // [n_tokens]
+    int32_t      *  n_seq_id; // [n_seqs]
+    llama_seq_id ** seq_id;   // [n_seqs]
+    int8_t       *  output;   // [n_tokens]
+};
+
 struct llama_kv_cell {
     llama_pos pos   = -1;
     llama_pos delta = 0;
-    int32_t   src   = 0; // used by recurrent state models to copy states
+    int32_t   src   = -1; // used by recurrent state models to copy states
+    int32_t   tail  = -1;
 
     std::set<llama_seq_id> seq_id;
 
@@ -2540,7 +2559,6 @@ struct llama_kv_cell {
 struct llama_kv_cache {
     bool has_shift = false;
     bool do_defrag = false;
-    bool do_copy   = false;
     bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
     bool v_trans   = true;  // the value tensor is transposed
 
@@ -2703,6 +2721,340 @@ struct llama_model {
     }
 };
 
+struct llama_sbatch_seq {
+    int32_t n_seq_id;
+    llama_seq_id * seq_id;
+    size_t offset;
+    size_t length;
+
+    // helper for smoother batch API transition -- can be deprecated in the future
+    llama_seq_id all_seq_id; // used if seq_id == NULL
+};
+
+// sequence-length-aware batch splitting
+struct llama_sbatch {
+    // tokens left in this batch
+    size_t n_tokens;
+
+    size_t n_embd;
+
+    bool logits_all; // TODO: remove once lctx.logits_all is removed too
+
+    // sorted indices into the batch
+    std::vector<size_t> ids;
+    // batch indices of the output
+    std::vector<size_t> out_ids;
+    std::vector<llama_sbatch_seq> seq;
+    const llama_batch * batch = nullptr;
+
+    // buffers for the ubatch
+    std::vector<llama_token>    ubatch_token;
+    std::vector<float>          ubatch_embd;
+    std::vector<llama_pos>      ubatch_pos;
+    std::vector<int32_t>        ubatch_n_seq_id;
+    std::vector<llama_seq_id *> ubatch_seq_id;
+    std::vector<int8_t>         ubatch_output;
+
+    llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false) {
+        // clear empty sequences
+        // the previous ubatch is assumed to be gone,
+        // so nothing should refer to values in these sequences anymore.
+        for (size_t i = seq.size(); i-- > 0;) {
+            if (seq[i].length == 0) {
+                seq.pop_back();
+            } else {
+                break;
+            }
+        }
+        ubatch_token.resize(!has_embd ? n_ubatch : 0);
+        ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
+        ubatch_pos.resize(n_ubatch);
+        ubatch_n_seq_id.resize(n_ubatch);
+        ubatch_seq_id.resize(n_ubatch);
+        ubatch_output.resize(n_ubatch);
+        llama_ubatch ubatch = {
+            /*equal_seqs   =*/ true,
+            /*n_tokens     =*/ 0,
+            /*n_seq_tokens =*/ 0,
+            /*n_seqs       =*/ 0,
+            /*token        =*/ !has_embd ? ubatch_token.data() : nullptr,
+            /*embd         =*/ has_embd  ? ubatch_embd.data()  : nullptr,
+            /*pos          =*/ ubatch_pos.data(),
+            /*n_seq_id     =*/ ubatch_n_seq_id.data(),
+            /*seq_id       =*/ ubatch_seq_id.data(),
+            /*output       =*/ ubatch_output.data(),
+        };
+        return ubatch;
+    }
+
+    void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
+        GGML_ASSERT(batch != nullptr);
+        GGML_ASSERT(length <= seq.length);
+        // Can only add sequences of equal lengths to a batch,
+        // otherwise it isn't clear to which sequence a token belongs
+        GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
+        GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
+        // NOTE: loops are separated for cache-friendliness
+        if (batch->token) {
+            if (ubatch.equal_seqs) {
+                for (size_t i = 0; i < length; ++i) {
+                    ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
+                }
+            } else {
+                // simple split
+                ubatch.token = batch->token + seq.offset;
+            }
+        } else {
+            ubatch.token = nullptr;
+        }
+        if (batch->embd) {
+            if (ubatch.equal_seqs) {
+                for (size_t i = 0; i < length; ++i) {
+                    memcpy(
+                        ubatch.embd + n_embd * (ubatch.n_tokens + i),
+                        batch->embd + n_embd * ids[seq.offset + i],
+                        n_embd * sizeof(float)
+                    );
+                }
+            } else {
+                // simple split
+                ubatch.embd = batch->embd + (n_embd * seq.offset);
+            }
+        } else {
+            ubatch.embd = nullptr;
+        }
+        // from here on, the else branches are deprecated;
+        // they are helpers for smoother batch API transition
+        if (batch->pos) {
+            if (ubatch.equal_seqs) {
+                for (size_t i = 0; i < length; ++i) {
+                    ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
+                }
+            } else {
+                // simple split
+                ubatch.pos = batch->pos + seq.offset;
+            }
+        } else {
+            for (size_t i = 0; i < length; ++i) {
+                llama_pos bi = ids[seq.offset + i];
+                ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1);
+            }
+        }
+        if (ubatch.equal_seqs) {
+            ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
+            if (seq.seq_id) {
+                ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
+            } else {
+                GGML_ASSERT(seq.n_seq_id == 1);
+                ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id;
+            }
+        } else {
+            // simple split
+            if (batch->n_seq_id) {
+                for (size_t i = 0; i < length; ++i) {
+                    ubatch.n_seq_id = batch->n_seq_id + seq.offset;
+                }
+            } else {
+                for (size_t i = 0; i < length; ++i) {
+                    ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
+                }
+            }
+            if (batch->seq_id) {
+                for (size_t i = 0; i < length; ++i) {
+                    ubatch.seq_id = batch->seq_id + seq.offset;
+                }
+            } else {
+                for (size_t i = 0; i < length; ++i) {
+                    ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id;
+                }
+            }
+        }
+        if (logits_all) {
+            for (size_t i = 0; i < length; ++i) {
+                ubatch.output[ubatch.n_tokens + i] = 1;
+                out_ids.push_back(ids[seq.offset + i]);
+            }
+        } else if (batch->logits) {
+            if (ubatch.equal_seqs) {
+                for (size_t i = 0; i < length; ++i) {
+                    size_t id = ids[seq.offset + i];
+                    int8_t is_output = batch->logits[id];
+                    ubatch.output[ubatch.n_tokens + i] = is_output;
+                    if (is_output) { out_ids.push_back(id); }
+                }
+            } else {
+                // simple split
+                ubatch.output = batch->logits + seq.offset;
+                for (size_t i = 0; i < length; ++i) {
+                    if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
+                }
+            }
+        } else {
+            // only get last output
+            for (size_t i = 0; i < length; ++i) {
+                size_t id = ids[seq.offset + i];
+                int8_t is_last = id == ids.size() - 1;
+                ubatch.output[ubatch.n_tokens + i] = is_last;
+                if (is_last) { out_ids.push_back(id); }
+            }
+        }
+        if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
+            ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
+        }
+        ubatch.n_tokens += length;
+        ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
+        seq.offset += length;
+        seq.length -= length;
+        n_tokens -= length;
+        GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
+    }
+
+    // simple split, unknown number of sequences of unequal lengths
+    llama_ubatch split_simple(size_t n_ubatch) {
+        n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
+        llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
+        ubatch.equal_seqs = false;
+        if (!seq.empty()) {
+            llama_sbatch_seq & s = seq[0];
+            size_t length = s.length < n_ubatch ? s.length : n_ubatch;
+            GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
+            add_seq_to_ubatch(ubatch, s, length);
+        }
+        return ubatch;
+    }
+
+    // make batches of equal-length sequences
+    llama_ubatch split_equal(size_t n_ubatch) {
+        n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
+        llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
+        if (!seq.empty()) {
+            size_t length = 0;
+            size_t n_tokens_in_ubatch = 0;
+            GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
+            // smallest first, because it's easier to split this way;
+            // starting from the end to pop in constant time.
+            for (size_t i = seq.size(); i-- > 0;) {
+                llama_sbatch_seq & s = seq[i];
+                GGML_ASSERT(s.length > 0);
+                if (length == 0) {
+                    length = s.length < n_ubatch ? s.length : n_ubatch;
+                }
+                add_seq_to_ubatch(ubatch, s, length);
+                n_tokens_in_ubatch += length;
+                // shared prompts can't be mixed with any of their sequences,
+                // so it's safer to compute them in their own ubatch
+                if (s.n_seq_id > 1) { break; }
+                // stop when there isn't enough space for another sequence
+                if (length + n_tokens_in_ubatch > n_ubatch) { break; }
+            }
+        }
+        return ubatch;
+    }
+
+    // sequence-wise split
+    llama_ubatch split_seq(size_t n_ubatch) {
+        n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
+        llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
+        if (!seq.empty()) {
+            llama_sbatch_seq & s = seq[seq.size() - 1];
+            size_t length = s.length < n_ubatch ? s.length : n_ubatch;
+            GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
+            add_seq_to_ubatch(ubatch, s, length);
+        }
+        return ubatch;
+    }
+
+    void from_batch(const llama_batch & batch, const size_t n_embd, const bool simple_split = false, const bool logits_all = false) {
+        GGML_ASSERT(batch.n_tokens >= 0);
+        this->batch = &batch;
+        this->n_embd = n_embd;
+        this->logits_all = logits_all;
+
+        n_tokens = batch.n_tokens;
+        ids.resize(n_tokens);
+        out_ids.clear();
+        // TODO: reserve out_ids and seq
+
+        for (size_t i = 0; i < n_tokens; ++i) {
+            ids[i] = i;
+        }
+        if (simple_split) {
+            seq.resize(1);
+            llama_sbatch_seq & s = seq[0];
+            s.n_seq_id = 0;
+            s.seq_id = nullptr;
+            s.offset = 0;
+            s.length = n_tokens;
+            s.all_seq_id = batch.all_seq_id;
+            return;
+        }
+        std::sort(ids.begin(), ids.end(),
+            [&batch](size_t a, size_t b) {
+                int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
+                int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
+                // sort by seq_id, then by pos
+                if (n_seq_a == n_seq_b) {
+                    if (batch.seq_id) {
+                        for (int32_t i = 0; i < n_seq_a; ++i) {
+                            llama_seq_id seq_id_a = batch.seq_id[a][i];
+                            llama_seq_id seq_id_b = batch.seq_id[b][i];
+                            // smaller seq_ids go first
+                            if (seq_id_a != seq_id_b) {
+                                return seq_id_a < seq_id_b;
+                            }
+                        }
+                    }
+                    // when all else is equal, sort by pos
+                    if (batch.pos) {
+                        return batch.pos[a] < batch.pos[b];
+                    }
+                    // no pos, sort by id (assuming batch.all_pos_1 is positive)
+                    return a < b;
+                }
+                // shared prompts go first
+                return n_seq_a > n_seq_b;
+            }
+        );
+        // init seq
+        llama_sbatch_seq * last_seq = nullptr;
+
+        if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) {
+            for (size_t i = 0; i < n_tokens; ++i) {
+                const size_t bi = ids[i];
+                const int32_t n_seqs = batch.n_seq_id[bi];
+                llama_seq_id * seq_ids = batch.seq_id[bi];
+                if (last_seq != nullptr) {
+                    bool same = n_seqs == last_seq->n_seq_id;
+                    for (int32_t j = 0; same && j < n_seqs; ++j) {
+                        if (seq_ids[j] != last_seq->seq_id[j]) {
+                            same = false;
+                        }
+                    }
+                    if (same) {
+                        last_seq->length += 1;
+                        continue;
+                    }
+                }
+                llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id};
+                seq.push_back(new_seq);
+                last_seq = &seq.back();
+            }
+        } else {
+            llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id};
+            seq.push_back(new_seq);
+        }
+        // keep shared prompts first at the end, then sort by length descending.
+        std::sort(seq.begin(), seq.end(),
+            [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
+                if (a.n_seq_id == b.n_seq_id) {
+                    return a.length > b.length;
+                }
+                return a.n_seq_id < b.n_seq_id;
+            }
+        );
+    }
+};
+
 struct llama_context {
     llama_context(const llama_model & model)
         : model(model)
@@ -2724,6 +3076,7 @@ struct llama_context {
 
     struct llama_cparams        cparams;
     struct llama_sampling       sampling;
+    struct llama_sbatch         sbatch;
     struct llama_kv_cache       kv_self;
     struct llama_control_vector cvec;
 
@@ -2984,8 +3337,7 @@ static bool llama_kv_cache_init(
 
     cache.has_shift = false;
 
-    // TODO: find a nicer way to add other recurrent model architectures
-    cache.recurrent = model.arch == LLM_ARCH_MAMBA;
+    cache.recurrent = llama_model_is_recurrent(&model);
     cache.v_trans   = !cache.recurrent && !cparams.flash_attn;
 
     cache.head = 0;
@@ -2998,13 +3350,6 @@ static bool llama_kv_cache_init(
     cache.cells.clear();
     cache.cells.resize(kv_size);
 
-    if (cache.recurrent) {
-        // init state copy sources
-        for (uint32_t i = 0; i < cache.size; ++i) {
-            cache.cells[i].src = i;
-        }
-    }
-
     // count used buffer types
     std::map<ggml_backend_buffer_type_t, int> buft_layer_count;
     if (offload) {
@@ -3072,46 +3417,162 @@ static bool llama_kv_cache_init(
 // to the first cell of the slot.
 static bool llama_kv_cache_find_slot(
            struct llama_kv_cache & cache,
-        const struct llama_batch & batch) {
+       const struct llama_ubatch & batch) {
     const uint32_t n_tokens = batch.n_tokens;
+    const uint32_t n_seqs   = batch.n_seqs;
+    const uint32_t n_seq_tokens = batch.n_seq_tokens;
 
     if (cache.recurrent) {
         // For recurrent state architectures (like Mamba),
-        // each KV cache cell can store the state for a whole sequence.
-
-        llama_seq_id min = cache.size - 1;
-        llama_seq_id max = 0;
-
-        for (uint32_t i = 0; i < n_tokens; ++i) {
-            for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) {
-                llama_seq_id seq_id = batch.seq_id[i][j];
-                // make sure it's a valid seq_id
-                if ((uint32_t) seq_id < cache.size) {
-                    if (seq_id > max) {
-                        max = seq_id;
-                    }
-                    if (seq_id < min) {
-                        min = seq_id;
+        // each cache cell can store the state for a whole sequence.
+        // A slot should be always be contiguous.
+
+        // can only process batches with an equal number of new tokens in each sequence
+        GGML_ASSERT(batch.equal_seqs);
+
+        int32_t min = cache.size - 1;
+        int32_t max = 0;
+
+        // everything should fit if all seq_ids are smaller than the max
+        for (uint32_t s = 0; s < n_seqs; ++s) {
+            const uint32_t n_seq_id = batch.n_seq_id[s];
+            for (uint32_t j = 0; j < n_seq_id; ++j) {
+                const llama_seq_id seq_id = batch.seq_id[s][j];
+
+                if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
+                    // too big seq_id
+                    // TODO: would it be possible to resize the cache instead?
+                    LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
+                    return false;
+                }
+                if (j > 0) {
+                    llama_kv_cell & seq = cache.cells[seq_id];
+                    if (seq.tail >= 0) {
+                        llama_kv_cell & cell = cache.cells[seq.tail];
+                        // clear cells from seq_ids that become shared
+                        // (should not normally happen, but let's handle it anyway)
+                        cell.seq_id.erase(seq_id);
+                        seq.tail = -1;
+                        if (cell.seq_id.empty()) {
+                            cell.pos = -1;
+                            cell.src = -1;
+                            cache.used -= 1;
+                        }
                     }
-                    // Assuming the tokens are in-order
-                    if (batch.pos[i] != cache.cells[seq_id].pos + 1) {
-                        // What should happen when the pos backtracks or skips a value?
-                        // Clearing the state mid-batch would require special-casing which isn't done.
-                        LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n",
-                            __func__, batch.pos[i], cache.cells[seq_id].pos, seq_id);
+                }
+            }
+        }
+
+#ifndef NDEBUG
+        {
+            std::vector<int32_t> tails_verif;
+            tails_verif.assign(cache.size, -1);
+            for (uint32_t i = 0; i < cache.size; ++i) {
+                llama_kv_cell & cell = cache.cells[i];
+                for (llama_seq_id seq_id : cell.seq_id) {
+                    if (tails_verif[seq_id] != -1) {
+                        LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
                     }
-                    if (cache.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) {
-                        cache.used += 1;
+                    tails_verif[seq_id] = i;
+                }
+            }
+            for (uint32_t i = 0; i < cache.size; ++i) {
+                if (tails_verif[i] != cache.cells[i].tail) {
+                    LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cache.cells[i].tail, tails_verif[i]);
+                }
+            }
+        }
+#endif
+
+        // find next empty cell
+        uint32_t next_empty_cell = cache.head;
+
+        for (uint32_t i = 0; i < cache.size; ++i) {
+            if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
+            llama_kv_cell & cell = cache.cells[next_empty_cell];
+            if (cell.is_empty()) { break; }
+            next_empty_cell += 1;
+        }
+
+        // find usable cell range
+        for (uint32_t s = 0; s < n_seqs; ++s) {
+            const llama_seq_id seq_id = batch.seq_id[s][0];
+            llama_kv_cell & seq_meta = cache.cells[seq_id];
+            bool has_cell = false;
+            if (seq_meta.tail >= 0) {
+                llama_kv_cell & cell = cache.cells[seq_meta.tail];
+                GGML_ASSERT(cell.has_seq_id(seq_id));
+                // does this seq_id "own" the cell?
+                if (cell.seq_id.size() == 1) { has_cell = true; }
+            }
+            if (!has_cell) {
+                llama_kv_cell & empty_cell = cache.cells[next_empty_cell];
+                GGML_ASSERT(empty_cell.is_empty());
+                // copy old tail into the empty cell
+                if (seq_meta.tail >= 0) {
+                    llama_kv_cell & orig_cell = cache.cells[seq_meta.tail];
+                    empty_cell.pos = orig_cell.pos;
+                    empty_cell.src = orig_cell.src;
+                    orig_cell.seq_id.erase(seq_id);
+                    empty_cell.seq_id.insert(seq_id); // will be overwritten
+                }
+                seq_meta.tail = next_empty_cell;
+                // find next empty cell
+                if (s + 1 < n_seqs) {
+                    next_empty_cell += 1;
+                    for (uint32_t i = 0; i < cache.size; ++i) {
+                        if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
+                        llama_kv_cell & cell = cache.cells[next_empty_cell];
+                        if (cell.is_empty()) { break; }
+                        next_empty_cell += 1;
                     }
-                    cache.cells[seq_id].pos = batch.pos[i];
-                    // NOTE: seq_ids are not inserted here; they are handled when the input tensors are set
-                } else {
-                    // too big seq_id
-                    // TODO: would it be possible to resize the KV cache size instead?
-                    LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
-                    return false;
                 }
             }
+            if (min > seq_meta.tail) { min = seq_meta.tail; }
+            if (max < seq_meta.tail) { max = seq_meta.tail; }
+        }
+
+        // gather and re-order
+        for (uint32_t s = 0; s < n_seqs; ++s) {
+            int32_t dst_id = s + min;
+            int32_t src_id = cache.cells[batch.seq_id[s][0]].tail;
+            if (dst_id != src_id) {
+                llama_kv_cell & dst_cell = cache.cells[dst_id];
+                llama_kv_cell & src_cell = cache.cells[src_id];
+
+                std::swap(dst_cell.pos, src_cell.pos);
+                std::swap(dst_cell.src, src_cell.src);
+                std::swap(dst_cell.seq_id, src_cell.seq_id);
+
+                // swap tails (assuming they NEVER overlap)
+                for (const llama_seq_id seq_id : src_cell.seq_id) {
+                    cache.cells[seq_id].tail = src_id;
+                }
+                for (const llama_seq_id seq_id : dst_cell.seq_id) {
+                    cache.cells[seq_id].tail = dst_id;
+                }
+            }
+        }
+
+        // update the pos of the used seqs
+        for (uint32_t s = 0; s < n_seqs; ++s) {
+            const llama_pos last_pos = batch.pos[n_seq_tokens * s + n_seq_tokens - 1];
+            int32_t cell_id = s + min;
+            llama_kv_cell & cell = cache.cells[cell_id];
+
+            if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
+                // What should happen when the pos backtracks or skips a value?
+                // Clearing the state mid-batch would require special-casing which isn't done.
+                LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
+                    __func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens);
+            }
+            cell.pos = last_pos;
+            cell.seq_id.clear();
+            for (int32_t j = 0; j < batch.n_seq_id[s]; ++j) {
+                const llama_seq_id seq_id = batch.seq_id[s][j];
+                cell.seq_id.insert(seq_id);
+                cache.cells[seq_id].tail = cell_id;
+            }
         }
 
         // allow getting the range of used cells, from head to head + n
@@ -3119,7 +3580,7 @@ static bool llama_kv_cache_find_slot(
         cache.n    = max - min + 1;
 
         // sanity check
-        return max >= min;
+        return cache.n >= n_seqs;
     }
     // otherwise, one cell per token.
 
@@ -3157,11 +3618,14 @@ static bool llama_kv_cache_find_slot(
         }
     }
 
-    for (uint32_t i = 0; i < n_tokens; i++) {
-        cache.cells[cache.head + i].pos = batch.pos[i];
+    for (uint32_t s = 0; s < n_seqs; s++) {
+        for (uint32_t i = 0; i < n_seq_tokens; ++i) {
+            uint32_t k = s*n_seq_tokens + i;
+            cache.cells[cache.head + k].pos = batch.pos[k];
 
-        for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
-            cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]);
+            for (int32_t j = 0; j < batch.n_seq_id[s]; j++) {
+                cache.cells[cache.head + k].seq_id.insert(batch.seq_id[s][j]);
+            }
         }
     }
 
@@ -3187,6 +3651,8 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
     for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
         cache.cells[i].pos = -1;
         cache.cells[i].seq_id.clear();
+        cache.cells[i].src = -1;
+        cache.cells[i].tail = -1;
     }
     cache.head = 0;
     cache.used = 0;
@@ -3213,9 +3679,16 @@ static bool llama_kv_cache_seq_rm(
             return false;
         }
         if (0 <= seq_id) {
-            // partial intersection is invalid
-            if ((0 < p0 && p0 <= cache.cells[seq_id].pos) || (0 < p1 && p1 <= cache.cells[seq_id].pos)) {
-                return false;
+            int32_t & tail_id = cache.cells[seq_id].tail;
+            if (tail_id >= 0) {
+                const llama_kv_cell & cell = cache.cells[tail_id];
+                // partial intersection is invalid
+                if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
+                    return false;
+                }
+                if (p0 <= cell.pos && p1 < cell.pos) {
+                    tail_id = -1;
+                }
             }
         } else {
             // seq_id is negative, then the range should include everything or nothing
@@ -3239,6 +3712,7 @@ static bool llama_kv_cache_seq_rm(
                 if (cache.cells[i].pos >= 0) cache.used--;
 
                 cache.cells[i].pos = -1;
+                cache.cells[i].src = -1;
                 if (new_head == cache.size) new_head = i;
             }
         }
@@ -3261,23 +3735,29 @@ static void llama_kv_cache_seq_cp(
 
     if (cache.recurrent) {
         if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
-            seq_id_src = cache.cells[seq_id_src].src;
-            GGML_ASSERT((uint32_t) seq_id_src < cache.size);
-            // intent to "copy from"
-            // supports copy chains thanks to taking the source of the source
-            cache.cells[seq_id_dst].src = seq_id_src;
-
-            // preserve the "keep or clear" status of the copied sequence
-            if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) {
-                cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
-            } else {
-                cache.cells[seq_id_dst].seq_id.erase(seq_id_dst);
+            llama_kv_cell & tail_src = cache.cells[seq_id_src];
+            llama_kv_cell & tail_dst = cache.cells[seq_id_dst];
+            if (tail_dst.tail >= 0) {
+                // clear destination seq_id if it wasn't empty
+                llama_kv_cell & cell_dst = cache.cells[tail_dst.tail];
+
+                cell_dst.seq_id.erase(seq_id_dst);
+                tail_dst.tail = -1;
+                if (cell_dst.seq_id.empty()) {
+                    cell_dst.pos = -1;
+                    cell_dst.delta = -1;
+                    cell_dst.src = -1;
+                    cache.used -= 1;
+                }
             }
+            if (tail_src.tail >= 0) {
+                llama_kv_cell & cell_src = cache.cells[tail_src.tail];
 
-            cache.do_copy = true;
-
-            cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
+                cell_src.seq_id.insert(seq_id_dst);
+                tail_dst.tail = tail_src.tail;
+            }
         }
+
         return;
     }
     // otherwise, this is the KV cache of a Transformer-like model
@@ -3295,9 +3775,13 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
     uint32_t new_head = cache.size;
 
     for (uint32_t i = 0; i < cache.size; ++i) {
+        if (cache.recurrent && (llama_seq_id) i != seq_id) {
+            cache.cells[i].tail = -1;
+        }
         if (!cache.cells[i].has_seq_id(seq_id)) {
             if (cache.cells[i].pos >= 0) cache.used--;
             cache.cells[i].pos = -1;
+            cache.cells[i].src = -1;
             cache.cells[i].seq_id.clear();
             if (new_head == cache.size) new_head = i;
         } else {
@@ -3326,9 +3810,12 @@ static void llama_kv_cache_seq_add(
     if (cache.recurrent) {
         // for Mamba-like models, only the pos needs to be shifted
         if (0 <= seq_id && seq_id < (int64_t) cache.size) {
-            llama_kv_cell & cell = cache.cells[seq_id];
-            if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
-                cell.pos += delta;
+            const int32_t tail_id = cache.cells[seq_id].tail;
+            if (tail_id >= 0) {
+                llama_kv_cell & cell = cache.cells[tail_id];
+                if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
+                    cell.pos += delta;
+                }
             }
         }
         return;
@@ -3372,9 +3859,12 @@ static void llama_kv_cache_seq_div(
     if (cache.recurrent) {
         // for Mamba-like models, only the pos needs to be changed
         if (0 <= seq_id && seq_id < (int64_t) cache.size) {
-            llama_kv_cell & cell = cache.cells[seq_id];
-            if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
-                cell.pos /= d;
+            const int32_t tail_id = cache.cells[seq_id].tail;
+            if (tail_id >= 0) {
+                llama_kv_cell & cell = cache.cells[tail_id];
+                if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
+                    cell.pos /= d;
+                }
             }
         }
         return;
@@ -3406,7 +3896,9 @@ static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama
 }
 
 static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
-    cache.do_defrag = true;
+    if (!cache.recurrent) {
+        cache.do_defrag = true;
+    }
 }
 
 static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) {
@@ -7948,7 +8440,7 @@ static struct ggml_tensor * llm_build_inp_embd(
         struct ggml_context * ctx,
        struct llama_context & lctx,
         const llama_hparams & hparams,
-          const llama_batch & batch,
+         const llama_ubatch & batch,
          struct ggml_tensor * tok_embd,
          const llm_build_cb & cb) {
     const int64_t n_embd = hparams.n_embd;
@@ -8497,12 +8989,180 @@ static struct ggml_tensor * llm_build_kv(
     return cur;
 }
 
+static struct ggml_tensor * llm_build_copy_mask_state(
+        struct ggml_context * ctx,
+         struct ggml_cgraph * graph,
+         struct ggml_tensor * s,
+         struct ggml_tensor * state_copy,
+         struct ggml_tensor * state_mask,
+                    int32_t   n_state,
+                    int32_t   kv_size,
+                    int32_t   kv_head,
+                    int32_t   n_kv,
+                    int32_t   n_seqs) {
+    struct ggml_tensor * states = ggml_reshape_2d(ctx, s, n_state, kv_size);
+
+    // copy states
+    // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
+    // this shrinks the tensors's ne[1] to n_kv
+    states = ggml_get_rows(ctx, states, state_copy);
+
+    // clear states of sequences which are starting at the beginning of this batch
+    // FIXME: zero-out NANs?
+    states = ggml_mul(ctx, states, state_mask);
+
+    // copy states which won't be changed further (between n_seqs and n_rs)
+    ggml_build_forward_expand(graph,
+        ggml_cpy(ctx,
+            ggml_view_1d(ctx, states, n_state*(n_kv - n_seqs), n_seqs*n_state*ggml_element_size(states)),
+            ggml_view_1d(ctx, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
+
+    // the part of the states that will be used and modified
+    return ggml_view_2d(ctx, states, n_state, n_seqs, states->nb[1], 0);
+}
+
+// TODO: split
+static struct ggml_tensor * llm_build_mamba(
+        struct ggml_context * ctx,
+       struct llama_context & lctx,
+         const llama_ubatch & batch,
+         struct ggml_cgraph * graph,
+         struct ggml_tensor * cur,
+         struct ggml_tensor * state_copy,
+         struct ggml_tensor * state_mask,
+                    int32_t   kv_head,
+                    int32_t   n_kv,
+         const llm_build_cb & cb,
+                    int       il) {
+    const llama_model    & model   = lctx.model;
+    const llama_hparams  & hparams = model.hparams;
+    const llama_kv_cache & kv      = lctx.kv_self;
+    const int64_t d_conv  = hparams.ssm_d_conv;
+    const int64_t d_inner = hparams.ssm_d_inner;
+    const int64_t d_state = hparams.ssm_d_state;
+    const int64_t dt_rank = hparams.ssm_dt_rank;
+    const int64_t n_seqs  = batch.n_seqs;
+    // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
+    const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
+    // Use the same RMS norm as the final layer norm
+    const float norm_rms_eps = hparams.f_norm_rms_eps;
+
+    const int64_t n_seq_tokens = batch.n_seq_tokens;
+
+    GGML_ASSERT(n_seqs != 0);
+    GGML_ASSERT(batch.equal_seqs);
+    GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs);
+
+    struct ggml_tensor * conv_states_all = kv.k_l[il];
+    struct ggml_tensor * ssm_states_all  = kv.v_l[il];
+
+    // (ab)using the KV cache to store the states
+    struct ggml_tensor * conv = llm_build_copy_mask_state(ctx,
+            graph, conv_states_all, state_copy, state_mask,
+            hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs);
+    conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner, n_seqs);
+    struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx,
+            graph, ssm_states_all, state_copy, state_mask,
+            hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs);
+    ssm = ggml_reshape_3d(ctx, ssm, d_state, d_inner, n_seqs);
+
+    // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
+    cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs);
+
+    // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
+    struct ggml_tensor * xz = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_in, cur);
+    // split the above in two
+    // => {d_inner, n_seq_tokens, n_seqs}
+    struct ggml_tensor * x = ggml_view_3d(ctx, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0);
+    struct ggml_tensor * z = ggml_view_3d(ctx, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], d_inner*ggml_element_size(xz));
+
+    // conv
+    {
+        // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
+        struct ggml_tensor * conv_x = ggml_concat(ctx, conv, ggml_transpose(ctx, x), 0);
+
+        // copy last (d_conv - 1) columns back into the state cache
+        struct ggml_tensor * last_conv = ggml_view_3d(ctx, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
+
+        ggml_build_forward_expand(graph,
+            ggml_cpy(ctx, last_conv,
+                ggml_view_1d(ctx, conv_states_all,
+                    (d_conv - 1)*(d_inner)*(n_seqs),
+                    kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all))));
+
+        // 1D convolution
+        // The equivalent is to make a self-overlapping view of conv_x
+        // over d_conv columns at each stride in the 3rd dimension,
+        // then element-wise multiply that with the conv1d weight,
+        // then sum the elements of each row,
+        // (the last two steps are a dot product over rows (also doable with mul_mat))
+        // then permute away the ne[0] dimension,
+        // and then you're left with the resulting x tensor.
+        // For simultaneous sequences, all sequences need to have the same length.
+        x = ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d);
+
+        // bias
+        x = ggml_add(ctx, x, model.layers[il].ssm_conv1d_b);
+
+        x = ggml_silu(ctx, x);
+    }
+
+    // ssm
+    {
+        // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs}
+        struct ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_x, x);
+        // split
+        struct ggml_tensor * dt = ggml_view_3d(ctx, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0);
+        struct ggml_tensor * B  = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank);
+        struct ggml_tensor * C  = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state));
+
+        // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
+        if (ssm_dt_b_c_rms) {
+            dt = ggml_rms_norm(ctx, dt, norm_rms_eps);
+            B = ggml_rms_norm(ctx, B, norm_rms_eps);
+            C = ggml_rms_norm(ctx, C, norm_rms_eps);
+        }
+
+        // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
+        dt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_dt, dt);
+        dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b);
+
+        // Custom operator to optimize the parallel associative scan
+        // as described in the Annex D of the Mamba paper.
+        // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
+        struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C);
+
+        // store last states
+        ggml_build_forward_expand(graph,
+            ggml_cpy(ctx,
+                ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, x->nb[3]),
+                ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
+
+        struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0);
+
+        // TODO: skip computing output earlier for unused tokens
+
+        // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs}
+        y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d));
+        y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z)));
+
+        // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
+        cur = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_out, y);
+    }
+
+    // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
+    cur = ggml_reshape_2d(ctx, cur, cur->ne[0], n_seq_tokens * n_seqs);
+    cb(cur, "mamba_out", il);
+
+    return cur;
+}
+
 struct llm_build_context {
     const llama_model    & model;
           llama_context  & lctx;
     const llama_hparams  & hparams;
     const llama_cparams  & cparams;
-    const llama_batch    & batch;
+    const llama_ubatch   & batch;
     const llama_kv_cache & kv_self;
 
     const int64_t n_embd;
@@ -8548,7 +9208,7 @@ struct llm_build_context {
     // TODO: consider making the entire interface noexcept
     llm_build_context(
         llama_context  & lctx,
-    const llama_batch  & batch,
+    const llama_ubatch & batch,
     const llm_build_cb & cb,
                   bool   worst_case) :
         model            (lctx.model),
@@ -8655,29 +9315,6 @@ struct llm_build_context {
         return gf;
     }
 
-    struct ggml_cgraph * build_s_copy() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
-
-        GGML_ASSERT(kv_self.recurrent);
-
-        struct ggml_tensor * state_copy = build_inp_s_copy();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
-            struct ggml_tensor * ssm_states  = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
-
-            conv_states = ggml_get_rows(ctx0, conv_states, state_copy);
-            ssm_states  = ggml_get_rows(ctx0,  ssm_states, state_copy);
-
-            // TODO: name the intermediate tensors with cb()
-
-            ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
-            ggml_build_forward_expand(gf, ggml_cpy(ctx0,  ssm_states, kv_self.v_l[il]));
-        }
-
-        return gf;
-    }
-
     struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
@@ -8812,7 +9449,7 @@ struct llm_build_context {
     }
 
     struct ggml_tensor * build_inp_s_copy() {
-        lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size);
+        lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
         cb(lctx.inp_s_copy, "inp_s_copy", -1);
         ggml_set_input(lctx.inp_s_copy);
         return lctx.inp_s_copy;
@@ -8825,13 +9462,6 @@ struct llm_build_context {
         return lctx.inp_s_mask;
     }
 
-    struct ggml_tensor * build_inp_s_seq() {
-        lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
-        cb(lctx.inp_s_seq, "inp_s_seq", -1);
-        ggml_set_input(lctx.inp_s_seq);
-        return lctx.inp_s_seq;
-    }
-
     struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) {
         // find result_norm tensor for input
         struct ggml_tensor * inp = nullptr;
@@ -12161,136 +12791,31 @@ struct llm_build_context {
     struct ggml_cgraph * build_mamba() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
-        const int64_t d_model = n_embd;
-        const int64_t d_conv  = hparams.ssm_d_conv;
-        const int64_t d_inner = hparams.ssm_d_inner;
-        GGML_ASSERT(2 * d_model == d_inner);
-        const int64_t d_state = hparams.ssm_d_state;
-        const int64_t dt_rank = hparams.ssm_dt_rank;
-        // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
-        const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
-        // Use the same RMS norm as the final layer norm
-        const float norm_rms_eps = hparams.f_norm_rms_eps;
-
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
         // {n_embd, n_tokens}
         inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
+        struct ggml_tensor * state_copy = build_inp_s_copy();
         struct ggml_tensor * state_mask = build_inp_s_mask();
-        struct ggml_tensor * state_seq  = build_inp_s_seq();
 
         for (int il = 0; il < n_layer; ++il) {
-            // (ab)using the KV cache to store the states
-            struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
-            struct ggml_tensor * ssm_states  = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
-
-            // clear states of sequences which are starting at the beginning of this batch
-            {
-                conv_states = ggml_mul(ctx0,
-                    ggml_view_2d(ctx0, conv_states, conv_states->ne[0], n_kv, conv_states->nb[1], kv_head*conv_states->nb[1]),
-                    state_mask);
-                ssm_states  = ggml_mul(ctx0,
-                    ggml_view_2d(ctx0, ssm_states, ssm_states->ne[0], n_kv, ssm_states->nb[1], kv_head*ssm_states->nb[1]),
-                    state_mask);
-            }
-
-            conv_states = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_kv);
-            ssm_states  = ggml_reshape_3d(ctx0,  ssm_states,    d_state, d_inner, n_kv);
-
             // norm
             cur = llm_build_norm(ctx0, inpL, hparams,
                     model.layers[il].attn_norm, NULL,
                     LLM_NORM_RMS, cb, il);
             cb(cur, "attn_norm", il);
 
-            // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens}
-            struct ggml_tensor * xz = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_in, cur);
-            // split the above in two
-            // => {d_inner, n_tokens}
-            struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
-            struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner);
-
-            // conv
-            {
-                // Custom operator which is needed only to ease simultaneous sequence processing.
-                // For a single sequence, the equivalent is to concatenate the columns of conv_states and x,
-                // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension,
-                // then element-wise multiply that with the conv1d weigth,
-                // then sum the elements of each row,
-                // (the last two steps are a dot product over rows (also doable with mul_mat))
-                // then permute away the ne[0] dimension,
-                // and then you're left with the resulting x tensor.
-                // The new conv_states is the last (d_conv - 1) columns
-                // of the last 3rd dimensional "layer" of the self-overlapping view.
-                // For simultaneous sequences, it's more complicated.
-                struct ggml_tensor * x_conv = ggml_ssm_conv(ctx0, conv_states, x, model.layers[il].ssm_conv1d, state_seq);
-
-                // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache
-                ggml_build_forward_expand(gf,
-                    ggml_cpy(ctx0,
-                        ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_kv, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)),
-                        ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv))));
-
-                // extract x from x_conv
-                x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0);
-
-                // bias
-                x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b);
-
-                x = ggml_silu(ctx0, x);
-            }
-
-            // ssm
-            {
-                // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens}
-                struct ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_x, x);
-                // split
-                struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0);
-                struct ggml_tensor * B  = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
-                struct ggml_tensor * C  = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
-
-                // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
-                if (ssm_dt_b_c_rms) {
-                    dt = ggml_rms_norm(ctx0, dt, norm_rms_eps);
-                    B = ggml_rms_norm(ctx0, B, norm_rms_eps);
-                    C = ggml_rms_norm(ctx0, C, norm_rms_eps);
-                }
-
-                // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens}
-                dt = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_dt, dt);
-                dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
-
-                // Custom operator to optimize the parallel associative scan
-                // as described in the Annex D of the Mamba paper.
-                // => {d_inner, n_tokens} and {d_state, d_inner, n_kv} combined,
-                // because only a single tensor can be returned.
-                struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx0, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq);
-
-                // store last states (the second part of y_ssm_states)
-                ggml_build_forward_expand(gf,
-                    ggml_cpy(ctx0,
-                        ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)),
-                        ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_head*d_state*d_inner*ggml_element_size(ssm_states))));
-
-                struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0);
-
-                if (il == n_layer - 1) {
-                    // skip computing output for unused tokens
-                    struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                    x    = ggml_get_rows(ctx0,    x, inp_out_ids);
-                    y    = ggml_get_rows(ctx0,    y, inp_out_ids);
-                    z    = ggml_get_rows(ctx0,    z, inp_out_ids);
-                    inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
-                }
-
-                // {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens}
-                y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
-                y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
+            cur = llm_build_mamba(ctx0, lctx, batch, gf, cur,
+                    state_copy, state_mask,
+                    kv_head, n_kv, cb, il);
 
-                // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens}
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_out, y);
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
             }
 
             // residual
@@ -14156,8 +14681,8 @@ struct llm_build_context {
 };
 
 static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
-    llama_batch dummy;
-    dummy.n_tokens = 0;
+    llama_ubatch dummy = {};
+    dummy.equal_seqs = true;
 
     llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
 
@@ -14173,8 +14698,8 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const
 }
 
 static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
-    llama_batch dummy;
-    dummy.n_tokens = 0;
+    llama_ubatch dummy = {};
+    dummy.equal_seqs = true;
 
     llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
 
@@ -14189,26 +14714,9 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
     return result;
 }
 
-static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) {
-    llama_batch dummy;
-    dummy.n_tokens = 0;
-
-    llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
-
-    struct llm_build_context llm(lctx, dummy, cb, false);
-
-    llm.init();
-
-    struct ggml_cgraph * result = llm.build_s_copy();
-
-    llm.free();
-
-    return result;
-}
-
 static struct ggml_cgraph * llama_build_graph(
          llama_context & lctx,
-     const llama_batch & batch,
+    const llama_ubatch & batch,
                   bool   worst_case) {
     const auto & model = lctx.model;
 
@@ -14478,7 +14986,7 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
     return relative_bucket;
 }
 
-static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
+static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
     //
     // set input data
     //
@@ -14517,10 +15025,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
             for (int i = 0; i < n_tokens; ++i) {
                 data[i] = i;
             }
-        } else if (batch.logits) {
+        } else if (batch.output) {
             int32_t n_outputs = 0;
             for (int i = 0; i < n_tokens; ++i) {
-                if (batch.logits[i]) {
+                if (batch.output[i]) {
                     data[n_outputs++] = i;
                 }
             }
@@ -14544,8 +15052,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
     if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) {
         // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
         if (cparams.causal_attn && !lctx.is_encoding) {
-            const int64_t n_kv     = kv_self.n;
-            const int64_t n_tokens = batch.n_tokens;
+            const int64_t n_kv         = kv_self.n;
+            const int64_t n_tokens     = batch.n_tokens;
+            const int64_t n_seq_tokens = batch.n_seq_tokens;
+            const int64_t n_seqs       = batch.n_seqs;
 
 
             float * data     = nullptr;
@@ -14565,32 +15075,35 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
             // of the correct sequence for each token of the batch.
             // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
             for (int h = 0; h < 1; ++h) {
-                for (int j = 0; j < n_tokens; ++j) {
-                    const llama_pos    pos    = batch.pos[j];
-                    const llama_seq_id seq_id = batch.seq_id[j][0];
+                for (int s = 0; s < n_seqs; ++s) {
+                    const llama_seq_id seq_id = batch.seq_id[s][0];
 
-                    for (int i = 0; i < n_kv; ++i) {
-                        float f;
-                        if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
-                            f = -INFINITY;
-                        } else {
-                            if (hparams.use_alibi) {
-                                f = -std::abs(lctx.kv_self.cells[i].pos - pos);
+                    for (int j = 0; j < n_seq_tokens; ++j) {
+                        const llama_pos pos = batch.pos[s*n_seq_tokens + j];
+
+                        for (int i = 0; i < n_kv; ++i) {
+                            float f;
+                            if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
+                                f = -INFINITY;
                             } else {
-                                f = 0.0f;
+                                if (hparams.use_alibi) {
+                                    f = -std::abs(kv_self.cells[i].pos - pos);
+                                } else {
+                                    f = 0.0f;
+                                }
                             }
-                        }
 
-                        if (data) {
-                            data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
-                        }
+                            if (data) {
+                                data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
+                            }
 
-                        // may need to cut off old tokens for sliding window
-                        if (data_swa) {
-                            if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
-                                f = -INFINITY;
+                            // may need to cut off old tokens for sliding window
+                            if (data_swa) {
+                                if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
+                                    f = -INFINITY;
+                                }
+                                data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
                             }
-                            data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
                         }
                     }
                 }
@@ -14612,8 +15125,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
                 }
             }
         } else {
+            const int64_t n_tokens     = batch.n_tokens;
+            const int64_t n_seq_tokens = batch.n_seq_tokens;
+            const int64_t n_seqs       = batch.n_seqs;
             // when using kv cache, the mask needs to match the kv cache size
-            const int64_t n_tokens = batch.n_tokens;
             const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens;
 
             GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
@@ -14621,27 +15136,35 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
             float * data = (float *) lctx.inp_KQ_mask->data;
 
             for (int h = 0; h < 1; ++h) {
-                for (int j = 0; j < n_tokens; ++j) {
-                    const llama_seq_id seq_id = batch.seq_id[j][0];
-
-                    for (int i = 0; i < n_tokens; ++i) {
-                        float f = -INFINITY;
-                        for (int s = 0; s < batch.n_seq_id[i]; ++s) {
-                            if (batch.seq_id[i][s] == seq_id) {
-                                if (hparams.use_alibi) {
-                                    f = -std::abs(batch.pos[i] - batch.pos[j]);
-                                } else {
-                                    f = 0.0f;
+                for (int s1 = 0; s1 < n_seqs; ++s1) {
+                    const llama_seq_id seq_id = batch.seq_id[s1][0];
+
+                    for (int j = 0; j < n_seq_tokens; ++j) {
+                        const int32_t tj = s1*n_seq_tokens + j;
+
+                        for (int s0 = 0; s0 < n_seqs; ++s0) {
+                            for (int i = 0; i < n_seq_tokens; ++i) {
+                                const int32_t ti = s0*n_seq_tokens + i;
+                                float f = -INFINITY;
+
+                                for (int s = 0; s < batch.n_seq_id[s0]; ++s) {
+                                    if (batch.seq_id[s0][s] == seq_id) {
+                                        if (hparams.use_alibi) {
+                                            f = -std::abs(batch.pos[ti] - batch.pos[tj]);
+                                        } else {
+                                            f = 0.0f;
+                                        }
+                                        break;
+                                    }
                                 }
-                                break;
+
+                                data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
                             }
                         }
 
-                        data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
-                    }
-
-                    for (int i = n_tokens; i < n_stride; ++i) {
-                        data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
+                        for (int i = n_tokens; i < n_stride; ++i) {
+                            data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
+                        }
                     }
                 }
             }
@@ -14649,7 +15172,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
     }
 
     if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
-        const int64_t n_tokens = batch.n_tokens;
+        const int64_t n_tokens     = batch.n_tokens;
+        const int64_t n_seq_tokens = batch.n_seq_tokens;
+        const int64_t n_seqs       = batch.n_seqs;
 
         GGML_ASSERT(lctx.inp_mean);
         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
@@ -14658,12 +15183,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean));
 
         std::vector<uint64_t> sum(n_tokens, 0);
-        for (int i = 0; i < n_tokens; ++i) {
-            const llama_seq_id seq_id = batch.seq_id[i][0];
 
+        for (int s = 0; s < n_seqs; ++s) {
+            const llama_seq_id seq_id = batch.seq_id[s][0];
+
+            // TODO: adapt limits to n_seqs when batch.equal_seqs is true
             GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
 
-            sum[seq_id] += 1;
+            sum[seq_id] += batch.n_seq_tokens;
         }
 
         std::vector<float> div(n_tokens, 0.0f);
@@ -14674,14 +15201,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
             }
         }
 
-        for (int i = 0; i < n_tokens; ++i) {
-            const llama_seq_id seq_id = batch.seq_id[i][0];
-            data[seq_id*n_tokens + i] = div[seq_id];
+        for (int s = 0; s < n_seqs; ++s) {
+            const llama_seq_id seq_id = batch.seq_id[s][0];
+
+            for (int i = 0; i < n_seq_tokens; ++i) {
+                data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
+            }
         }
     }
 
     if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
-        const int64_t n_tokens = batch.n_tokens;
+        const int64_t n_tokens     = batch.n_tokens;
+        const int64_t n_seq_tokens = batch.n_seq_tokens;
+        const int64_t n_seqs       = batch.n_seqs;
 
         GGML_ASSERT(lctx.inp_cls);
         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
@@ -14689,20 +15221,26 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         uint32_t * data = (uint32_t *) lctx.inp_cls->data;
         memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
 
-        for (int i = 0; i < n_tokens; ++i) {
-            const llama_seq_id seq_id = batch.seq_id[i][0];
-            const llama_pos    pos    = batch.pos[i];
+        for (int s = 0; s < n_seqs; ++s) {
+            const llama_seq_id seq_id = batch.seq_id[s][0];
 
+            // TODO: adapt limits to n_seqs when batch.equal_seqs is true
             GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS");
 
-            if (pos == 0) {
-                data[seq_id] = i;
+            for (int i = 0; i < n_seq_tokens; ++i) {
+                const llama_pos pos = batch.pos[s*n_seq_tokens + i];
+
+                if (pos == 0) {
+                    data[seq_id] = s*n_seq_tokens + i;
+                }
             }
         }
     }
 
     if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
-        const int64_t n_tokens = batch.n_tokens;
+        const int64_t n_tokens     = batch.n_tokens;
+        const int64_t n_seq_tokens = batch.n_seq_tokens;
+        const int64_t n_seqs       = batch.n_seqs;
 
         GGML_ASSERT(lctx.inp_cls);
         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
@@ -14713,15 +15251,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         std::vector<int> last_pos(n_tokens, -1);
         std::vector<int> last_row(n_tokens, -1);
 
-        for (int i = 0; i < n_tokens; ++i) {
-            const llama_seq_id seq_id = batch.seq_id[i][0];
-            const llama_pos    pos    = batch.pos[i];
+        for (int s = 0; s < n_seqs; ++s) {
+            const llama_seq_id seq_id = batch.seq_id[s][0];
 
+            // TODO: adapt limits to n_seqs when batch.equal_seqs is true
             GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
 
-            if (pos >= last_pos[seq_id]) {
-                last_pos[seq_id] = pos;
-                last_row[seq_id] = i;
+            for (int i = 0; i < n_seq_tokens; ++i) {
+                const llama_pos pos = batch.pos[s*n_seq_tokens + i];
+
+                if (pos >= last_pos[seq_id]) {
+                    last_pos[seq_id] = pos;
+                    last_row[seq_id] = s*n_seq_tokens + i;
+                }
             }
         }
 
@@ -14739,41 +15281,39 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
             GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
             float * data = (float *) lctx.inp_s_mask->data;
 
-            // states which are not affected by the current batch are left untouched
+            // clear unused states
             for (int i = 0; i < n_kv; ++i) {
-                llama_seq_id    seq_id       = i + lctx.kv_self.head;
-                llama_kv_cell & kv_cell      = lctx.kv_self.cells[seq_id];
-                bool            has_self_seq = kv_cell.has_seq_id(seq_id);
+                uint32_t        cell_id = i + kv_self.head;
+                llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
 
-                data[i] = (float) has_self_seq;
+                data[i] = (float) (kv_cell.src >= 0);
 
-                // ensure current sequences will be kept
-                if (!has_self_seq && kv_cell.pos >= 0) {
-                    kv_cell.seq_id.insert(seq_id);
+                // only clear once
+                if (kv_cell.src < 0) {
+                    kv_cell.src = cell_id;
                 }
             }
         }
-        // For Mamba (and other recurrent architectures),
-        // update the correct state(s)/sequence(s) for each token of the batch.
-        // Like with the KQ_mask, if a token in the batch has multiple sequences,
-        // they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv).
-        if (lctx.inp_s_seq) {
-            const int64_t n_tokens = batch.n_tokens;
 
-            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer));
-            int32_t * data = (int32_t *) lctx.inp_s_seq->data;
+        if (lctx.inp_s_copy) {
+            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
+            int32_t * data = (int32_t *) lctx.inp_s_copy->data;
 
-            for (int j = 0; j < n_tokens; ++j) {
-                const int32_t n_seq = batch.n_seq_id[j];
-                GGML_ASSERT(0 < n_seq); // a token should be part of at least 1 sequence
+            // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
+            for (uint32_t i = 0; i < n_kv; ++i) {
+                const uint32_t  cell_id = i + kv_self.head;
+                llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
 
-                for (int i = 0; i < n_kv; ++i) {
-                    if (i < n_seq) {
-                        // for this type of model, the head is the minimum seq_id of the batch
-                        data[j*n_kv + i] = batch.seq_id[j][i] - kv_self.head;
-                    } else {
-                        data[j*n_kv + i] = -1;
-                    }
+                // prevent out-of-bound sources
+                if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) {
+                    kv_cell.src = cell_id;
+                }
+
+                data[i] = kv_cell.src;
+
+                // ensure copy only happens once
+                if (kv_cell.src != (int32_t) cell_id) {
+                    kv_cell.src = cell_id;
                 }
             }
         }
@@ -14783,6 +15323,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         const int64_t n_tokens = batch.n_tokens;
 
         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer));
+        GGML_ASSERT(!batch.equal_seqs); // TODO: use batch.n_seqs instead of failing
 
         int32_t * data = (int32_t *) lctx.inp_pos_bucket->data;
 
@@ -14818,6 +15359,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         const int64_t n_tokens = batch.n_tokens;
 
         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer));
+        GGML_ASSERT(!batch.equal_seqs); // TODO: use batch.n_seqs instead of failing
 
         float * data = (float *) lctx.inp_KQ_mask_cross->data;
 
@@ -14911,6 +15453,43 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
     return n_outputs_max;
 }
 
+// make the outputs have the same order they had in the user-provided batch
+static void llama_output_reorder(struct llama_context * ctx) {
+    std::vector<size_t> & out_ids = ctx->sbatch.out_ids;
+    if (!out_ids.empty()) {
+        uint32_t n_vocab = ctx->model.hparams.n_vocab;
+        uint32_t n_embd  = ctx->model.hparams.n_embd;
+        int32_t n_outputs = ctx->n_outputs;
+        GGML_ASSERT((size_t) n_outputs == out_ids.size());
+        // TODO: is there something more efficient which also minimizes swaps?
+        // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
+        for (int32_t i = 0; i < n_outputs - 1; ++i) {
+            int32_t j_min = i;
+            for (int32_t j = i + 1; j < n_outputs; ++j) {
+                if (out_ids[j] < out_ids[j_min]) {
+                    j_min = j;
+                }
+            }
+            if (j_min == i) { continue; }
+            std::swap(out_ids[i], out_ids[j_min]);
+            if (ctx->logits_size > 0) {
+                for (uint32_t k = 0; k < n_vocab; k++) {
+                    std::swap(ctx->logits[i*n_vocab + k], ctx->logits[j_min*n_vocab + k]);
+                }
+            }
+            if (ctx->embd_size > 0) {
+                for (uint32_t k = 0; k < n_embd; k++) {
+                    std::swap(ctx->embd[i*n_embd + k], ctx->embd[j_min*n_embd + k]);
+                }
+            }
+        }
+        std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1);
+        for (int32_t i = 0; i < n_outputs; ++i) {
+            ctx->output_ids[out_ids[i]] = i;
+        }
+        out_ids.clear();
+    }
+}
 
 static void llama_graph_compute(
         llama_context & lctx,
@@ -14983,15 +15562,11 @@ static int llama_decode_internal(
 
     const auto n_ubatch = cparams.n_ubatch;
 
-    // TODO: simplify or deprecate
-    std::vector<llama_pos> pos;
-    std::vector<int32_t>                   n_seq_id;
-    std::vector<llama_seq_id *>            seq_id_arr;
-    std::vector<std::vector<llama_seq_id>> seq_id;
-
     // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
     const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
 
+    lctx.embd_seq.clear();
+
     // count outputs
     if (batch_all.logits && !embd_pooled) {
         for (uint32_t i = 0; i < n_tokens_all; ++i) {
@@ -15004,55 +15579,42 @@ static int llama_decode_internal(
         n_outputs = 1;
     }
 
+    lctx.sbatch.from_batch(batch_all, n_embd,
+        /* simple_split */ !kv_self.recurrent,
+        /* logits_all   */ n_outputs == n_tokens_all);
+
     // reserve output buffer
     if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
         LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
         return -2;
     };
 
-    // set output mappings
-    if (batch_all.logits) {
-        int32_t i_logits = 0;
-        for (uint32_t i = 0; i < n_tokens_all; ++i) {
-            if (batch_all.logits[i]) {
-                lctx.output_ids[i] = i_logits++;
+    while (lctx.sbatch.n_tokens > 0) {
+        llama_ubatch ubatch;
+        if (kv_self.recurrent) {
+            if (embd_pooled) {
+                // Pooled embeddings cannot be split across ubatches (yet)
+                ubatch = lctx.sbatch.split_seq(n_ubatch);
+            } else {
+                // recurrent model architectures are easier to implement
+                // with equal-length sequences
+                ubatch = lctx.sbatch.split_equal(n_ubatch);
             }
+        } else {
+            ubatch = lctx.sbatch.split_simple(n_ubatch);
         }
-    } else {
-        for (uint32_t i = 0; i < n_outputs; ++i) {
-            lctx.output_ids[i] = i;
-        }
-    }
-
-    for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) {
-        const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
-        llama_batch u_batch = {
-            /* .n_tokens   = */ (int32_t) n_tokens,
-            /* .token      = */ batch_all.token     ? batch_all.token    + cur_token        : nullptr,
-            /* .embd       = */ batch_all.embd      ? batch_all.embd     + cur_token*n_embd : nullptr,
-            /* .pos        = */ batch_all.pos       ? batch_all.pos      + cur_token        : nullptr,
-            /* .n_seq_id   = */ batch_all.n_seq_id  ? batch_all.n_seq_id + cur_token        : nullptr,
-            /* .seq_id     = */ batch_all.seq_id    ? batch_all.seq_id   + cur_token        : nullptr,
-            /* .logits     = */ batch_all.logits    ? batch_all.logits   + cur_token        : nullptr,
-            /* .all_pos_0  = */ batch_all.all_pos_0 + (llama_pos) cur_token*batch_all.all_pos_1,
-            /* .all_pos_1  = */ batch_all.all_pos_1,
-            /* .all_seq_id = */ batch_all.all_seq_id,
-        };
+        const uint32_t n_tokens = ubatch.n_tokens;
 
         // count the outputs in this u_batch
         {
             int32_t n_outputs_new = 0;
 
-            if (u_batch.logits && !embd_pooled) {
-                for (uint32_t i = 0; i < n_tokens; i++) {
-                    n_outputs_new += u_batch.logits[i] != 0;
-                }
-            } else if (n_outputs == n_tokens_all) {
+            if (n_outputs == n_tokens_all) {
                 n_outputs_new = n_tokens;
             } else {
-                // keep last output only
-                if (cur_token + n_tokens >= n_tokens_all) {
-                    n_outputs_new = 1;
+                GGML_ASSERT(ubatch.output);
+                for (uint32_t i = 0; i < n_tokens; i++) {
+                    n_outputs_new += (int32_t) (ubatch.output[i] != 0);
                 }
             }
 
@@ -15063,32 +15625,6 @@ static int llama_decode_internal(
         int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
         GGML_ASSERT(n_threads > 0);
 
-        // helpers for smoother batch API transition
-        // after deprecating the llama_eval calls, these will be removed
-        if (u_batch.pos == nullptr) {
-            pos.resize(n_tokens);
-            for (uint32_t i = 0; i < n_tokens; i++) {
-                pos[i] = u_batch.all_pos_0 + i*u_batch.all_pos_1;
-            }
-
-            u_batch.pos = pos.data();
-        }
-
-        if (u_batch.seq_id == nullptr) {
-            n_seq_id.resize(n_tokens);
-            seq_id.resize(n_tokens);
-            seq_id_arr.resize(n_tokens);
-            for (uint32_t i = 0; i < n_tokens; i++) {
-                n_seq_id[i] = 1;
-                seq_id[i].resize(1);
-                seq_id[i][0] = u_batch.all_seq_id;
-                seq_id_arr[i] = seq_id[i].data();
-            }
-
-            u_batch.n_seq_id = n_seq_id.data();
-            u_batch.seq_id = seq_id_arr.data();
-        }
-
         // non-causal masks do not use the KV cache
         if (hparams.causal_attn) {
             llama_kv_cache_update(&lctx);
@@ -15099,7 +15635,7 @@ static int llama_decode_internal(
                 kv_self.head = 0;
             }
 
-            if (!llama_kv_cache_find_slot(kv_self, u_batch)) {
+            if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
                 return 1;
             }
 
@@ -15118,7 +15654,7 @@ static int llama_decode_internal(
         ggml_backend_sched_reset(lctx.sched);
         ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
 
-        ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);
+        ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false);
 
         // the output is always the last tensor in the graph
         struct ggml_tensor * res  = gf->nodes[gf->n_nodes - 1];
@@ -15146,7 +15682,7 @@ static int llama_decode_internal(
 
         ggml_backend_sched_alloc_graph(lctx.sched, gf);
 
-        llama_set_inputs(lctx, u_batch);
+        llama_set_inputs(lctx, ubatch);
 
         llama_graph_compute(lctx, gf, n_threads);
 
@@ -15204,12 +15740,11 @@ static int llama_decode_internal(
                 case LLAMA_POOLING_TYPE_CLS:
                 case LLAMA_POOLING_TYPE_LAST:
                     {
-                        // extract sequence embeddings
+                        // extract sequence embeddings (cleared before processing each batch)
                         auto & embd_seq_out = lctx.embd_seq;
-                        embd_seq_out.clear();
 
-                        for (uint32_t i = 0; i < n_tokens; i++) {
-                            const llama_seq_id seq_id = u_batch.seq_id[i][0];
+                        for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
+                            const llama_seq_id seq_id = ubatch.seq_id[s][0];
                             if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
                                 continue;
                             }
@@ -15226,6 +15761,25 @@ static int llama_decode_internal(
         n_outputs_prev += lctx.n_outputs;
     }
 
+    // set output mappings
+    {
+        bool sorted_output = true;
+
+        GGML_ASSERT(lctx.sbatch.out_ids.size() == n_outputs);
+
+        for (size_t i = 0; i < n_outputs; ++i) {
+            size_t out_id = lctx.sbatch.out_ids[i];
+            lctx.output_ids[out_id] = i;
+            if (out_id != i) {
+                sorted_output = false;
+            }
+        }
+
+        if (sorted_output) {
+            lctx.sbatch.out_ids.clear();
+        }
+    }
+
     // set to total number of outputs in the batch, for use in llama_get_logits_ith
     lctx.n_outputs = n_outputs;
 
@@ -15290,11 +15844,9 @@ static int llama_encode_internal(
 
     const int64_t n_embd = hparams.n_embd;
 
-    // TODO: simplify or deprecate
-    std::vector<llama_pos> pos;
-    std::vector<int32_t>                   n_seq_id;
-    std::vector<llama_seq_id *>            seq_id_arr;
-    std::vector<std::vector<llama_seq_id>> seq_id;
+    lctx.sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
+
+    const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens);
 
     // reserve output buffer
     if (llama_output_reserve(lctx, n_tokens) < n_tokens) {
@@ -15312,36 +15864,10 @@ static int llama_encode_internal(
     const int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
     GGML_ASSERT(n_threads > 0);
 
-    // helpers for smoother batch API transition
-    // after deprecating the llama_eval calls, these will be removed
-    if (batch.pos == nullptr) {
-        pos.resize(n_tokens);
-        for (uint32_t i = 0; i < n_tokens; i++) {
-            pos[i] = batch.all_pos_0 + i*batch.all_pos_1;
-        }
-
-        batch.pos = pos.data();
-    }
-
-    if (batch.seq_id == nullptr) {
-        n_seq_id.resize(n_tokens);
-        seq_id.resize(n_tokens);
-        seq_id_arr.resize(n_tokens);
-        for (uint32_t i = 0; i < n_tokens; i++) {
-            n_seq_id[i] = 1;
-            seq_id[i].resize(1);
-            seq_id[i][0] = batch.all_seq_id;
-            seq_id_arr[i] = seq_id[i].data();
-        }
-
-        batch.n_seq_id = n_seq_id.data();
-        batch.seq_id = seq_id_arr.data();
-    }
-
     ggml_backend_sched_reset(lctx.sched);
     ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
 
-    ggml_cgraph * gf = llama_build_graph(lctx, batch, false);
+    ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false);
 
     // the output embeddings after the final encoder normalization
     struct ggml_tensor * embd = nullptr;
@@ -15365,7 +15891,7 @@ static int llama_encode_internal(
 
     ggml_backend_sched_alloc_graph(lctx.sched, gf);
 
-    llama_set_inputs(lctx, batch);
+    llama_set_inputs(lctx, ubatch);
 
     llama_graph_compute(lctx, gf, n_threads);
 
@@ -15379,12 +15905,13 @@ static int llama_encode_internal(
             float * embd_out = lctx.embd_enc.data();
 
             ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
+            GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
 
             // remember the sequence ids used during the encoding - needed for cross attention later
             lctx.seq_ids_enc.resize(n_tokens);
             for (uint32_t i = 0; i < n_tokens; i++) {
-                for (int s = 0; s < batch.n_seq_id[i]; s++) {
-                    llama_seq_id seq_id = batch.seq_id[i][s];
+                for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
+                    llama_seq_id seq_id = ubatch.seq_id[i][s];
                     lctx.seq_ids_enc[i].insert(seq_id);
                 }
             }
@@ -15409,8 +15936,10 @@ static int llama_encode_internal(
                         auto & embd_seq_out = lctx.embd_seq;
                         embd_seq_out.clear();
 
+                        GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
+
                         for (uint32_t i = 0; i < n_tokens; i++) {
-                            const llama_seq_id seq_id = batch.seq_id[i][0];
+                            const llama_seq_id seq_id = ubatch.seq_id[i][0];
                             if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
                                 continue;
                             }
@@ -15688,32 +16217,6 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
         }
     }
 
-    if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) {
-        {
-            ggml_backend_sched_reset(lctx.sched);
-
-            ggml_cgraph * gf = llama_build_graph_s_copy(lctx);
-
-            ggml_backend_sched_alloc_graph(lctx.sched, gf);
-
-            llama_set_s_copy(lctx);
-
-            llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
-
-            need_reserve = true;
-        }
-
-        {
-            auto & kv_self = lctx.kv_self;
-
-            kv_self.do_copy = false;
-
-            for (uint32_t i = 0; i < kv_self.size; ++i) {
-                kv_self.cells[i].src = i;
-            }
-        }
-    }
-
     // defragment the KV cache if needed
     if (lctx.kv_self.do_defrag) {
         llama_kv_cache_defrag_internal(lctx);
@@ -15727,10 +16230,11 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
     if (need_reserve) {
         // TODO: extract to a function
         // build worst-case graph
-        int n_tokens = (int)std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
-        int n_past = lctx.cparams.n_ctx - n_tokens;
+        uint32_t n_seqs = 1; // TODO: worst-case number of sequences
+        uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
         llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
-        ggml_cgraph * gf = llama_build_graph(lctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true);
+        llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
+        ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
 
         // initialize scheduler with the worst-case graph
         ggml_backend_sched_reset(lctx.sched);
@@ -16326,12 +16830,15 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
     qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
 
     // sanity checks
-    //
-    //  - qs.n_attention_wv == 0                         for Mamba           models
-    //  - qs.n_attention_wv == model.hparams.n_layer     for Transformer     models
-    //  - qs.n_attention_wv == 3 * model.hparams.n_layer for Encoder-Decoder models
-    //
-    GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer || qs.n_attention_wv == 3 * (int)model.hparams.n_layer) && "n_attention_wv is unexpected");
+    {
+        const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
+        // attention layers have a non-zero number of kv heads
+        int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
+        if (llama_model_has_encoder(&model)) {
+            n_attn_layer *= 3;
+        }
+        GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
+    }
 
     size_t total_size_org = 0;
     size_t total_size_new = 0;
@@ -17140,7 +17647,7 @@ struct llama_context * llama_new_context_with_model(
     ggml_type type_v = params.type_v;
 
     // Mamba only needs a constant number of KV cache cells per sequence
-    if (model->arch == LLM_ARCH_MAMBA) {
+    if (llama_model_is_recurrent(model)) {
         // Mamba needs at least as many KV cells as there are sequences kept at any time
         kv_size = std::max((uint32_t) 1, params.n_seq_max);
         // it's probably best to keep as much precision as possible for the states
@@ -17372,10 +17879,11 @@ struct llama_context * llama_new_context_with_model(
             }
 
             // build worst-case graph
-            int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_ubatch);
-            int n_past = cparams.n_ctx - n_tokens;
+            uint32_t n_seqs = 1; // TODO: worst-case number of sequences
+            uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
             llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
-            ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true);
+            llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
+            ggml_cgraph * gf = llama_build_graph(*ctx, ubatch, true);
 
             // initialize scheduler with the worst-case graph
             if (!ggml_backend_sched_reserve(ctx->sched, gf)) {
@@ -17615,6 +18123,13 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) {
     return model->hparams.dec_start_token_id;
 }
 
+bool llama_model_is_recurrent(const struct llama_model * model) {
+    switch (model->arch) {
+        case LLM_ARCH_MAMBA:  return true;
+        default:              return false;
+    }
+}
+
 uint32_t llama_model_quantize(
         const char * fname_inp,
         const char * fname_out,
@@ -17936,7 +18451,9 @@ struct llama_data_write {
         write_string(rng_str);
     }
 
-    void write_output_ids(const struct llama_context * ctx) {
+    void write_output_ids(struct llama_context * ctx) {
+        llama_output_reorder(ctx);
+
         const uint32_t n_outputs = ctx->n_outputs;
 
         std::vector<int32_t> output_pos;
@@ -18224,8 +18741,11 @@ struct llama_data_read {
 
             llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
 
-            llama_batch batch = llama_batch_init(cell_count, 0, 1);
+            llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
             batch.n_tokens = cell_count;
+            batch.n_seq_tokens = cell_count;
+            batch.n_seqs = 1;
+
             for (uint32_t i = 0; i < cell_count; ++i) {
                 llama_pos pos;
                 uint32_t n_seq_id;
@@ -18239,11 +18759,10 @@ struct llama_data_read {
                 }
 
                 batch.pos[i] = pos;
-                batch.n_seq_id[i] = 1;
-                batch.seq_id[i][0] = dest_seq_id;
             }
+            batch.n_seq_id[0] = 1;
+            batch.seq_id[0] = &dest_seq_id;
             if (!llama_kv_cache_find_slot(kv_self, batch)) {
-                llama_batch_free(batch);
                 LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
                 return false;
             }
@@ -18255,9 +18774,6 @@ struct llama_data_read {
             GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
             GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
             GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
-
-            // Cleanup
-            llama_batch_free(batch);
         } else {
             // whole KV cache restore
 
@@ -18289,6 +18805,15 @@ struct llama_data_read {
                     }
 
                     cell.seq_id.insert(seq_id);
+
+                    if (kv_self.recurrent) {
+                        int32_t & tail = kv_self.cells[seq_id].tail;
+                        if (tail != -1) {
+                            LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
+                            return false;
+                        }
+                        tail = i;
+                    }
                 }
             }
 
@@ -18296,6 +18821,14 @@ struct llama_data_read {
             kv_self.used = cell_count;
         }
 
+        if (kv_self.recurrent) {
+            for (uint32_t i = 0; i < cell_count; ++i) {
+                uint32_t cell_id = kv_self.head + i;
+                // make sure the recurrent states will keep their restored state
+                kv_self.cells[cell_id].src = cell_id;
+            }
+        }
+
         return true;
     }
 
@@ -18883,7 +19416,18 @@ struct llama_batch llama_batch_get_one(
 }
 
 struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
-    llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
+    llama_batch batch = {
+        /*n_tokens       =*/ 0,
+        /*tokens         =*/ nullptr,
+        /*embd           =*/ nullptr,
+        /*pos            =*/ nullptr,
+        /*n_seq_id       =*/ nullptr,
+        /*seq_id         =*/ nullptr,
+        /*logits         =*/ nullptr,
+        /*all_pos_0      =*/ 0,
+        /*all_pos_1      =*/ 0,
+        /*all_seq_id     =*/ 0,
+    };
 
     if (embd) {
         batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
@@ -18969,6 +19513,10 @@ void llama_synchronize(struct llama_context * ctx) {
 float * llama_get_logits(struct llama_context * ctx) {
     llama_synchronize(ctx);
 
+    // reorder logits for backward compatibility
+    // TODO: maybe deprecate this
+    llama_output_reorder(ctx);
+
     return ctx->logits;
 }
 
@@ -19013,6 +19561,10 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
 float * llama_get_embeddings(struct llama_context * ctx) {
     llama_synchronize(ctx);
 
+    // reorder embeddings for backward compatibility
+    // TODO: maybe deprecate this
+    llama_output_reorder(ctx);
+
     return ctx->embd;
 }
 

From 1731d4238f9e4f925a750810e7f5480827c66dcf Mon Sep 17 00:00:00 2001
From: luoyu-intel <yu.luo@intel.com>
Date: Thu, 22 Aug 2024 12:50:10 +0800
Subject: [PATCH 3/6] [SYCL] Add oneDNN primitive support (#9091)

* add onednn

* add sycl_f16

* add dnnl stream

* add engine map

* use dnnl for intel only

* use fp16fp16fp16

* update doc
---
 CMakePresets.json             |   5 +-
 docs/backend/SYCL.md          |  14 ++---
 ggml/src/CMakeLists.txt       |  10 ++++
 ggml/src/ggml-sycl.cpp        |  16 +++++-
 ggml/src/ggml-sycl/common.hpp |  50 +++++++++++++++++
 ggml/src/ggml-sycl/gemm.hpp   | 101 ++++++++++++++++++++++++++++++++++
 6 files changed, 186 insertions(+), 10 deletions(-)
 create mode 100644 ggml/src/ggml-sycl/gemm.hpp

diff --git a/CMakePresets.json b/CMakePresets.json
index bdad38952d3cb..ce627b4d39e0c 100644
--- a/CMakePresets.json
+++ b/CMakePresets.json
@@ -28,6 +28,7 @@
     { "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } },
     { "name": "reldbg",  "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } },
     { "name": "static",  "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } },
+    { "name": "sycl_f16",  "hidden": true, "cacheVariables": { "GGML_SYCL_F16": "ON" } },
 
     {
         "name": "arm64-windows-msvc", "hidden": true,
@@ -60,6 +61,8 @@
     { "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] },
 
     { "name": "x64-windows-sycl-debug"  , "inherits": [ "sycl-base", "debug"   ] },
-    { "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] }
+    { "name": "x64-windows-sycl-debug-f16", "inherits": [ "sycl-base", "debug", "sycl_f16" ] },
+    { "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] },
+    { "name": "x64-windows-sycl-release-f16", "inherits": [ "sycl-base", "release", "sycl_f16" ] }
   ]
 }
diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md
index 59a39fbb67395..e838b2be6b11c 100644
--- a/docs/backend/SYCL.md
+++ b/docs/backend/SYCL.md
@@ -20,7 +20,7 @@
 **oneAPI** is an open ecosystem and a standard-based specification, supporting multiple architectures including but not limited to intel CPUs, GPUs and FPGAs. The key components of the oneAPI ecosystem include:
 
 - **DPCPP** *(Data Parallel C++)*: The primary oneAPI SYCL implementation, which includes the icpx/icx Compilers.
-- **oneAPI Libraries**: A set of highly optimized libraries targeting multiple domains *(e.g. oneMKL - Math Kernel Library)*.
+- **oneAPI Libraries**: A set of highly optimized libraries targeting multiple domains *(e.g. oneMKL and oneDNN)*.
 - **oneAPI LevelZero**: A high performance low level interface for fine-grained control over intel iGPUs and dGPUs.
 - **Nvidia & AMD Plugins**: These are plugins extending oneAPI's DPCPP support to SYCL on Nvidia and AMD GPU targets.
 
@@ -28,10 +28,6 @@
 
 The llama.cpp SYCL backend is designed to support **Intel GPU** firstly. Based on the cross-platform feature of SYCL, it could support other vendor GPUs: Nvidia GPU (*AMD GPU coming*).
 
-When targeting **Intel CPU**, it is recommended to use llama.cpp for [Intel oneMKL](README.md#intel-onemkl) backend.
-
-It has the similar design of other llama.cpp BLAS-based paths such as *OpenBLAS, cuBLAS, etc..*. In beginning work, the oneAPI's [SYCLomatic](https://github.com/oneapi-src/SYCLomatic) open-source migration tool (Commercial release [IntelĀ® DPC++ Compatibility Tool](https://www.intel.com/content/www/us/en/developer/tools/oneapi/dpc-compatibility-tool.html)) was used for this purpose.
-
 ## Recommended Release
 
 The SYCL backend would be broken by some PRs due to no online CI.
@@ -45,6 +41,10 @@ The following release is verified with good quality:
 
 ## News
 
+
+- 2024.8
+  - Use oneDNN as the default GEMM library, improve the compatibility for new Intel GPUs.
+
 - 2024.5
   - Performance is increased: 34 -> 37 tokens/s of llama-2-7b.Q4_0 on Arc770.
   - Arch Linux is verified successfully.
@@ -196,7 +196,7 @@ Please follow the instructions for downloading and installing the Toolkit for Li
 
 Following guidelines/code snippets assume the default installation values. Otherwise, please make sure the necessary changes are reflected where applicable.
 
-Upon a successful installation, SYCL is enabled for the available intel devices, along with relevant libraries such as oneAPI MKL for intel GPUs.
+Upon a successful installation, SYCL is enabled for the available intel devices, along with relevant libraries such as oneAPI oneDNN for Intel GPUs.
 
 - **Adding support to Nvidia GPUs**
 
@@ -255,8 +255,6 @@ or
 # Export relevant ENV variables
 source /opt/intel/oneapi/setvars.sh
 
-# Build LLAMA with MKL BLAS acceleration for intel GPU
-
 # Option 1: Use FP32 (recommended for better performance in most cases)
 cmake -B build -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
 
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index 1775ef3cc9146..951cec6941076 100644
--- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt
@@ -549,6 +549,13 @@ if (GGML_SYCL)
     file(GLOB   GGML_SOURCES_SYCL "ggml-sycl/*.cpp")
     list(APPEND GGML_SOURCES_SYCL "ggml-sycl.cpp")
 
+    find_package(DNNL)
+    message("-- DNNL found:"${DNNL_FOUND})
+    if (GGML_SYCL_TARGET STREQUAL "INTEL")
+        add_compile_definitions(GGML_SYCL_DNNL=${DNNL_FOUND})
+    else()
+        add_compile_definitions(GGML_SYCL_DNNL=0)
+    endif()
     if (WIN32)
         find_package(IntelSYCL REQUIRED)
         find_package(MKL REQUIRED)
@@ -561,6 +568,9 @@ if (GGML_SYCL)
             set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} -fsycl pthread m dl onemkl)
         endif()
     endif()
+    if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
+        list(APPEND GGML_EXTRA_LIBS DNNL::dnnl)
+    endif()
 endif()
 
 if (GGML_RPC)
diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp
index 94cd4b11052e1..0d884f89a4e7b 100644
--- a/ggml/src/ggml-sycl.cpp
+++ b/ggml/src/ggml-sycl.cpp
@@ -38,6 +38,7 @@
 
 #include "ggml-sycl/backend.hpp"
 #include "ggml-sycl/presets.hpp"
+#include "ggml-sycl/gemm.hpp"
 
 bool   ggml_sycl_loaded(void);
 void   ggml_sycl_free_data(struct ggml_tensor * tensor);
@@ -2482,6 +2483,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
 
         const sycl::half alpha_f16 = 1.0f;
         const sycl::half beta_f16 = 0.0f;
+#if !GGML_SYCL_DNNL
         SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
             *stream, oneapi::mkl::transpose::trans,
             oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
@@ -2491,6 +2493,13 @@ inline void ggml_sycl_op_mul_mat_sycl(
             dpct::library_data_t::real_half)));
         const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
         to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
+#else
+        auto dnnl_stream = ctx.stream_dnnl(stream);
+        DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
+            src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
+        const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
+        to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
+#endif
     }
     else {
         // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
@@ -2513,13 +2522,18 @@ inline void ggml_sycl_op_mul_mat_sycl(
 
         const float alpha = 1.0f;
         const float beta = 0.0f;
-
+#if !GGML_SYCL_DNNL
         SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
             *stream, oneapi::mkl::transpose::trans,
             oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
             dpct::get_value(&alpha, *stream), src0_ddf_i, ne00,
             src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
             dst_dd_i, ldc)));
+#else
+        auto dnnl_stream = ctx.stream_dnnl(stream);
+         DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
+            src0_ddf_i, DnnlGemmWrapper::to_dt<float>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
+#endif
     }
     (void) dst;
     (void) src1_ddq_i;
diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp
index 78cd682ad4057..05947ccb746f2 100644
--- a/ggml/src/ggml-sycl/common.hpp
+++ b/ggml/src/ggml-sycl/common.hpp
@@ -19,6 +19,10 @@
 #include "dpct/helper.hpp"
 #include "ggml-sycl.h"
 #include "presets.hpp"
+#if GGML_SYCL_DNNL
+#include "dnnl.hpp"
+#include "dnnl_sycl.hpp"
+#endif
 
 #define GGML_COMMON_DECL_SYCL
 #define GGML_COMMON_IMPL_SYCL
@@ -277,6 +281,52 @@ struct ggml_backend_sycl_context {
         return stream(device, 0);
     }
 
+#if GGML_SYCL_DNNL
+    dnnl::engine make_engine(sycl::queue* q) {
+        // Get the device associated with the queue
+        sycl::device dev = q->get_device();
+        // Get the context associated with the queue
+        sycl::context ctx = q->get_context();
+        const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
+        return eng;
+    }
+
+    std::unordered_map<sycl::queue*, dnnl::stream> stream_map;
+    std::unordered_map<sycl::queue*, dnnl::engine> engine_map;
+    dnnl::stream stream_dnnl(int device, int _stream) {
+        auto q = stream(device, _stream);
+        return stream_dnnl(q);
+    }
+    dnnl::engine engine_dnnl(sycl::queue* qptr) {
+        auto it = engine_map.find(qptr);
+        if (it == engine_map.end()) {
+            auto eng = make_engine(qptr);
+            engine_map[qptr] = eng;
+            return eng;
+        }
+        else
+        {
+            return it->second;
+        }
+    }
+    dnnl::stream stream_dnnl(sycl::queue* qptr) {
+        auto it = stream_map.find(qptr);
+        if (it == stream_map.end()) {
+            auto eng = engine_dnnl(qptr);
+            auto stream = dnnl::sycl_interop::make_stream(eng, *qptr);
+            stream_map[qptr] = stream;
+            return stream;
+        }
+        else
+        {
+            return it->second;
+        }
+    }
+    dnnl::stream stream_dnnl() {
+        return stream_dnnl(device, 0);
+    }
+#endif
+
     // pool
     std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
 
diff --git a/ggml/src/ggml-sycl/gemm.hpp b/ggml/src/ggml-sycl/gemm.hpp
new file mode 100644
index 0000000000000..2ad9b36f419ce
--- /dev/null
+++ b/ggml/src/ggml-sycl/gemm.hpp
@@ -0,0 +1,101 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_GEMM_HPP
+#define GGML_SYCL_GEMM_HPP
+
+#include <fstream>
+#include <iostream>
+
+#include "ggml-sycl.h"
+
+#if GGML_SYCL_DNNL
+
+#include "dnnl.hpp"
+#include "dnnl_sycl.hpp"
+
+class DnnlGemmWrapper {
+public:
+    using dt = dnnl::memory::data_type;
+    using tag = dnnl::memory::format_tag;
+
+    template<typename T>
+    static constexpr dt to_dt() {
+        if constexpr (std::is_same_v<T, float>) return dt::f32;
+        else if constexpr (std::is_same_v<T, sycl::half>) return dt::f16;
+        else static_assert(0);
+    }
+
+    static inline void row_gemm(sycl::queue& q, bool a_trans,
+        bool b_trans, int m, int n, int k,
+        const void* a, dt at, const void* b, dt bt, void* c, dt ct)
+    {
+        // Get the device associated with the queue
+        sycl::device dev = q.get_device();
+        // Get the context associated with the queue
+        sycl::context ctx = q.get_context();
+        const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
+        const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
+        dnnl::memory::dims a_dims = { m, k };
+        dnnl::memory::dims b_dims = { k, n };
+        dnnl::memory::dims c_dims = { m, n };
+        const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
+        const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
+        const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
+        auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
+        auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
+        auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
+        auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
+
+        // Create the primitive.
+        auto matmul_prim = dnnl::matmul(matmul_pd);
+        // Primitive arguments.
+        std::unordered_map<int, dnnl::memory> matmul_args;
+        matmul_args.insert({ DNNL_ARG_SRC, a_mem });
+        matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
+        matmul_args.insert({ DNNL_ARG_DST, c_mem });
+
+        matmul_prim.execute(stream, matmul_args);
+    }
+
+
+    static inline void row_gemm(const dnnl::stream& stream, bool a_trans,
+        bool b_trans, int m, int n, int k,
+        const void* a, dt at, const void* b, dt bt, void* c, dt ct)
+    {
+        auto const eng = stream.get_engine();
+        dnnl::memory::dims a_dims = { m, k };
+        dnnl::memory::dims b_dims = { k, n };
+        dnnl::memory::dims c_dims = { m, n };
+        const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
+        const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
+        const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
+        auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
+        auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
+        auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
+        auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
+
+        // Create the primitive.
+        auto matmul_prim = dnnl::matmul(matmul_pd);
+        // Primitive arguments.
+        std::unordered_map<int, dnnl::memory> matmul_args;
+        matmul_args.insert({ DNNL_ARG_SRC, a_mem });
+        matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
+        matmul_args.insert({ DNNL_ARG_DST, c_mem });
+
+        matmul_prim.execute(stream, matmul_args);
+    }
+};
+
+#endif
+
+#endif // GGML_SYCL_GEMM_HPP

From 11b84eb4578864827afcf956db5b571003f18180 Mon Sep 17 00:00:00 2001
From: Akarshan Biswas <akarshan.biswas@gmail.com>
Date: Thu, 22 Aug 2024 19:39:47 +0530
Subject: [PATCH 4/6] [SYCL] Add a space to supress a cmake warning (#9133)

---
 ggml/src/CMakeLists.txt | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index 951cec6941076..ff84b9bb5f0f2 100644
--- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt
@@ -550,7 +550,7 @@ if (GGML_SYCL)
     list(APPEND GGML_SOURCES_SYCL "ggml-sycl.cpp")
 
     find_package(DNNL)
-    message("-- DNNL found:"${DNNL_FOUND})
+    message("-- DNNL found:" ${DNNL_FOUND})
     if (GGML_SYCL_TARGET STREQUAL "INTEL")
         add_compile_definitions(GGML_SYCL_DNNL=${DNNL_FOUND})
     else()

From b77d7f6f0d05dfe0ae2581de33d2e1337ca5aafb Mon Sep 17 00:00:00 2001
From: Carsten Kragelund <carsten@kragelund.me>
Date: Fri, 23 Aug 2024 05:03:46 +0000
Subject: [PATCH 5/6] fix: llama3.1 rope_freqs not respecting custom head_dim

---
 convert_hf_to_gguf.py | 2 +-
 src/llama.cpp         | 3 ++-
 2 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 108c822cff5d2..dcd54e0e34a6e 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -1570,7 +1570,7 @@ def prepare_tensors(self):
         if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
             if rope_scaling.get("rope_type", '').lower() == "llama3":
                 base = self.hparams.get("rope_theta", 10000.0)
-                dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
+                dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
                 freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
 
                 factor = rope_scaling.get("factor", 8.0)
diff --git a/src/llama.cpp b/src/llama.cpp
index bd7f1508b2644..f9502befa87ee 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -6605,6 +6605,7 @@ static bool llm_load_tensors(
         const int64_t n_embd_gqa    = n_embd_v_gqa;
         const int64_t n_vocab       = hparams.n_vocab;
         const int64_t n_vocab_type  = hparams.n_vocab_type;
+        const int64_t n_rot         = hparams.n_rot;
         const int64_t n_expert      = hparams.n_expert;
         const int64_t n_expert_used = hparams.n_expert_used;
         const int64_t n_ctx_train   = hparams.n_ctx_train;
@@ -6662,7 +6663,7 @@ static bool llm_load_tensors(
 
                         layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
 
-                        layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+                        layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
 
                         if (n_expert == 0) {
                             layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});

From 1a88919759dabbc33f1db3d0cb23eedad11718e5 Mon Sep 17 00:00:00 2001
From: Carsten Kragelund <carsten@kragelund.me>
Date: Fri, 23 Aug 2024 08:27:50 +0000
Subject: [PATCH 6/6] fix: use potential head_dim for Exaone

---
 convert_hf_to_gguf.py | 2 +-
 src/llama.cpp         | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index dcd54e0e34a6e..69498e01b2c8c 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -3816,7 +3816,7 @@ def prepare_tensors(self):
         if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
             if rope_scaling.get("rope_type", '').lower() == "llama3":
                 base = self.hparams.get("rope_theta", 10000.0)
-                dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
+                dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
                 freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
 
                 factor = rope_scaling.get("factor", 8.0)
diff --git a/src/llama.cpp b/src/llama.cpp
index f9502befa87ee..0ee1f36a84ea2 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -8194,7 +8194,7 @@ static bool llm_load_tensors(
                         layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
 
                         layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+                        layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
                         layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
                         layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
                         layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});