-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
/
Copy pathscheduler.py
1462 lines (1255 loc) · 61.8 KB
/
scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import enum
import os
import random
import time
from collections import deque
from dataclasses import dataclass, field
from typing import (Callable, Deque, Dict, Iterable, List, Optional, Set,
Tuple, Union)
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceGroupMetadataDelta,
SequenceStatus)
from vllm.utils import Device, PyObjectCache
logger = init_logger(__name__)
# Test-only. If configured, decode is preempted with
# ARTIFICIAL_PREEMPTION_PROB% probability.
ENABLE_ARTIFICIAL_PREEMPT = bool(
os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa
ARTIFICIAL_PREEMPTION_PROB = 0.5
ARTIFICIAL_PREEMPTION_MAX_CNT = 500
class PreemptionMode(enum.Enum):
"""Preemption modes.
1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
and swap them back in when the sequences are resumed.
2. Recomputation: Discard the blocks of the preempted sequences and
recompute them when the sequences are resumed, treating the sequences as
new prompts.
"""
SWAP = enum.auto()
RECOMPUTE = enum.auto()
@dataclass
class SchedulingBudget:
"""The available slots for scheduling.
TODO(sang): Right now, the budget is request_id-aware meaning it can ignore
budget update from the same request_id. It is because in normal scheduling
path, we update RUNNING num_seqs ahead of time, meaning it could be
updated more than once when scheduling RUNNING requests. Since this won't
happen if we only have chunked prefill scheduling, we can remove this
feature from the API when chunked prefill is enabled by default.
"""
token_budget: int
max_num_seqs: int
_request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
_request_ids_num_curr_seqs: Set[str] = field(default_factory=set)
_num_batched_tokens: int = 0
_num_curr_seqs: int = 0
def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
assert num_new_tokens != 0
assert num_new_seqs != 0
return (self.num_batched_tokens + num_new_tokens <= self.token_budget
and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs)
def remaining_token_budget(self):
return self.token_budget - self.num_batched_tokens
def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int):
if req_id in self._request_ids_num_batched_tokens:
return
self._request_ids_num_batched_tokens.add(req_id)
self._num_batched_tokens += num_batched_tokens
def subtract_num_batched_tokens(self, req_id: str,
num_batched_tokens: int):
if req_id in self._request_ids_num_batched_tokens:
self._request_ids_num_batched_tokens.remove(req_id)
self._num_batched_tokens -= num_batched_tokens
def add_num_seqs(self, req_id: str, num_curr_seqs: int):
if req_id in self._request_ids_num_curr_seqs:
return
self._request_ids_num_curr_seqs.add(req_id)
self._num_curr_seqs += num_curr_seqs
def subtract_num_seqs(self, req_id: str, num_curr_seqs: int):
if req_id in self._request_ids_num_curr_seqs:
self._request_ids_num_curr_seqs.remove(req_id)
self._num_curr_seqs -= num_curr_seqs
@property
def num_batched_tokens(self):
return self._num_batched_tokens
@property
def num_curr_seqs(self):
return self._num_curr_seqs
@dataclass
class ScheduledSequenceGroup:
# A sequence group that's scheduled.
seq_group: SequenceGroup
# The total chunk size (number of tokens) to process for next iteration.
# 1 for decoding. Same as prompt tokens for prefill, but if prefill is
# chunked, it can be smaller than that.
token_chunk_size: int
@dataclass
class SchedulerOutputs:
"""The scheduling decision made from a scheduler."""
# Scheduled sequence groups.
scheduled_seq_groups: Iterable[ScheduledSequenceGroup]
# Number of prefill groups scheduled.
num_prefill_groups: int
# Total number of batched tokens.
num_batched_tokens: int
# Blocks to swap in. List of CPU -> GPU block number.
blocks_to_swap_in: List[Tuple[int, int]]
# Blocks to swap out. List of GPU -> CPU block number.
blocks_to_swap_out: List[Tuple[int, int]]
# Blocks to copy. Source to dest block.
blocks_to_copy: List[Tuple[int, int]]
# Sequence groups that are going to be ignored.
ignored_seq_groups: List[SequenceGroup]
# The number of slots for lookahead decoding.
num_lookahead_slots: int
# The number of requests in the running queue
running_queue_size: int
preempted: int
def __post_init__(self):
# Swap in and swap out should never happen at the same time.
assert not (self.blocks_to_swap_in and self.blocks_to_swap_out)
self.num_loras: int = len(self.lora_requests)
if self.num_loras > 0:
self._sort_by_lora_ids()
self.num_prompt_adapters: int = len(self.prompt_adapter_requests)
def is_empty(self) -> bool:
# NOTE: We do not consider the ignored sequence groups.
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
and not self.blocks_to_swap_out and not self.blocks_to_copy)
def _sort_by_lora_ids(self):
self.scheduled_seq_groups = sorted(
self.scheduled_seq_groups,
key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
@property
def lora_requests(self) -> Set[LoRARequest]:
return {
g.seq_group.lora_request
for g in self.scheduled_seq_groups
if g.seq_group.lora_request is not None
}
@property
def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]:
return {
g.seq_group.prompt_adapter_request
for g in self.scheduled_seq_groups
if g.seq_group.prompt_adapter_request is not None
}
@dataclass
class SchedulerRunningOutputs:
"""The requests that are scheduled from a running queue.
Could contain prefill (prefill that's chunked) or decodes. If there's not
enough memory, it can be preempted (for recompute) or swapped out.
"""
# Selected sequences that are running and in a decoding phase.
decode_seq_groups: List[ScheduledSequenceGroup]
# Selected sequences that are running and in a prefill phase.
# I.e., it means the prefill has been chunked.
prefill_seq_groups: List[ScheduledSequenceGroup]
# The preempted sequences.
preempted: List[SequenceGroup]
# Sequences that are swapped out.
swapped_out: List[SequenceGroup]
# The blocks to swap out.
blocks_to_swap_out: List[Tuple[int, int]]
# The blocks to copy.
blocks_to_copy: List[Tuple[int, int]]
# The number of slots for lookahead decoding.
num_lookahead_slots: int
# Optimization for fast-access to seq_group lists
decode_seq_groups_list: List[SequenceGroup]
prefill_seq_groups_list: List[SequenceGroup]
@classmethod
def create_empty(cls) -> "SchedulerRunningOutputs":
return SchedulerRunningOutputs(
decode_seq_groups=[],
prefill_seq_groups=[],
preempted=[],
swapped_out=[],
blocks_to_swap_out=[],
blocks_to_copy=[],
num_lookahead_slots=0,
decode_seq_groups_list=[],
prefill_seq_groups_list=[],
)
@dataclass
class SchedulerSwappedInOutputs:
"""The requests that are scheduled from a swap queue.
Could contain prefill (prefill that's chunked) or decodes.
"""
# Selected sequences that are going to be swapped in and is in a
# decoding phase.
decode_seq_groups: List[SequenceGroup]
# Selected sequences that are going to be swapped in and in a prefill
# phase. I.e., it means the prefill has been chunked.
prefill_seq_groups: List[SequenceGroup]
# The blocks to swap in.
blocks_to_swap_in: List[Tuple[int, int]]
# The blocks to copy.
blocks_to_copy: List[Tuple[int, int]]
# The number of slots for lookahead decoding.
num_lookahead_slots: int
# Infeasible sequence groups.
infeasible_seq_groups: List[SequenceGroup]
@classmethod
def create_empty(cls) -> "SchedulerSwappedInOutputs":
return SchedulerSwappedInOutputs(
decode_seq_groups=[],
prefill_seq_groups=[],
blocks_to_swap_in=[],
blocks_to_copy=[],
num_lookahead_slots=0,
infeasible_seq_groups=[],
)
@dataclass
class SchedulerPrefillOutputs:
"""The requests that are scheduled from a waiting queue.
Could contain a fresh prefill requests or preempted requests that need
to be recomputed from scratch.
"""
# Selected sequences for prefill.
seq_groups: List[SequenceGroup]
# Ignored sequence groups.
ignored_seq_groups: List[SequenceGroup]
num_lookahead_slots: int
@classmethod
def create_empty(cls) -> "SchedulerPrefillOutputs":
return SchedulerPrefillOutputs(
seq_groups=[],
ignored_seq_groups=[],
num_lookahead_slots=0,
)
def seq_group_metadata_builder():
return SequenceGroupMetadata(request_id="",
is_prompt=False,
seq_data={},
sampling_params=None,
block_tables={})
def scheduler_running_outputs_builder():
return SchedulerRunningOutputs(decode_seq_groups=[],
prefill_seq_groups=[],
preempted=[],
swapped_out=[],
blocks_to_swap_out=[],
blocks_to_copy=[],
num_lookahead_slots=0,
prefill_seq_groups_list=[],
decode_seq_groups_list=[])
def scheduled_seq_group_builder():
return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)
class Scheduler:
def __init__(
self,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
pipeline_parallel_size: int = 1,
output_proc_callback_fn: Optional[Callable] = None,
) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
# Note for LoRA scheduling: the current policy is extremely
# simple and NOT fair. It can lead to starvation of some
# LoRAs. This should be improved in the future.
self.lora_config = lora_config
version = "v1"
if self.scheduler_config.use_v2_block_manager:
version = "v2"
if self.scheduler_config.embedding_mode:
version = "embedding"
BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
version)
num_gpu_blocks = cache_config.num_gpu_blocks
if num_gpu_blocks:
num_gpu_blocks //= pipeline_parallel_size
num_cpu_blocks = cache_config.num_cpu_blocks
if num_cpu_blocks:
num_cpu_blocks //= pipeline_parallel_size
# Create the block space manager.
self.block_manager = BlockSpaceManagerImpl(
block_size=self.cache_config.block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
sliding_window=self.cache_config.sliding_window,
enable_caching=self.cache_config.enable_prefix_caching)
# Sequence groups in the WAITING state.
# Contain new prefill or preempted requests.
self.waiting: Deque[SequenceGroup] = deque()
# Sequence groups in the RUNNING state.
# Contain decode requests.
self.running: Deque[SequenceGroup] = deque()
# Sequence groups in the SWAPPED state.
# Contain decode requests that are swapped out.
self.swapped: Deque[SequenceGroup] = deque()
# Sequence groups finished requests ids since last step iteration.
# It lets the model know that any state associated with these requests
# can and must be released after the current step.
# This is used to evict the finished requests from the Mamba cache.
self._finished_requests_ids: List[str] = list()
# Time at previous scheduling step
self.prev_time = 0.0
# Did we schedule a prompt at previous step?
self.prev_prompt = False
# Latency of the last prompt step
self.last_prompt_latency = 0.0
# preemption mode, RECOMPUTE or SWAP
self.user_specified_preemption_mode = scheduler_config.preemption_mode
# The following field is test-only. It is used to inject artificial
# preemption.
self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT
self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT
if self.enable_artificial_preemption
else 0)
self.num_cumulative_preemption: int = 0
# Used to cache python objects
self._seq_group_metadata_cache: List[PyObjectCache] = []
self._scheduler_running_outputs_cache: List[PyObjectCache] = []
self._scheduled_seq_group_cache: List[PyObjectCache] = []
# For async output processing, we need to swap cache buffers between
# iterations. I.e. since the output processing is lagged one step,
# we cannot reuse the cached objects immediately when the schedule()
# is called again, but only when schedule() is called the second time.
self.output_proc_callback_fn = output_proc_callback_fn
self.use_async_output_proc = self.output_proc_callback_fn is not None
self.num_cache_iters = 2 if self.use_async_output_proc else 1
self.cache_id = 0
for i in range(self.num_cache_iters):
self._seq_group_metadata_cache.append(
PyObjectCache(seq_group_metadata_builder))
self._scheduler_running_outputs_cache.append(
PyObjectCache(scheduler_running_outputs_builder))
self._scheduled_seq_group_cache.append(
PyObjectCache(scheduled_seq_group_builder))
# For async postprocessor, the extra decode run cannot be done
# when the request reaches max_model_len. In this case, the request
# will be stopped during schedule() call and added to this stop list
# for processing and deallocation by the free_finished_seq_groups()
self._async_stopped: List[SequenceGroup] = []
@property
def next_cache_id(self):
return (self.cache_id + 1) % self.num_cache_iters
@property
def lora_enabled(self) -> bool:
return bool(self.lora_config)
@property
def num_decoding_tokens_per_seq(self) -> int:
"""The number of new tokens."""
return 1
def add_seq_group(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the waiting queue.
self.waiting.append(seq_group)
def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the running queue.
# Only for testing purposes.
self.running.append(seq_group)
def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the swapped queue.
# Only for testing purposes.
self.swapped.append(seq_group)
def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
"""Aborts a sequence group with the given ID.
Check if the sequence group with the given ID
is present in any of the state queue.
If present, remove the sequence group from the state queue.
Also, if any of the sequences in the sequence group is not finished,
free the sequence with status `FINISHED_ABORTED`.
Otherwise, do nothing.
Args:
request_id: The ID(s) of the sequence group to abort.
"""
if isinstance(request_id, str):
request_id = (request_id, )
request_ids = set(request_id)
for state_queue in [self.waiting, self.running, self.swapped]:
aborted_groups: List[SequenceGroup] = []
for seq_group in state_queue:
if not request_ids:
# Using 'break' here may add two extra iterations,
# but is acceptable to reduce complexity.
break
if seq_group.request_id in request_ids:
# Appending aborted group into pending list.
aborted_groups.append(seq_group)
request_ids.remove(seq_group.request_id)
for aborted_group in aborted_groups:
# Remove the sequence group from the state queue.
state_queue.remove(aborted_group)
# Remove the aborted request from the Mamba cache.
self._finished_requests_ids.append(aborted_group.request_id)
for seq in aborted_group.get_seqs():
if seq.is_finished():
continue
seq.status = SequenceStatus.FINISHED_ABORTED
self.free_seq(seq)
self._free_seq_group_cross_attn_blocks(aborted_group)
def _free_seq_group_cross_attn_blocks(
self,
seq_group: SequenceGroup,
) -> None:
"""
Free a sequence group from a cross-attention block table.
Has no effect on decoder-only models.
"""
if seq_group.is_encoder_decoder():
self.block_manager.free_cross(seq_group)
def has_unfinished_seqs(self) -> bool:
return len(self.waiting) != 0 or len(self.running) != 0 or len(
self.swapped) != 0
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device)
def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)
def get_and_reset_finished_requests_ids(self) -> List[str]:
"""Flushes the list of request ids of previously finished seq_groups."""
finished_requests_ids = self._finished_requests_ids
self._finished_requests_ids = list()
return finished_requests_ids
def _schedule_running(
self,
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
enable_chunking: bool = False,
) -> SchedulerRunningOutputs:
"""Schedule sequence groups that are running.
Running queue should include decode and chunked prefill requests.
Args:
budget: The scheduling budget. The argument is in-place updated
when any decodes are preempted.
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any decodes are preempted.
enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
Returns:
SchedulerRunningOutputs.
"""
ret: SchedulerRunningOutputs = \
self._scheduler_running_outputs_cache[self.cache_id].get_object()
ret.blocks_to_swap_out.clear()
ret.blocks_to_copy.clear()
ret.decode_seq_groups.clear()
ret.prefill_seq_groups.clear()
ret.preempted.clear()
ret.swapped_out.clear()
ret.num_lookahead_slots = self._get_num_lookahead_slots(
is_prefill=False)
ret.decode_seq_groups_list.clear()
ret.prefill_seq_groups_list.clear()
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_out: List[Tuple[int, int]] = ret.blocks_to_swap_out
blocks_to_copy: List[Tuple[int, int]] = ret.blocks_to_copy
decode_seq_groups: List[ScheduledSequenceGroup] = ret.decode_seq_groups
prefill_seq_groups: List[
ScheduledSequenceGroup] = ret.prefill_seq_groups
preempted: List[SequenceGroup] = ret.preempted
swapped_out: List[SequenceGroup] = ret.swapped_out
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
# Store original running requests for the case of async + preemption
if self.use_async_output_proc:
orig_running = self.running.copy()
running_queue = self.running
assert len(self._async_stopped) == 0
while running_queue:
seq_group = running_queue[0]
num_running_tokens = self._get_num_new_tokens(
seq_group, SequenceStatus.RUNNING, enable_chunking, budget)
if num_running_tokens == 0:
break
running_queue.popleft()
# With async postprocessor, an extra decode run is done
# to process the final tokens. The check below avoids this extra
# decode run when the model max len is reached, in order to avoid
# a memory overflow.
if self.use_async_output_proc and seq_group.seqs[0].get_len(
) > self.scheduler_config.max_model_len:
self._async_stopped.append(seq_group)
continue
# With async postprocessor, when preemption kicks in, we need
# first to drain the async postprocessor, so that all async
# block_table freeing is applied before the preemption freeing
# is applied.
if self.use_async_output_proc and not self._can_append_slots(
seq_group):
tmp = self.running
self.running = orig_running
assert self.output_proc_callback_fn is not None
self.output_proc_callback_fn(is_async=True)
self.running = tmp
while not self._can_append_slots(seq_group):
budget.subtract_num_batched_tokens(seq_group.request_id,
num_running_tokens)
num_running_seqs = seq_group.get_max_num_running_seqs()
budget.subtract_num_seqs(seq_group.request_id,
num_running_seqs)
if (curr_loras is not None and seq_group.lora_int_id > 0
and seq_group.lora_int_id in curr_loras):
curr_loras.remove(seq_group.lora_int_id)
if running_queue:
# Preempt the lowest-priority sequence groups.
victim_seq_group = running_queue.pop()
preempted_mode = self._preempt(victim_seq_group,
blocks_to_swap_out)
if preempted_mode == PreemptionMode.RECOMPUTE:
preempted.append(victim_seq_group)
else:
swapped_out.append(victim_seq_group)
else:
# No other sequence groups can be preempted.
# Preempt the current sequence group.
preempted_mode = self._preempt(seq_group,
blocks_to_swap_out)
if preempted_mode == PreemptionMode.RECOMPUTE:
preempted.append(seq_group)
else:
swapped_out.append(seq_group)
break
else:
self._append_slots(seq_group, blocks_to_copy)
is_prefill = seq_group.is_prefill()
scheduled_seq_group: ScheduledSequenceGroup = \
self._scheduled_seq_group_cache[self.cache_id].get_object()
scheduled_seq_group.seq_group = seq_group
if is_prefill:
scheduled_seq_group.token_chunk_size = num_running_tokens
prefill_seq_groups.append(scheduled_seq_group)
ret.prefill_seq_groups_list.append(seq_group)
else:
scheduled_seq_group.token_chunk_size = 1
decode_seq_groups.append(scheduled_seq_group)
ret.decode_seq_groups_list.append(seq_group)
budget.add_num_batched_tokens(seq_group.request_id,
num_running_tokens)
# OPTIMIZATION: Note that get_max_num_running_seqs is
# expensive. For the default scheduling chase where
# enable_chunking is False, num_seqs are updated before running
# this method, so we don't have to update it again here.
if enable_chunking:
num_running_seqs = seq_group.get_max_num_running_seqs()
budget.add_num_seqs(seq_group.request_id, num_running_seqs)
if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.add(seq_group.lora_int_id)
self._scheduler_running_outputs_cache[self.next_cache_id].reset()
self._scheduled_seq_group_cache[self.next_cache_id].reset()
return ret
def _schedule_swapped(
self,
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
enable_chunking: bool = False,
) -> SchedulerSwappedInOutputs:
"""Schedule sequence groups that are swapped out.
It schedules swapped requests as long as it fits `budget` and
curr_loras <= max_lora from the scheduling config. The input arguments
`budget` and `curr_loras` are updated based on scheduled seq_groups.
Args:
budget: The scheduling budget. The argument is in-place updated
when any requests are swapped in.
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any requests are swapped in.
enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
Returns:
SchedulerSwappedInOutputs.
"""
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_in: List[Tuple[int, int]] = []
blocks_to_copy: List[Tuple[int, int]] = []
decode_seq_groups: List[ScheduledSequenceGroup] = []
prefill_seq_groups: List[ScheduledSequenceGroup] = []
infeasible_seq_groups: List[SequenceGroup] = []
swapped_queue = self.swapped
leftover_swapped: Deque[SequenceGroup] = deque()
while swapped_queue:
seq_group = swapped_queue[0]
# If the sequence group cannot be swapped in, stop.
is_prefill = seq_group.is_prefill()
alloc_status = self.block_manager.can_swap_in(
seq_group, self._get_num_lookahead_slots(is_prefill))
if alloc_status == AllocStatus.LATER:
break
elif alloc_status == AllocStatus.NEVER:
logger.warning(
"Failing the request %s because there's not enough kv "
"cache blocks to run the entire sequence.",
seq_group.request_id)
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_IGNORED
infeasible_seq_groups.append(seq_group)
swapped_queue.popleft()
continue
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
assert curr_loras is not None
assert self.lora_config is not None
if (lora_int_id > 0 and (lora_int_id not in curr_loras)
and len(curr_loras) >= self.lora_config.max_loras):
# We don't have a space for another LoRA, so
# we ignore this request for now.
leftover_swapped.appendleft(seq_group)
swapped_queue.popleft()
continue
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs = seq_group.get_max_num_running_seqs()
num_new_tokens = self._get_num_new_tokens(seq_group,
SequenceStatus.SWAPPED,
enable_chunking, budget)
if (num_new_tokens == 0
or not budget.can_schedule(num_new_tokens=num_new_tokens,
num_new_seqs=num_new_seqs)):
break
if lora_int_id > 0 and curr_loras is not None:
curr_loras.add(lora_int_id)
swapped_queue.popleft()
self._swap_in(seq_group, blocks_to_swap_in)
self._append_slots(seq_group, blocks_to_copy)
is_prefill = seq_group.is_prefill()
if is_prefill:
prefill_seq_groups.append(
ScheduledSequenceGroup(seq_group,
token_chunk_size=num_new_tokens))
else:
decode_seq_groups.append(
ScheduledSequenceGroup(seq_group, token_chunk_size=1))
budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
budget.add_num_seqs(seq_group.request_id, num_new_seqs)
swapped_queue.extendleft(leftover_swapped)
return SchedulerSwappedInOutputs(
decode_seq_groups=decode_seq_groups,
prefill_seq_groups=prefill_seq_groups,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_copy=blocks_to_copy,
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=False),
infeasible_seq_groups=infeasible_seq_groups,
)
def _get_prompt_limit(self, seq_group: SequenceGroup) -> int:
if self.scheduler_config.chunked_prefill_enabled:
prompt_limit = self.scheduler_config.max_model_len
else:
prompt_limit = min(self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens)
# Model is fine tuned with long context. Return the fine tuned max_len.
if (seq_group.lora_request
and seq_group.lora_request.long_lora_max_len):
assert prompt_limit <= seq_group.lora_request.long_lora_max_len
return seq_group.lora_request.long_lora_max_len
else:
return prompt_limit
def _schedule_prefills(
self,
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
enable_chunking: bool = False,
) -> SchedulerPrefillOutputs:
"""Schedule sequence groups that are in prefill stage.
Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
as a new prefill (that starts from beginning -> most recently generated
tokens).
It schedules waiting requests as long as it fits `budget` and
curr_loras <= max_lora from the scheduling config. The input arguments
`budget` and `curr_loras` are updated based on scheduled seq_groups.
Args:
budget: The scheduling budget. The argument is in-place updated
when any requests are scheduled.
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any requests are scheduled.
enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
Returns:
SchedulerPrefillOutputs.
"""
ignored_seq_groups: List[SequenceGroup] = []
seq_groups: List[SequenceGroup] = []
waiting_queue = self.waiting
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
while self._passed_delay(time.time()) and waiting_queue:
seq_group = waiting_queue[0]
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
assert len(waiting_seqs) == 1, (
"Waiting sequence group should have only one prompt "
"sequence.")
num_new_tokens = self._get_num_new_tokens(seq_group,
SequenceStatus.WAITING,
enable_chunking, budget)
if not enable_chunking:
num_prompt_tokens = waiting_seqs[0].get_len()
assert num_new_tokens == num_prompt_tokens
prompt_limit = self._get_prompt_limit(seq_group)
if num_new_tokens > prompt_limit:
logger.warning(
"Input prompt (%d tokens) is too long"
" and exceeds limit of %d", num_new_tokens, prompt_limit)
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
waiting_queue.popleft()
continue
# If the sequence group cannot be allocated, stop.
can_allocate = self.block_manager.can_allocate(seq_group)
if can_allocate == AllocStatus.LATER:
break
elif can_allocate == AllocStatus.NEVER:
logger.warning(
"Input prompt (%d tokens) is too long"
" and exceeds the capacity of block_manager",
num_new_tokens)
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
waiting_queue.popleft()
continue
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
assert curr_loras is not None
assert self.lora_config is not None
if (self.lora_enabled and lora_int_id > 0
and lora_int_id not in curr_loras
and len(curr_loras) >= self.lora_config.max_loras):
# We don't have a space for another LoRA, so
# we ignore this request for now.
leftover_waiting_sequences.appendleft(seq_group)
waiting_queue.popleft()
continue
num_new_seqs = seq_group.get_max_num_running_seqs()
if (num_new_tokens == 0
or not budget.can_schedule(num_new_tokens=num_new_tokens,
num_new_seqs=num_new_seqs)):
break
# Can schedule this request.
if curr_loras is not None and lora_int_id > 0:
curr_loras.add(lora_int_id)
waiting_queue.popleft()
self._allocate_and_set_running(seq_group)
seq_group.init_multi_step(
num_scheduler_steps=self._get_num_lookahead_slots(
is_prefill=True) + 1)
seq_groups.append(
ScheduledSequenceGroup(seq_group=seq_group,
token_chunk_size=num_new_tokens))
budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
budget.add_num_seqs(seq_group.request_id, num_new_seqs)
# Queue requests that couldn't be scheduled.
waiting_queue.extendleft(leftover_waiting_sequences)
if len(seq_groups) > 0:
self.prev_prompt = True
return SchedulerPrefillOutputs(
seq_groups=seq_groups,
ignored_seq_groups=ignored_seq_groups,
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True))
def _schedule_default(self) -> SchedulerOutputs:
"""Schedule queued requests.
The current policy is designed to optimize the throughput. First,
it batches as many prefill requests as possible. And it schedules
decodes. If there's a pressure on GPU memory, decode requests can
be swapped or preempted.
"""
# Include running requests to the budget.
budget = SchedulingBudget(
token_budget=self.scheduler_config.max_num_batched_tokens,
max_num_seqs=self.scheduler_config.max_num_seqs,
)
# Make sure we include num running seqs before scheduling prefill,
# so that we don't schedule beyond max_num_seqs for prefill.
for seq_group in self.running:
budget.add_num_seqs(seq_group.request_id,
seq_group.get_max_num_running_seqs())
curr_loras = set(
seq_group.lora_int_id for seq_group in self.running
if seq_group.lora_int_id > 0) if self.lora_enabled else None
prefills = SchedulerPrefillOutputs.create_empty()
running_scheduled = SchedulerRunningOutputs.create_empty()
swapped_in = SchedulerSwappedInOutputs.create_empty()
# If any requests are swapped, prioritized swapped requests.
if not self.swapped:
prefills = self._schedule_prefills(budget,
curr_loras,
enable_chunking=False)
# Don't schedule decodes if prefills are scheduled.
# NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
# only contains decode requests, not chunked prefills.
if len(prefills.seq_groups) == 0:
running_scheduled = self._schedule_running(budget,
curr_loras,
enable_chunking=False)
# If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests.
if len(running_scheduled.preempted) + len(
running_scheduled.swapped_out) == 0:
swapped_in = self._schedule_swapped(budget, curr_loras)
assert (budget.num_batched_tokens <=
self.scheduler_config.max_num_batched_tokens)
assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
# Update waiting requests.
self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests.
if len(prefills.seq_groups) > 0:
self.running.extend([s.seq_group for s in prefills.seq_groups])
self.running.extend(running_scheduled.decode_seq_groups_list)
if len(swapped_in.decode_seq_groups) > 0:
self.running.extend(
[s.seq_group for s in swapped_in.decode_seq_groups])
# Update swapped requests.
self.swapped.extend(running_scheduled.swapped_out)
preempted = (len(running_scheduled.preempted) +
len(running_scheduled.swapped_out))
# There should be no prefill from running queue because this policy
# doesn't allow chunked prefills.
assert len(running_scheduled.prefill_seq_groups) == 0
assert len(swapped_in.prefill_seq_groups) == 0
# Merge lists
num_prefill_groups = len(prefills.seq_groups)
if num_prefill_groups > 0:
scheduled_seq_groups = prefills.seq_groups
scheduled_seq_groups.extend(running_scheduled.decode_seq_groups)
else:
scheduled_seq_groups = running_scheduled.decode_seq_groups
scheduled_seq_groups.extend(swapped_in.decode_seq_groups)
blocks_to_copy = running_scheduled.blocks_to_copy
blocks_to_copy.extend(swapped_in.blocks_to_copy)
ignored_seq_groups = prefills.ignored_seq_groups
ignored_seq_groups.extend(swapped_in.infeasible_seq_groups)
return SchedulerOutputs(
scheduled_seq_groups=scheduled_seq_groups,
num_prefill_groups=num_prefill_groups,
num_batched_tokens=budget.num_batched_tokens,
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
ignored_seq_groups=ignored_seq_groups,
num_lookahead_slots=running_scheduled.num_lookahead_slots,
running_queue_size=len(self.running),
preempted=preempted,
)
def _schedule_chunked_prefill(self) -> SchedulerOutputs:
"""Schedule queued requests.
Chunked prefill allows to chunk prefill requests, batch them together
with decode requests. This policy 1. schedule as many decoding requests
as possible. 2. schedule chunked prefill requests that are not
finished. 3. schedule swapped request. 4. schedule new prefill
requests.
The policy can sustain the high GPU utilization because it can put
prefill and decodes requests to the same batch, while it improves
inter token latency because decodes requests don't need to be blocked
by prefill requests.
"""
budget = SchedulingBudget(
token_budget=self.scheduler_config.max_num_batched_tokens,
max_num_seqs=self.scheduler_config.max_num_seqs,
)