Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Continuous batching load test stuck #5827

Closed
Kev1ntan opened this issue Mar 2, 2024 · 12 comments · Fixed by #5836
Closed

Continuous batching load test stuck #5827

Kev1ntan opened this issue Mar 2, 2024 · 12 comments · Fixed by #5836

Comments

@Kev1ntan
Copy link

Kev1ntan commented Mar 2, 2024

OS: Linux 2d078bb41859 5.15.0-83-generic #92~20.04.1-Ubuntu SMP Mon Aug 21 14:00:49 UTC 2023 x86_64 x86_64 x86_64 GNU/Linux

instance: 1xRTX 3090

load test tool: k6

Hi, i am doing load test for llama cpp server, but somehow the request only capped at the --parallel n. below i give the evidence
Screen Shot 2024-03-02 at 10 05 07
Screen Shot 2024-03-02 at 10 04 36
Screen Shot 2024-03-02 at 10 06 46

is the command for batch inference wrong?, because when the load test completed i try to manual send 1 request but the model didnt response anything(seems like the slot to released yet?)

any help is appreciated, thank you.

@phymbert
Copy link
Collaborator

phymbert commented Mar 2, 2024

Hi, please share the steps to reproduce your bench.

By default, the maximum number of http concurrent requests is set to the number of CPU cores. You can use --threads-http to increase it to the number of slots --parallel. @ggerganov I got your point. Let's initialize it by defautlt to n_slots.

@Kev1ntan
Copy link
Author

Kev1ntan commented Mar 3, 2024

@phymbert you can try run the server using below command:
./server -m ../models/mistral-7b-v0.1.Q8_0.gguf -c 2048 --port 9000 -ngl 33 -tb 64 -cb -np 64

then run k6 with 100 VU:

export const options = {
  vus: 100, // simulate 100 virtual users
  duration: '60s', // running the test for 60 seconds
};

if there are any http call fail, try to hit the server manually, in my case the server didnt response anything...

@phymbert
Copy link
Collaborator

phymbert commented Mar 3, 2024

I guess you have 8-16 CPU cores, so without specifying --threads-http=102 your server will stuck with 100 users:

@phymbert
Copy link
Collaborator

phymbert commented Mar 3, 2024

I did some tests, unfortunately, I have only an RTX 3050, so I tested with PHI-2 and only 32 slots and 32 users.

Using:

server --host localhost --port 8080 --model phi-2.Q4_K_M.gguf --alias phi-2 --cont-batching --metrics --parallel 32 --n-predict 32 -ngl 33 --threads-http 34 -tb 8 --batch-size 96 --ctx-size 4096 --log-format text

On:
Device 0: NVIDIA GeForce RTX 3050 Laptop GPU, compute capability 8.6, VMM: yes

K6 Script
import http from 'k6/http'
import {check, sleep} from 'k6'

export default function() {
    const data = {
        "messages": [
            {
                "role": "system",
                "content": "You are a kind AI assistant.",
            },
            {
                "role": "user",
                "content": "I believe the meaning of life is",
            }
        ],
        "model": "model",
        "max_tokens": 32,
        "stream": false,
    }
    let res = http.post('http://localhost:8080/v1/chat/completions',JSON.stringify(data), {
        headers: { 'Content-Type': 'application/json' },
    })

    check(res, {'success completion': (r) => r.status === 200})

    sleep(0.3)
}

export const options = {
    vus: 32, // simulate 100 virtual users
    duration: '300s', // running the test for 60 seconds
};

I have no issue if the parameter --http-threads 34 is set.
image

You can export server metrics during the test:

curl http://localhost:8080/metrics
# HELP llamacpp:prompt_tokens_total Number of prompt tokens processed.
# TYPE llamacpp:prompt_tokens_total counter
llamacpp:prompt_tokens_total 73074
# HELP llamacpp:tokens_predicted_total Number of generation tokens processed.
# TYPE llamacpp:tokens_predicted_total counter
llamacpp:tokens_predicted_total 39707
# HELP llamacpp:prompt_tokens_seconds Average prompt throughput in tokens/s.
# TYPE llamacpp:prompt_tokens_seconds gauge
llamacpp:prompt_tokens_seconds 68
# HELP llamacpp:predicted_tokens_seconds Average generation throughput in tokens/s.
# TYPE llamacpp:predicted_tokens_seconds gauge
llamacpp:predicted_tokens_seconds 3
# HELP llamacpp:kv_cache_usage_ratio KV-cache usage. 1 means 100 percent usage.
# TYPE llamacpp:kv_cache_usage_ratio gauge
llamacpp:kv_cache_usage_ratio 0
# HELP llamacpp:kv_cache_tokens KV-cache tokens.
# TYPE llamacpp:kv_cache_tokens gauge
llamacpp:kv_cache_tokens 1907
# HELP llamacpp:requests_processing Number of request processing.
# TYPE llamacpp:requests_processing gauge
llamacpp:requests_processing 32
# HELP llamacpp:requests_deferred Number of request deferred.
# TYPE llamacpp:requests_deferred gauge
llamacpp:requests_deferred 0

If you still face issue, please share all these step on your end.

@Kev1ntan
Copy link
Author

Kev1ntan commented Mar 3, 2024

@phymbert i just tried with ./server -m ../models/mistral-7b-v0.1.Q8_0.gguf -c 2048 --host localhost --port 9000 -ngl 33 b 64 -cb -np 64 --threads-http 66 and found something wierd:

  1. in my test before using /completion endpoint, request stuck at --paraller number
  2. current test still faced same issue with /completion endpoint
  3. using v1/chat/completions, didnt stuck anymore with --paraller number, but still got some request timeout.below is both the test i just run

with /completion:
Screen Shot 2024-03-03 at 19 16 27

with /v1/chat/completions:
Screen Shot 2024-03-03 at 19 16 59
this one is better but still got 6 request timeout from 632 reqs

do you have any insight?

@phymbert
Copy link
Collaborator

phymbert commented Mar 3, 2024

Nice to hear. Without sharing your k6 script, I cannot help so much. Sometimes, some requests can be slower for multiple reasons. You can just accept it or increase timeout. If you need HA, you can scale the number of servers.

Note, fundamentally, there is no difference between /completion and /chat/completions.

@Kev1ntan
Copy link
Author

Kev1ntan commented Mar 3, 2024

here is the k6 script @phymbert :

import http from 'k6/http';
import { sleep } from 'k6';
export const options = {
  vus: 50,
  duration: '60s',
};
export default function () {
    let headers = { 'Content-Type': 'application/json' };
    http.post('http://localhost:9000/v1/chat/completions', JSON.stringify({
        "messages": [
            {
                "role": "user",
                "content": "do you know indonesia?, if yes please describe indonesia in details"
            }
        ]
    }), { headers: headers });
    sleep(1); // virtual user will wait for 1 second before the next request
}

@Kev1ntan
Copy link
Author

Kev1ntan commented Mar 3, 2024

please try with /completion also @phymbert

@phymbert
Copy link
Collaborator

phymbert commented Mar 3, 2024

please try with /completion also @phymbert

I will test later on, but from my understanding, they are equivalent, just different data structures.
Do you see any difference in terms of tokens/s or RPM ?

@Kev1ntan
Copy link
Author

Kev1ntan commented Mar 3, 2024

please try with /completion also @phymbert

I will test later on, but from my understanding, they are equivalent, just different data structures. Do you see any difference in terms of tokens/s or RPM ?

/completion

import http from 'k6/http';
import { sleep } from 'k6';
export const options = {
  vus: 60, 
  duration: '60s', 
};
export default function () {
    let headers = { 'Content-Type': 'application/json' };
    let response = http.post('http://localhost:9000/completion', JSON.stringify({
        "prompt": "i am"
    }), { headers: headers });
    sleep(1); // virtual user will wait for 1 second before the next request
}

results:

# HELP llamacpp:prompt_tokens_total Number of prompt tokens processed.
# TYPE llamacpp:prompt_tokens_total counter
llamacpp:prompt_tokens_total 216
# HELP llamacpp:tokens_predicted_total Number of generation tokens processed.
# TYPE llamacpp:tokens_predicted_total counter
llamacpp:tokens_predicted_total 6075
# HELP llamacpp:prompt_tokens_seconds Average prompt throughput in tokens/s.
# TYPE llamacpp:prompt_tokens_seconds gauge
llamacpp:prompt_tokens_seconds 27
# HELP llamacpp:predicted_tokens_seconds Average generation throughput in tokens/s.
# TYPE llamacpp:predicted_tokens_seconds gauge
llamacpp:predicted_tokens_seconds 3
# HELP llamacpp:kv_cache_usage_ratio KV-cache usage. 1 means 100 percent usage.
# TYPE llamacpp:kv_cache_usage_ratio gauge
llamacpp:kv_cache_usage_ratio 0
# HELP llamacpp:kv_cache_tokens KV-cache tokens.
# TYPE llamacpp:kv_cache_tokens gauge
llamacpp:kv_cache_tokens 1664
# HELP llamacpp:requests_processing Number of request processing.
# TYPE llamacpp:requests_processing gauge
llamacpp:requests_processing 49
# HELP llamacpp:requests_deferred Number of request deferred.
# TYPE llamacpp:requests_deferred gauge
llamacpp:requests_deferred 0
Screen Shot 2024-03-03 at 21 00 35

/chat/completions

import http from 'k6/http';
import { sleep } from 'k6';
export const options = {
  vus: 60,
  duration: '60s',
};
export default function () {
    let headers = { 'Content-Type': 'application/json' };
    http.post('http://localhost:9000/v1/chat/completions', JSON.stringify({
        "messages": [
            {
                "role": "user",
                "content": "do you know indonesia?, if yes please describe indonesia in details"
            }
        ],
    }), { headers: headers });
    sleep(1); // virtual user will wait for 1 second before the next request
}

results:

# HELP llamacpp:prompt_tokens_total Number of prompt tokens processed.
# TYPE llamacpp:prompt_tokens_total counter
llamacpp:prompt_tokens_total 18788
# HELP llamacpp:tokens_predicted_total Number of generation tokens processed.
# TYPE llamacpp:tokens_predicted_total counter
llamacpp:tokens_predicted_total 4697
# HELP llamacpp:prompt_tokens_seconds Average prompt throughput in tokens/s.
# TYPE llamacpp:prompt_tokens_seconds gauge
llamacpp:prompt_tokens_seconds 72
# HELP llamacpp:predicted_tokens_seconds Average generation throughput in tokens/s.
# TYPE llamacpp:predicted_tokens_seconds gauge
llamacpp:predicted_tokens_seconds 0
# HELP llamacpp:kv_cache_usage_ratio KV-cache usage. 1 means 100 percent usage.
# TYPE llamacpp:kv_cache_usage_ratio gauge
llamacpp:kv_cache_usage_ratio 0
# HELP llamacpp:kv_cache_tokens KV-cache tokens.
# TYPE llamacpp:kv_cache_tokens gauge
llamacpp:kv_cache_tokens 1216
# HELP llamacpp:requests_processing Number of request processing.
# TYPE llamacpp:requests_processing gauge
llamacpp:requests_processing 0
# HELP llamacpp:requests_deferred Number of request deferred.
# TYPE llamacpp:requests_deferred gauge
llamacpp:requests_deferred 0
Screen Shot 2024-03-03 at 21 01 08

@phymbert
Copy link
Collaborator

phymbert commented Mar 3, 2024

max_tokens is n_predict in /completionand messages are in input. please see the Readme

@Kev1ntan
Copy link
Author

Kev1ntan commented Mar 3, 2024

max_tokens is n_predict in /completionand messages are in input. please see the Readme

Screen Shot 2024-03-03 at 21 16 03 that one right?, using prompt also working Screen Shot 2024-03-03 at 21 18 32

phymbert added a commit that referenced this issue Mar 8, 2024
phymbert added a commit that referenced this issue Mar 8, 2024
phymbert added a commit that referenced this issue Mar 8, 2024
phymbert added a commit that referenced this issue Mar 9, 2024
…mparison (#5941)

* server: bench: Init a bench scenario with K6
See #5827

* server: bench: EOL EOF

* server: bench: PR feedback and improved k6 script configuration

* server: bench: remove llamacpp_completions_tokens_seconds as it include prompt processing time and it's misleading

server: bench: add max_tokens from SERVER_BENCH_MAX_TOKENS

server: bench: increase truncated rate to 80% before failing

* server: bench: fix doc

* server: bench: change gauge custom metrics to trend

* server: bench: change gauge custom metrics to trend
server: bench: add trend custom metrics for total tokens per second average

* server: bench: doc add an option to debug http request

* server: bench: filter dataset too short and too long sequences

* server: bench: allow to filter out conversation in the dataset based on env variable

* server: bench: fix assistant message sent instead of user message

* server: bench: fix assistant message sent instead of user message

* server : add defrag thold parameter

* server: bench: select prompts based on the current iteration id not randomly to make the bench more reproducible

---------

Co-authored-by: Georgi Gerganov <[email protected]>
hazelnutcloud pushed a commit to hazelnutcloud/llama.cpp that referenced this issue Mar 10, 2024
…mparison (ggml-org#5941)

* server: bench: Init a bench scenario with K6
See ggml-org#5827

* server: bench: EOL EOF

* server: bench: PR feedback and improved k6 script configuration

* server: bench: remove llamacpp_completions_tokens_seconds as it include prompt processing time and it's misleading

server: bench: add max_tokens from SERVER_BENCH_MAX_TOKENS

server: bench: increase truncated rate to 80% before failing

* server: bench: fix doc

* server: bench: change gauge custom metrics to trend

* server: bench: change gauge custom metrics to trend
server: bench: add trend custom metrics for total tokens per second average

* server: bench: doc add an option to debug http request

* server: bench: filter dataset too short and too long sequences

* server: bench: allow to filter out conversation in the dataset based on env variable

* server: bench: fix assistant message sent instead of user message

* server: bench: fix assistant message sent instead of user message

* server : add defrag thold parameter

* server: bench: select prompts based on the current iteration id not randomly to make the bench more reproducible

---------

Co-authored-by: Georgi Gerganov <[email protected]>
NeoZhangJianyu pushed a commit to NeoZhangJianyu/llama.cpp that referenced this issue Mar 12, 2024
…mparison (ggml-org#5941)

* server: bench: Init a bench scenario with K6
See ggml-org#5827

* server: bench: EOL EOF

* server: bench: PR feedback and improved k6 script configuration

* server: bench: remove llamacpp_completions_tokens_seconds as it include prompt processing time and it's misleading

server: bench: add max_tokens from SERVER_BENCH_MAX_TOKENS

server: bench: increase truncated rate to 80% before failing

* server: bench: fix doc

* server: bench: change gauge custom metrics to trend

* server: bench: change gauge custom metrics to trend
server: bench: add trend custom metrics for total tokens per second average

* server: bench: doc add an option to debug http request

* server: bench: filter dataset too short and too long sequences

* server: bench: allow to filter out conversation in the dataset based on env variable

* server: bench: fix assistant message sent instead of user message

* server: bench: fix assistant message sent instead of user message

* server : add defrag thold parameter

* server: bench: select prompts based on the current iteration id not randomly to make the bench more reproducible

---------

Co-authored-by: Georgi Gerganov <[email protected]>
jordankanter pushed a commit to jordankanter/llama.cpp that referenced this issue Mar 13, 2024
…mparison (ggml-org#5941)

* server: bench: Init a bench scenario with K6
See ggml-org#5827

* server: bench: EOL EOF

* server: bench: PR feedback and improved k6 script configuration

* server: bench: remove llamacpp_completions_tokens_seconds as it include prompt processing time and it's misleading

server: bench: add max_tokens from SERVER_BENCH_MAX_TOKENS

server: bench: increase truncated rate to 80% before failing

* server: bench: fix doc

* server: bench: change gauge custom metrics to trend

* server: bench: change gauge custom metrics to trend
server: bench: add trend custom metrics for total tokens per second average

* server: bench: doc add an option to debug http request

* server: bench: filter dataset too short and too long sequences

* server: bench: allow to filter out conversation in the dataset based on env variable

* server: bench: fix assistant message sent instead of user message

* server: bench: fix assistant message sent instead of user message

* server : add defrag thold parameter

* server: bench: select prompts based on the current iteration id not randomly to make the bench more reproducible

---------

Co-authored-by: Georgi Gerganov <[email protected]>
hodlen pushed a commit to hodlen/llama.cpp that referenced this issue Apr 1, 2024
…mparison (ggml-org#5941)

* server: bench: Init a bench scenario with K6
See ggml-org#5827

* server: bench: EOL EOF

* server: bench: PR feedback and improved k6 script configuration

* server: bench: remove llamacpp_completions_tokens_seconds as it include prompt processing time and it's misleading

server: bench: add max_tokens from SERVER_BENCH_MAX_TOKENS

server: bench: increase truncated rate to 80% before failing

* server: bench: fix doc

* server: bench: change gauge custom metrics to trend

* server: bench: change gauge custom metrics to trend
server: bench: add trend custom metrics for total tokens per second average

* server: bench: doc add an option to debug http request

* server: bench: filter dataset too short and too long sequences

* server: bench: allow to filter out conversation in the dataset based on env variable

* server: bench: fix assistant message sent instead of user message

* server: bench: fix assistant message sent instead of user message

* server : add defrag thold parameter

* server: bench: select prompts based on the current iteration id not randomly to make the bench more reproducible

---------

Co-authored-by: Georgi Gerganov <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants