Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : refactor k-shift implementation + initial defragmentation #5691

Merged
merged 25 commits into from
Feb 25, 2024

Conversation

ggerganov
Copy link
Member

@ggerganov ggerganov commented Feb 23, 2024

ref #3380

The goal is to separate the K-shift graph from the main LLM compute graph in order to be able to apply the shift on demand, instead of lazily through llama_decode(). This is a first step towards implementing more KV cache operations, such as defragmentation and compression

API changes:

  • rename llama_kv_cache_seq_shift() to llama_kv_cache_seq_add()
  • add llama_kv_cache_seq_pos_max()
  • add llama_kv_cache_defrag()
  • add llama_kv_cache_update()

The defragmentation should work, though it is probably not as performant as it could be since we are copying the KV cache data to host memory. However, since we expect to do this operation rarely, it could still be useful as it is. For now, using llama_kv_cache_defrag() is completely optional and is only demonstrated with the passkey example. Later I think we can calculate how much cells will be recovered from the defragmentation and start applying it automatically only in cases where a certain fragmentation threshold is exceeded

The next steps for improving the KV cache management will be to simplify typical usages (e.g. context shift, self-extend) with a simpler API.

  • Self-extend test:
make -j && ./passkey ./models/llama-7b-v2/ggml-model-f16.gguf 250 2 90
main: n_len = 6083, n_ctx = 8192, n_kv_req = 8224, n_grp = 2, n_batch = 512, n_junk = 250, i_pos = 90

prefix tokens: 32
prompt tokens: 6067
main: processed: [     0,    512)
main: processed: [   512,   1024)
main: processed: [  1024,   1536)
main: processed: [  1536,   2048)
main: processed: [  2048,   2560)
main: processed: [  2560,   3072)
main: processed: [  3072,   3584)
main: processed: [  3584,   4096)
main: processed: [  4096,   4608)
main: processed: [  4608,   5120)
main: processed: [  5120,   5632)
main: processed: [  5632,   6067)

main: passkey = 24287, inserted at position 90 / 250 (token pos: ~2184)

 What is the pass key? The pass key is 24287. Remember it. 24287

main: decoded 16 tokens in 0.55 s, speed: 29.11 t/s

llama_print_timings:        load time =    1275.57 ms
llama_print_timings:      sample time =       0.35 ms /    17 runs   (    0.02 ms per token, 48991.35 tokens per second)
llama_print_timings: prompt eval time =    5447.01 ms /  6067 tokens (    0.90 ms per token,  1113.82 tokens per second)
llama_print_timings:        eval time =     545.46 ms /    16 runs   (   34.09 ms per token,    29.33 tokens per second)
llama_print_timings:       total time =    6558.84 ms /  6083 tokens
  • Context shift test:
make -j && ./passkey ./models/llama-7b-v2/ggml-model-f16.gguf 250 1 130
main: n_len = 6083, n_ctx = 4096, n_kv_req = 4128, n_grp = 1, n_batch = 512, n_junk = 250, i_pos = 130

prefix tokens: 32
prompt tokens: 6067
main: processed: [     0,    512)
main: processed: [   512,   1024)
main: processed: [  1024,   1536)
main: processed: [  1536,   2048)
main: processed: [  2048,   2560)
main: processed: [  2560,   3072)
main: processed: [  3072,   3584)
main: processed: [  3584,   4096)
main: shifting KV cache with 512
(tmp log) KV defrag: move [3584, 4096) to [32, 544)
(tmp log) KV defrag cell moves: 512
(tmp log) KV defrag time: 8.020 ms
main: processed: [  4096,   4608)
main: shifting KV cache with 512
(tmp log) KV defrag: move [3584, 4096) to [544, 1056)
(tmp log) KV defrag cell moves: 512
(tmp log) KV defrag time: 7.944 ms
main: processed: [  4608,   5120)
main: shifting KV cache with 512
(tmp log) KV defrag: move [3584, 4096) to [1056, 1568)
(tmp log) KV defrag cell moves: 512
(tmp log) KV defrag time: 7.850 ms
main: processed: [  5120,   5632)
main: shifting KV cache with 512
(tmp log) KV defrag: move [3584, 4096) to [1568, 2080)
(tmp log) KV defrag cell moves: 512
(tmp log) KV defrag time: 7.908 ms
main: processed: [  5632,   6067)

main: passkey = 33977, inserted at position 130 / 250 (token pos: ~3154)

 What is the pass key? The pass key is 33977. Here we go. There and back again.

main: decoded 16 tokens in 0.49 s, speed: 32.49 t/s

llama_print_timings:        load time =    1053.46 ms
llama_print_timings:      sample time =       0.35 ms /    17 runs   (    0.02 ms per token, 49132.95 tokens per second)
llama_print_timings: prompt eval time =    5267.30 ms /  6067 tokens (    0.87 ms per token,  1151.82 tokens per second)
llama_print_timings:        eval time =     490.78 ms /    16 runs   (   30.67 ms per token,    32.60 tokens per second)
llama_print_timings:       total time =    6077.97 ms /  6083 tokens

TODO:

  • [ ] avoid node reallocations
  • rename llama_kv_cache_seq_shift() to llama_kv_cache_seq_add()
  • test that llama_kv_cache_update() works when called directly

@slaren
Copy link
Member

slaren commented Feb 23, 2024

It's ok to reuse the same scheduler, being able to use it with different graphs with the same instance was one of the goals. The reallocations are not really a performance problem because it is a very fast operation, previously essentially it was done on every eval. The exception is if it causes a buffer reallocation, but that shouldn't happen if it was properly initialized with a worst-case graph. Currently, reallocations happens all the time precisely because the K-shift is part of the worst-case graph used to initialize the scheduler, and it shouldn't be, because it has a different graph topology than normal evaluations.

Reallocations however are a problem for the implementation of pipeline parallelism, because they can cause the addresses of some tensors to change, which in turn requires a full synchronization to avoid overwriting some tensors, which breaks the parallelism. So this is something that I am working on addressing. The obvious solution is just calling ggml_backend_sched_reserve again with a normal, worst-case graph after the K-shift operation. It also needs to take into account the type of graph, because other changes in the graph would also force a reallocation, such as switching between token and embedding inputs. So as it is, we would need to keep track of the last type of graph used, and if we are evaluating a different type of graph, we should call reserve first with a worst-case graph of the same type. I really do not like adding that complexity to the applications, but I am still trying to figure a better solution.

@ggerganov
Copy link
Member Author

I see, I hadn't noticed the reserves occur on master too.

Yeah, it's a bit cumbersome to have to keep track of the graph topology manually. I'm thinking something about passing an optional callback that constructs a worst-case graph together with the graph we want to evaluate and the backend automatically decides whether to materialize the worst-case graph and reserve based on a hash of the topology (nodes + types). But it still sounds complicated, so probably not worth it

@ggerganov ggerganov changed the title llama : refactor k-shift implementation llama : refactor k-shift implementation + initial defragmentation Feb 25, 2024
@ggerganov ggerganov marked this pull request as ready for review February 25, 2024 13:43
@ggerganov ggerganov requested a review from slaren February 25, 2024 13:47
@ggerganov ggerganov marked this pull request as draft February 25, 2024 14:32
@ggerganov
Copy link
Member Author

I just realized that there is a substantial optimization possible in the defragmentation strategy. Will try to implement this, so converted back to draft

@ggerganov
Copy link
Member Author

ggerganov commented Feb 25, 2024

The initial defragmentation strategy was very naive - start from the beginning of the cache, when we find an empty cell we move the next non-empty cell in it. This caused moving the entire cache when the "hole" is in the start (typical example during context shift, where we remove the oldest cells). The implementation also prevented multi-threading and ggml_graph use because the src and dst tensors could overlap

With the new strategy, we find the "holes" starting from the beginning of the cache and fill those holes with data from the end of the cache. This now avoids overlapping src and dst buffers and is trivial to implement as ggml_graph. This allows running the defragmentation directly on the GPU. The performance is very good now

There is still some chance for a bug, so will leave wider application for next PRs

@slaren Let me know if you have concerns about the new K-shift and defrag graphs that are introduced here

@ggerganov ggerganov marked this pull request as ready for review February 25, 2024 15:48
Copy link
Member

@slaren slaren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I can tell, these graphs always do the operations inplace, and they don't allocate any tensors in the compute buffer. So I don't think it will cause any issues with the scheduler, or buffer reallocations.

@ggerganov ggerganov merged commit bf08e00 into master Feb 25, 2024
73 of 111 checks passed
@ggerganov ggerganov deleted the gg/refactor-k-shift branch February 25, 2024 20:12
@dranger003
Copy link
Contributor

dranger003 commented Feb 26, 2024

I think this commit introduces a bug causing bad model responses, at least as per tested below.

commit bf08e00 (this commit / bugged)
build\bin\Release\main.exe -p "<start_of_turn>user\nWrite an essay about AI.<end_of_turn>\n<start_of_turn>model\n" -e --temp 0 --repeat-penalty 1.0 -m ggml-gemma-7b-it-f16.gguf -ngl 128
Log start
main: build = 2264 (bf08e006)
main: built with MSVC 19.39.33520.0 for x64
main: seed  = 1708910944
ggml_init_cublas: GGML_CUDA_FORCE_MMQ:   no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
llama_model_loader: loaded meta data with 22 key-value pairs and 254 tensors from ggml-gemma-7b-it-f16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = gemma
llama_model_loader: - kv   1:                               general.name str              = gemma-7b-it
llama_model_loader: - kv   2:                       gemma.context_length u32              = 8192
llama_model_loader: - kv   3:                     gemma.embedding_length u32              = 3072
llama_model_loader: - kv   4:                          gemma.block_count u32              = 28
llama_model_loader: - kv   5:                  gemma.feed_forward_length u32              = 24576
llama_model_loader: - kv   6:                 gemma.attention.head_count u32              = 16
llama_model_loader: - kv   7:              gemma.attention.head_count_kv u32              = 16
llama_model_loader: - kv   8:     gemma.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv   9:                 gemma.attention.key_length u32              = 256
llama_model_loader: - kv  10:               gemma.attention.value_length u32              = 256
llama_model_loader: - kv  11:                       tokenizer.ggml.model str              = llama
llama_model_loader: - kv  12:                      tokenizer.ggml.tokens arr[str,256000]  = ["<pad>", "<eos>", "<bos>", "<unk>", ...
llama_model_loader: - kv  13:                      tokenizer.ggml.scores arr[f32,256000]  = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  14:                  tokenizer.ggml.token_type arr[i32,256000]  = [3, 3, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  15:                tokenizer.ggml.bos_token_id u32              = 2
llama_model_loader: - kv  16:                tokenizer.ggml.eos_token_id u32              = 1
llama_model_loader: - kv  17:            tokenizer.ggml.unknown_token_id u32              = 3
llama_model_loader: - kv  18:            tokenizer.ggml.padding_token_id u32              = 0
llama_model_loader: - kv  19:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  20:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  21:                    tokenizer.chat_template str              = {% if messages[0]['role'] == 'system'...
llama_model_loader: - type  f32:   57 tensors
llama_model_loader: - type  f16:  197 tensors
llm_load_vocab: mismatch in special tokens definition ( 416/256000 vs 260/256000 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = gemma
llm_load_print_meta: vocab type       = SPM
llm_load_print_meta: n_vocab          = 256000
llm_load_print_meta: n_merges         = 0
llm_load_print_meta: n_ctx_train      = 8192
llm_load_print_meta: n_embd           = 3072
llm_load_print_meta: n_head           = 16
llm_load_print_meta: n_head_kv        = 16
llm_load_print_meta: n_layer          = 28
llm_load_print_meta: n_rot            = 192
llm_load_print_meta: n_embd_head_k    = 256
llm_load_print_meta: n_embd_head_v    = 256
llm_load_print_meta: n_gqa            = 1
llm_load_print_meta: n_embd_k_gqa     = 4096
llm_load_print_meta: n_embd_v_gqa     = 4096
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-06
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: n_ff             = 24576
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 8192
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: model type       = 7B
llm_load_print_meta: model ftype      = F16 (guessed)
llm_load_print_meta: model params     = 8.54 B
llm_load_print_meta: model size       = 15.90 GiB (16.00 BPW)
llm_load_print_meta: general.name     = gemma-7b-it
llm_load_print_meta: BOS token        = 2 '<bos>'
llm_load_print_meta: EOS token        = 1 '<eos>'
llm_load_print_meta: UNK token        = 3 '<unk>'
llm_load_print_meta: PAD token        = 0 '<pad>'
llm_load_print_meta: LF token         = 227 '<0x0A>'
llm_load_tensors: ggml ctx size =    0.19 MiB
llm_load_tensors: offloading 28 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 29/29 layers to GPU
llm_load_tensors:        CPU buffer size =  1500.00 MiB
llm_load_tensors:      CUDA0 buffer size = 16284.67 MiB
......................................................................................
llama_new_context_with_model: n_ctx      = 512
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:      CUDA0 KV buffer size =   224.00 MiB
llama_new_context_with_model: KV self size  =  224.00 MiB, K (f16):  112.00 MiB, V (f16):  112.00 MiB
llama_new_context_with_model:  CUDA_Host input buffer size   =     8.01 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =   506.00 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =     6.00 MiB
llama_new_context_with_model: graph splits (measure): 2

system_info: n_threads = 16 / 32 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 0 | VSX = 0 | MATMUL_INT8 = 0 |
sampling:
        repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
        top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.000
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order:
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature
generate: n_ctx = 512, n_batch = 512, n_predict = -1, n_keep = 1


<start_of_turn>user
Write an essay about AI.<end_of_turn>
<start_of_turn>model
## The Looming Large: Artificial Intelligence and the Future of Humanity

The dawn of the digital age has ushered in an era of unprecedented technological advancement, one that promises to reshape the very fabric of our future, artificial intelligence (AI) revolution, a transformative force, artificial intelligence (AI)

The dawn of artificial intelligence (AI)

The tapestry of the future, a transformative force, a revolution, a transformative power and potential, a transformative force of the future, a dawn of a revolution, a tapestry of the future, a transformative force, a revolution, a transformative force, a revolution, a transformative, a tapestry of the future, a revolution, a transformative force, a revolution, a transformative, a revolution, a transformative force, a transformative, a revolution, a transformative, a burgeoning, a tapestry of the, a transformative, a, a, a transformative, the, a revolution, a transformative, a revolution, a transformative, a revolution, a landscape, a, a transformative, a, a, a technological, a force, a transformative, AI, a, one that AI, a, the, the, AI, one, a force, a force, a force, a digital, a force, a force, one that, a force, a force, a force, one that, a digital, one that AI, a, one that, a force, a force, a force, a, one that, a, a, a force, a force, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a tapestry tapestry tapestry tapestry tapestry tapestry tapestry tapestry tapestry tapestry of the, the, a, the one that has woven into the one that has the one that has the one that has woven into the one that has the one that has the the one that has the the tapestry tapestry tapestry tapestry tapestry tapestry tapestry tapestry tapestry tapestry tapestry tapestry tapestry tapestry of of the the the the one that has the the the one that has the the the the the the one that has the one that weaves a tapestry tapestry tapestry tapestry of of the tapestry tapestry tapestry tapestry tapestry tapestry tapestry of of the the the the the the the the the the the the tapestry of of the the the tapestry tapestry of unparalleled in the the the the the the the the the tapestry tapestry of unparalleled in the tapestry of unparalleled in the tapestry tapestry of unparalleled in the tapestry of unparalleled in the tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of a tapestry of of dreams dreams dreams dreams dreams dreams dreams of of dreams, a tapestry of dreams, a tapestry of dreams dreams, the tapestry of dreams dreams, a tapestry of dreams, the tapestry of dreams dreams, a weave of of dreams dreams, a dream dreams, a dream dreams, a dream dreams, a dream dreams, a dream, a dream, a dream, a dream, a dream, a dream, a dream, a dream, a dream, a dream, a dream, a dream, a dream, a dream, a dream, a dream, a dream, a dream, a dream, a dream, a dream dreams, a dream dreams. [end of text]

llama_print_timings:        load time =    7183.53 ms
llama_print_timings:      sample time =     419.85 ms /   724 runs   (    0.58 ms per token,  1724.42 tokens per second)
llama_print_timings: prompt eval time =      51.91 ms /    15 tokens (    3.46 ms per token,   288.93 tokens per second)
llama_print_timings:        eval time =   18186.32 ms /   723 runs   (   25.15 ms per token,    39.76 tokens per second)
llama_print_timings:       total time =   18956.57 ms /   738 tokens
commit f762501 (previous commit / fine)
build\bin\Release\main.exe -p "<start_of_turn>user\nWrite an essay about AI.<end_of_turn>\n<start_of_turn>model\n" -e --temp 0 --repeat-penalty 1.0 -m ggml-gemma-7b-it-f16.gguf -ngl 128
Log start
main: build = 2263 (f7625019)
main: built with MSVC 19.39.33520.0 for x64
main: seed  = 1708910756
ggml_init_cublas: GGML_CUDA_FORCE_MMQ:   no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
llama_model_loader: loaded meta data with 22 key-value pairs and 254 tensors from ggml-gemma-7b-it-f16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = gemma
llama_model_loader: - kv   1:                               general.name str              = gemma-7b-it
llama_model_loader: - kv   2:                       gemma.context_length u32              = 8192
llama_model_loader: - kv   3:                     gemma.embedding_length u32              = 3072
llama_model_loader: - kv   4:                          gemma.block_count u32              = 28
llama_model_loader: - kv   5:                  gemma.feed_forward_length u32              = 24576
llama_model_loader: - kv   6:                 gemma.attention.head_count u32              = 16
llama_model_loader: - kv   7:              gemma.attention.head_count_kv u32              = 16
llama_model_loader: - kv   8:     gemma.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv   9:                 gemma.attention.key_length u32              = 256
llama_model_loader: - kv  10:               gemma.attention.value_length u32              = 256
llama_model_loader: - kv  11:                       tokenizer.ggml.model str              = llama
llama_model_loader: - kv  12:                      tokenizer.ggml.tokens arr[str,256000]  = ["<pad>", "<eos>", "<bos>", "<unk>", ...
llama_model_loader: - kv  13:                      tokenizer.ggml.scores arr[f32,256000]  = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  14:                  tokenizer.ggml.token_type arr[i32,256000]  = [3, 3, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  15:                tokenizer.ggml.bos_token_id u32              = 2
llama_model_loader: - kv  16:                tokenizer.ggml.eos_token_id u32              = 1
llama_model_loader: - kv  17:            tokenizer.ggml.unknown_token_id u32              = 3
llama_model_loader: - kv  18:            tokenizer.ggml.padding_token_id u32              = 0
llama_model_loader: - kv  19:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  20:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  21:                    tokenizer.chat_template str              = {% if messages[0]['role'] == 'system'...
llama_model_loader: - type  f32:   57 tensors
llama_model_loader: - type  f16:  197 tensors
llm_load_vocab: mismatch in special tokens definition ( 416/256000 vs 260/256000 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = gemma
llm_load_print_meta: vocab type       = SPM
llm_load_print_meta: n_vocab          = 256000
llm_load_print_meta: n_merges         = 0
llm_load_print_meta: n_ctx_train      = 8192
llm_load_print_meta: n_embd           = 3072
llm_load_print_meta: n_head           = 16
llm_load_print_meta: n_head_kv        = 16
llm_load_print_meta: n_layer          = 28
llm_load_print_meta: n_rot            = 192
llm_load_print_meta: n_embd_head_k    = 256
llm_load_print_meta: n_embd_head_v    = 256
llm_load_print_meta: n_gqa            = 1
llm_load_print_meta: n_embd_k_gqa     = 4096
llm_load_print_meta: n_embd_v_gqa     = 4096
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-06
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: n_ff             = 24576
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 8192
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: model type       = 7B
llm_load_print_meta: model ftype      = F16 (guessed)
llm_load_print_meta: model params     = 8.54 B
llm_load_print_meta: model size       = 15.90 GiB (16.00 BPW)
llm_load_print_meta: general.name     = gemma-7b-it
llm_load_print_meta: BOS token        = 2 '<bos>'
llm_load_print_meta: EOS token        = 1 '<eos>'
llm_load_print_meta: UNK token        = 3 '<unk>'
llm_load_print_meta: PAD token        = 0 '<pad>'
llm_load_print_meta: LF token         = 227 '<0x0A>'
llm_load_tensors: ggml ctx size =    0.19 MiB
llm_load_tensors: offloading 28 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 29/29 layers to GPU
llm_load_tensors:        CPU buffer size =  1500.00 MiB
llm_load_tensors:      CUDA0 buffer size = 16284.67 MiB
......................................................................................
llama_new_context_with_model: n_ctx      = 512
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:      CUDA0 KV buffer size =   224.00 MiB
llama_new_context_with_model: KV self size  =  224.00 MiB, K (f16):  112.00 MiB, V (f16):  112.00 MiB
llama_new_context_with_model:  CUDA_Host input buffer size   =     8.01 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =   506.00 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =     6.00 MiB
llama_new_context_with_model: graph splits (measure): 3

system_info: n_threads = 16 / 32 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 0 | VSX = 0 | MATMUL_INT8 = 0 |
sampling:
        repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
        top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.000
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order:
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature
generate: n_ctx = 512, n_batch = 512, n_predict = -1, n_keep = 1


<start_of_turn>user
Write an essay about AI.<end_of_turn>
<start_of_turn>model
## The Looming Shadow of Artificial Intelligence

The burgeoning field of artificial intelligence (AI) casts a long shadow across the landscape of our modern world. From the moment we wake up to the moment we drift off to sleep, AI's presence permeates our daily routines, shaping our interactions with technology and influencing our very understanding of the universe. While the potential benefits of AI are vast and undeniable, its burgeoning presence raises crucial questions about its ethical implications and potential misuse.

One of the most tangible impacts of AI is its transformative power in the realm of healthcare. AI algorithms are revolutionizing diagnoses, drug discovery, and personalized treatment. By analyzing vast amounts of medical data, AI can identify patterns and make predictions that aid doctors in making more accurate and timely decisions. This has the potential to save lives, improve treatment outcomes, and reduce healthcare costs.

Beyond the medical sphere, AI is also revolutionizing other industries. From finance to retail, transportation to entertainment, AI is automating tasks, optimizing processes, and creating new opportunities. Self-driving cars, powered by AI, promise to revolutionize transportation, making roads safer and more efficient. AI is also transforming the way we interact with information, enabling us to access and process vast amounts of data with unprecedented ease.

However, the burgeoning presence of AI also raises concerns about its potential misuse and ethical implications. One of the most significant concerns is the potential for AI to perpetuate bias and discrimination. Algorithms are only as good as the data they are trained on, and if the data reflects existing biases, the AI system will perpetuate those biases. This raises concerns about AI perpetuating discrimination against marginalized groups, further marginalizing them in society.

Another ethical concern surrounding AI is the potential for job displacement. As AI becomes more sophisticated, it has the potential to automate many tasks currently performed by humans. This raises concerns about widespread job displacement, particularly in low-skill jobs. While this may lead to economic growth in certain sectors, it also has the potential to exacerbate economic inequality and social unrest.

The ethical implications of AI extend beyond the realm of employment and discrimination. The use of AI in warfare raises concerns about the potential for autonomous weapons systems to engage in conflict without human intervention. This raises questions about the morality of AI and its potential to lead to unintended consequences.

In conclusion, AI holds the potential to revolutionize numerous aspects of our lives, offering vast benefits across various fields. However, its burgeoning presence also raises crucial ethical questions that can lead to the potential for the human rights and the potential for human rights and biases and discrimination and bias and the potential for discrimination and bias and discrimination. It is important for human bias and discrimination, and bias and discrimination, and biases. It is important for human oversight and bias and discrimination, which raises concerns.

In conclusion, it is important for human oversight and bias and discrimination and inequality. [end of text]

llama_print_timings:        load time =    7229.30 ms
llama_print_timings:      sample time =     336.09 ms /   579 runs   (    0.58 ms per token,  1722.77 tokens per second)
llama_print_timings: prompt eval time =      50.70 ms /    15 tokens (    3.38 ms per token,   295.83 tokens per second)
llama_print_timings:        eval time =   14462.17 ms /   578 runs   (   25.02 ms per token,    39.97 tokens per second)
llama_print_timings:       total time =   15085.24 ms /   593 tokens

EDIT: This seems to affect Gemma only, I tested with Mistral and Phi and they both produce the right output. Also tested on Linux, same outcome.

ggerganov added a commit that referenced this pull request Feb 26, 2024
@ggerganov
Copy link
Member Author

@dranger003 Thanks for spotting this - should be fixed now (269de86)

@dranger003
Copy link
Contributor

Thanks, it's working now.

jordankanter pushed a commit to jordankanter/llama.cpp that referenced this pull request Mar 13, 2024
…g#5691)

* llama : refactor k-shift implementation

ggml-ci

* llama : rename llama_kv_cache_seq_shift to llama_kv_cache_seq_add

* llama : cont k-shift refactoring + normalize type names

ggml-ci

* minor : fix MPI builds

* llama : reuse n_rot from the build context

ggml-ci

* llama : revert enum name changes from this PR

ggml-ci

* llama : update llama_rope_type

* llama : add comment about rope values

* llama : fix build

* passkey : apply kv cache updates explicitly

ggml-ci

* llama : change name to llama_kv_cache_update()

* llama : add llama_kv_cache_seq_pos_max()

* passkey : fix llama_kv_cache_seq_pos_max() usage

* llama : some llama_kv_cell simplifications

* llama : add llama_kv_cache_compress (EXPERIMENTAL)

* llama : add alternative KV cache merging (EXPERIMENTAL)

* llama : add llama_kv_cache_defrag

* llama : comments

* llama : remove llama_kv_cache_compress

will add in a separate PR

ggml-ci

* llama : defragment via non-overlapping moves

* llama : ggml_graph based defrag implementation

ggml-ci

* llama : switch the loop order in build_defrag

* llama : add comments
jordankanter pushed a commit to jordankanter/llama.cpp that referenced this pull request Mar 13, 2024
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 1, 2024
…g#5691)

* llama : refactor k-shift implementation

ggml-ci

* llama : rename llama_kv_cache_seq_shift to llama_kv_cache_seq_add

* llama : cont k-shift refactoring + normalize type names

ggml-ci

* minor : fix MPI builds

* llama : reuse n_rot from the build context

ggml-ci

* llama : revert enum name changes from this PR

ggml-ci

* llama : update llama_rope_type

* llama : add comment about rope values

* llama : fix build

* passkey : apply kv cache updates explicitly

ggml-ci

* llama : change name to llama_kv_cache_update()

* llama : add llama_kv_cache_seq_pos_max()

* passkey : fix llama_kv_cache_seq_pos_max() usage

* llama : some llama_kv_cell simplifications

* llama : add llama_kv_cache_compress (EXPERIMENTAL)

* llama : add alternative KV cache merging (EXPERIMENTAL)

* llama : add llama_kv_cache_defrag

* llama : comments

* llama : remove llama_kv_cache_compress

will add in a separate PR

ggml-ci

* llama : defragment via non-overlapping moves

* llama : ggml_graph based defrag implementation

ggml-ci

* llama : switch the loop order in build_defrag

* llama : add comments
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants