From b32fb2e55d0eb67fd979a54aca4042a6f8d546af Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 13 Apr 2023 11:20:56 +0000 Subject: [PATCH 1/4] Support prefix --- cacheflow/master/block_manager.py | 26 +++++++++ cacheflow/master/scheduler.py | 71 +++++++++++++++++++++- cacheflow/models/attention.py | 94 +++++++++++++++++++++++++----- cacheflow/models/input_metadata.py | 19 +++++- cacheflow/models/llama.py | 3 +- cacheflow/models/opt.py | 3 +- cacheflow/sampling_params.py | 6 +- cacheflow/sequence.py | 3 + cacheflow/worker/worker.py | 76 +++++++++++++++++++++++- 9 files changed, 278 insertions(+), 23 deletions(-) diff --git a/cacheflow/master/block_manager.py b/cacheflow/master/block_manager.py index 30dfa1e8c28ed..26dc72ce9be10 100644 --- a/cacheflow/master/block_manager.py +++ b/cacheflow/master/block_manager.py @@ -100,6 +100,32 @@ def allocate(self, seq_group: SequenceGroup) -> None: for seq in seq_group.seqs: self.block_tables[seq.seq_id] = block_table.copy() + def allocate_with_prefix( + self, + seq_group: SequenceGroup, + prefix_id: int, + ) -> Optional[Tuple[int, List[int]]]: + # NOTE(woosuk): We ensure that every prefix must be a multiple of the + # block size, by recomputing the last few prefix tokens if they are + # not a multiple of the block size. + block_table = self.block_tables[prefix_id].copy() + # Increase the reference count of the prefix blocks. + for block in block_table: + block.ref_count += seq_group.num_seqs() + + # Allocate new physical token blocks that will store the prompt tokens. + # NOTE: Here we assume that all sequences in the group have the same prompt. + seq = seq_group.seqs[0] + for _ in range(len(seq.logical_token_blocks)): + block = self.gpu_allocator.allocate() + # Set the reference counts of the token blocks. + block.ref_count = seq_group.num_seqs() + block_table.append(block) + + # Assign the same block table for every sequence. + for seq in seq_group.seqs: + self.block_tables[seq.seq_id] = block_table.copy() + def can_append(self, seq_group: SequenceGroup) -> bool: # Simple heuristic: If there is at least one free block # for each sequence, we can append. diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index 66d1fa4e2c565..08f3672a88104 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -58,6 +58,11 @@ def __init__( num_cpu_blocks=num_cpu_blocks, ) + # Prefix sequence groups: prefix_id -> SequenceGroup. + self.prefix: Dict[int, SequenceGroup] = {} + # Prefix that has not been processed yet. + self.waiting_prefix: List[SequenceGroup] = [] + # Sequence groups in the WAITING state. self.waiting: List[SequenceGroup] = [] # Sequence groups in the RUNNING state. @@ -72,6 +77,18 @@ def __init__( # Performance-related statistics. self.stats = Stats(num_gpu_blocks, num_cpu_blocks) + def register_prefix( + self, + seq_group: SequenceGroup, + ) -> None: + num_seqs = seq_group.num_seqs() + if num_seqs > 1: + raise ValueError( + 'The prefix must be a single sequence, ' + f'but got {num_seqs} sequences.') + self.waiting_prefix.append(seq_group) + self.sampling_params[seq_group.group_id] = SamplingParams.from_dict({}) + def add_sequence_groups( self, seq_groups: List[Tuple[SequenceGroup, SamplingParams]], @@ -80,10 +97,36 @@ def add_sequence_groups( for seq_group, sampling_params in seq_groups: self.waiting.append(seq_group) self.sampling_params[seq_group.group_id] = sampling_params + if sampling_params.prefix_id is not None: + if sampling_params.prefix_id not in self.prefix: + raise ValueError( + f'Invalid prefix id: {sampling_params.prefix_id}') def _schedule( self, ) -> Tuple[Dict[int, int], Dict[int, int], Dict[int, List[int]], List[int]]: + # Prioritize the processing of prefix if there is any. + # This must happen at the initialization phase, before start serving. + if self.waiting_prefix: + assert not self.waiting + assert not self.running + assert not self.swapped + group_ids = [] + for seq_group in self.waiting_prefix: + assert seq_group.num_seqs() == 1 + seq = seq_group.seqs[0] + seq.status = SequenceStatus.PREFIX + self.running.append(seq_group) + + # NOTE(woosuk): The prefix id is the same as the sequence id, + # not the group id. + self.prefix[seq.seq_id] = seq_group + self.block_manager.allocate(seq_group) + group_ids.append(seq_group.group_id) + + self.waiting_prefix = [] + return ({}, {}, {}, group_ids) + # Blocks that need to be swaped or copied before model execution. blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {} @@ -273,14 +316,24 @@ def step(self) -> List[SequenceGroup]: # sequence length seq_len = seq.get_len() + sampling_params = self.sampling_params[group_id] + if sampling_params.prefix_id is None: + num_prefix_tokens = 0 + else: + prefix_seq_group = self.prefix[sampling_params.prefix_id] + prefix_seq = prefix_seq_group.seqs[0] + num_prefix_tokens = prefix_seq.get_len() + seq_len += num_prefix_tokens + input_seq_group = SequenceGroupInputs( group_id=group_id, is_prompt=is_prompt, input_tokens=input_tokens, context_len=seq_len, seq_logprobs=seq_logprobs, - sampling_params=self.sampling_params[group_id], + sampling_params=sampling_params, block_tables=block_tables, + num_prefix_tokens=num_prefix_tokens, ) input_seq_groups.append(input_seq_group) @@ -301,6 +354,13 @@ def post_step( self, seq_outputs: Dict[int, SequenceOutputs], ) -> None: + if self.prefix: + # Skip prefix sequences. + self.running = [ + seq_group for seq_group in self.running + if seq_group.seqs[0].seq_id not in self.prefix + ] + # Update the running sequences and free blocks. for seq_group in self.running: group_id = seq_group.group_id @@ -352,7 +412,14 @@ def post_step( self.running = running def _allocate(self, seq_group: SequenceGroup) -> None: - self.block_manager.allocate(seq_group) + group_id = seq_group.group_id + sampling_params = self.sampling_params[group_id] + if sampling_params.prefix_id is not None: + self.block_manager.allocate_with_prefix( + seq_group, sampling_params.prefix_id) + else: + self.block_manager.allocate(seq_group) + for seq in seq_group.seqs: seq.status = SequenceStatus.RUNNING # FIXME(woosuk): Support interactive generation. diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py index 6fa197e7c8b90..ad0940696e5b3 100644 --- a/cacheflow/models/attention.py +++ b/cacheflow/models/attention.py @@ -12,9 +12,23 @@ class GPTCacheFlowAttention(nn.Module): - def __init__(self, scale: float) -> None: + def __init__( + self, + scale: float, + num_heads: int, + head_size: int, + kv_buffer_size: int = 2048, + ) -> None: super().__init__() self.scale = float(scale) + self.num_heads = num_heads + self.head_size = head_size + + kv_buffer = torch.empty( + size=(kv_buffer_size, 3, num_heads, head_size), + dtype=torch.get_default_dtype(), + ) + self.register_buffer('kv_buffer', kv_buffer, persistent=False) def multi_query_kv_attention( self, @@ -28,8 +42,7 @@ def multi_query_kv_attention( if query.dtype == torch.float: raise ValueError('The float data type is not supported by ' 'FlashAttention. Use the half data type instead.') - head_size = query.shape[-1] - if head_size > 128: + if self.head_size > 128: raise ValueError('FlashAttention does not support head_size > 128.') # Directly call FlashAttention's internal function to avoid allocating @@ -49,6 +62,42 @@ def multi_query_kv_attention( return_softmax=False, ) + def multi_query_cached_kv_attention( + self, + output: torch.Tensor, # [num_prefix_prompt_tokens, num_heads, head_size] + query: torch.Tensor, # [num_prefix_prompt_tokens, num_heads, head_size] + key: torch.Tensor, # [num_prefix_prompt_tokens, num_heads, head_size] + value: torch.Tensor, # [num_prefix_prompt_tokens, num_heads, head_size] + key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] + value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] + slots: torch.Tensor, # [num_prefix_prompt_tokens] + cumulative_query_lens: torch.Tensor, # [num_prompts + 1] + cumulative_context_lens: torch.Tensor, # [num_prompts + 1] + max_query_len: int, + max_context_len: int, + ) -> None: + cache_ops.gather_kv( + key, + value, + key_cache, + value_cache, + slots, + ) + _flash_attn_forward( + query, + key, + value, + output, + cumulative_query_lens, + cumulative_context_lens, + max_query_len, + max_context_len, + dropout_p=0.0, + softmax_scale=self.scale, + causal=True, + return_softmax=False, + ) + def single_query_cached_kv_attention( self, output: torch.Tensor, # [num_generation_tokens, num_heads, head_size] @@ -57,10 +106,9 @@ def single_query_cached_kv_attention( value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] input_metadata: InputMetadata, ) -> None: - head_size = value_cache.shape[2] - supported_head_sizes = [32, 64, 80, 96, 128, 160, 192, 256] - if head_size not in supported_head_sizes: - raise ValueError(f'head_size ({head_size}) is not supported by ' + supported_head_sizes = {32, 64, 80, 96, 128, 160, 192, 256} + if self.head_size not in supported_head_sizes: + raise ValueError(f'head_size ({self.head_size}) is not supported by ' 'the single_query_cached_kv_attention kernel. ' 'Use one of the following head sizes: ' f'{supported_head_sizes}.') @@ -92,11 +140,9 @@ def forward( # tensor of shape [num_tokens, 3 * num_heads * head_size]. # Reshape the query, key, and value tensors. - num_heads = value_cache.shape[1] - head_size = value_cache.shape[2] - query = query.view(-1, num_heads, head_size) - key = key.view(-1, num_heads, head_size) - value = value.view(-1, num_heads, head_size) + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_heads, self.head_size) + value = value.view(-1, self.num_heads, self.head_size) # Pre-allocate the output tensor. output = torch.empty_like(query) @@ -129,6 +175,25 @@ def forward( input_metadata.slot_mapping, ) + # Compute the attetion op for prompt with cached prefix. + num_query_tokens = input_metadata.num_query_tokens + if num_query_tokens > 0: + start = num_prompt_tokens + end = num_prompt_tokens + num_query_tokens + self.multi_query_cached_kv_attention( + output[start:end], + query[start:end], + key[start:end], + value[start:end], + key_cache, + value_cache, + input_metadata.slots_including_prefix, + input_metadata.cumulative_query_lens, + input_metadata.cumulative_context_lens_including_prefix, + input_metadata.max_query_len, + input_metadata.max_context_len_including_prefix, + ) + if input_metadata.num_generation_tokens > 0: # Compute the attention op for generation tokens. self.single_query_cached_kv_attention( @@ -140,7 +205,7 @@ def forward( # Reshape the output tensor. # NOTE(woosuk): The output tensor may include paddings. - return output.view(-1, num_heads * head_size) + return output.view(-1, self.num_heads * self.head_size) class OPTCacheFlowAttention(GPTCacheFlowAttention): @@ -156,11 +221,12 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention): def __init__( self, scale: float, + num_heads: int, head_size: int, max_position: int = 8192, base: int = 10000, ) -> None: - super().__init__(scale) + super().__init__(scale, num_heads, head_size) # Create the cos and sin cache. inv_freq = 1.0 / (base ** (torch.arange(0, head_size, 2) / head_size)) diff --git a/cacheflow/models/input_metadata.py b/cacheflow/models/input_metadata.py index c61bfff20a66b..1b39eb965e4eb 100644 --- a/cacheflow/models/input_metadata.py +++ b/cacheflow/models/input_metadata.py @@ -17,6 +17,11 @@ def __init__( context_lens: torch.Tensor, max_context_len: int, block_tables: torch.Tensor, + query_lens: List[int], + cumulative_query_lens: torch.Tensor, + max_context_len_including_prefix: int, + cumulative_context_lens_including_prefix: torch.Tensor, + slots_including_prefix: torch.Tensor, ) -> None: self.seq_groups = seq_groups self.seq_logprobs = seq_logprobs @@ -27,9 +32,20 @@ def __init__( self.max_context_len = max_context_len self.block_tables = block_tables + self.query_lens = query_lens + self.cumulative_query_lens = cumulative_query_lens + self.max_context_len_including_prefix = max_context_len_including_prefix + self.cumulative_context_lens_including_prefix = cumulative_context_lens_including_prefix + self.slots_including_prefix = slots_including_prefix + self.num_prompts = len(prompt_lens) self.num_prompt_tokens = sum(prompt_lens) self.max_prompt_len = max(prompt_lens) if prompt_lens else 0 + + self.num_queries = len(query_lens) + self.num_query_tokens = sum(query_lens) + self.max_query_len = max(query_lens) if query_lens else 0 + self.num_generation_tokens = context_lens.shape[0] self.num_valid_tokens = slot_mapping.shape[0] if block_tables.numel() > 0: @@ -37,7 +53,8 @@ def __init__( else: self.max_num_blocks_per_seq = 0 assert block_tables.shape[0] == self.num_generation_tokens - assert context_lens.shape[0] == self.num_generation_tokens + assert self.num_valid_tokens == ( + self.num_prompt_tokens + self.num_query_tokens + self.num_generation_tokens) def __repr__(self) -> str: return (f'InputMetadata(' diff --git a/cacheflow/models/llama.py b/cacheflow/models/llama.py index a8572e4f362a9..eeaf4c94a9ea0 100644 --- a/cacheflow/models/llama.py +++ b/cacheflow/models/llama.py @@ -82,7 +82,8 @@ def __init__( input_is_parallel=True, perform_initialization=False, ) - self.attn = LlamaCacheFlowAttention(self.scaling, self.head_dim) + self.attn = LlamaCacheFlowAttention( + self.scaling, self.num_heads, self.head_dim) def forward( self, diff --git a/cacheflow/models/opt.py b/cacheflow/models/opt.py index 90c6f54e3fcad..d3c485c4dbfa2 100644 --- a/cacheflow/models/opt.py +++ b/cacheflow/models/opt.py @@ -59,7 +59,8 @@ def __init__( self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias, input_is_parallel=True, perform_initialization=False) - self.attn = OPTCacheFlowAttention(scale=self.scaling) + self.attn = OPTCacheFlowAttention( + self.scaling, self.num_heads, self.head_dim) def forward( self, diff --git a/cacheflow/sampling_params.py b/cacheflow/sampling_params.py index 4daeaa486e569..1c59a469a8584 100644 --- a/cacheflow/sampling_params.py +++ b/cacheflow/sampling_params.py @@ -13,6 +13,7 @@ def __init__( max_num_steps: int, num_logprobs: int, context_window_size: Optional[int], + prefix_id: Optional[int], ) -> None: if n < 1: raise ValueError(f'n must be at least 1, got {n}.') @@ -59,6 +60,7 @@ def __init__( self.max_num_steps = max_num_steps self.num_logprobs = num_logprobs self.context_window_size = context_window_size + self.prefix_id = prefix_id def __repr__(self) -> str: return (f'SamplingParams(n={self.n}, ' @@ -68,7 +70,8 @@ def __repr__(self) -> str: f'stop_token_ids={self.stop_token_ids}, ' f'max_num_steps={self.max_num_steps}, ' f'num_logprobs={self.num_logprobs}, ' - f'context_window_size={self.context_window_size})') + f'context_window_size={self.context_window_size}, ' + f'prefix_id={self.prefix_id})') @classmethod def from_dict(cls, d: Dict) -> 'SamplingParams': @@ -81,4 +84,5 @@ def from_dict(cls, d: Dict) -> 'SamplingParams': max_num_steps=d.get('max_num_steps', 16), num_logprobs=d.get('num_logprobs', 0), context_window_size=d.get('context_window_size', None), + prefix_id=d.get('prefix_id', None), ) diff --git a/cacheflow/sequence.py b/cacheflow/sequence.py index 6f5501a994685..66e0d2eeaf75d 100644 --- a/cacheflow/sequence.py +++ b/cacheflow/sequence.py @@ -7,6 +7,7 @@ class SequenceStatus(enum.Enum): + PREFIX = enum.auto() WAITING = enum.auto() RUNNING = enum.auto() SWAPPED = enum.auto() @@ -132,6 +133,7 @@ def __init__( seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs. sampling_params: SamplingParams, block_tables: Dict[int, List[int]], # Seq id -> List of physical block numbers. + num_prefix_tokens: int, ) -> None: self.group_id = group_id self.is_prompt = is_prompt @@ -140,6 +142,7 @@ def __init__( self.seq_logprobs = seq_logprobs self.sampling_params = sampling_params self.block_tables = block_tables + self.num_prefix_tokens = num_prefix_tokens class SequenceOutputs: diff --git a/cacheflow/worker/worker.py b/cacheflow/worker/worker.py index 95ce2c6a869e8..3377e1f59da48 100644 --- a/cacheflow/worker/worker.py +++ b/cacheflow/worker/worker.py @@ -98,19 +98,21 @@ def prepare_inputs( ) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]: seq_groups: List[Tuple[List[int], SamplingParams]] = [] seq_logprobs: Dict[int, float] = {} - sampling_params: Dict[int, SamplingParams] = {} input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] - # Add prompt tokens. + # Add prompt tokens without prefix. prompt_lens: List[int] = [] for input_seq_group in input_seq_groups: if not input_seq_group.is_prompt: continue + if input_seq_group.num_prefix_tokens > 0: + continue + sampling_params = input_seq_group.sampling_params + assert sampling_params.prefix_id is None seq_ids = list(input_seq_group.input_tokens.keys()) - sampling_params = input_seq_group.sampling_params seq_groups.append((seq_ids, sampling_params)) seq_logprobs.update(input_seq_group.seq_logprobs) @@ -139,6 +141,62 @@ def prepare_inputs( cumulative_prompt_lens.append( cumulative_prompt_lens[-1] + prompt_len) + # Add prompt tokens with prefix. + query_lens: List[int] = [] + context_lens_including_prefix: List[int] = [] + slots_including_prefix: List[int] = [] + for input_seq_group in input_seq_groups: + if not input_seq_group.is_prompt: + continue + + if input_seq_group.num_prefix_tokens == 0: + continue + sampling_params = input_seq_group.sampling_params + assert sampling_params.prefix_id is not None + + seq_ids = list(input_seq_group.input_tokens.keys()) + seq_groups.append((seq_ids, sampling_params)) + seq_logprobs.update(input_seq_group.seq_logprobs) + + # Use any sequence in the group. + seq_id = seq_ids[0] + + prompt_tokens = input_seq_group.input_tokens[seq_id] + prompt_len = len(prompt_tokens) + query_lens.append(prompt_len) + + num_prefix_tokens = input_seq_group.num_prefix_tokens + assert num_prefix_tokens % self.block_size == 0 + num_prefix_blocks = num_prefix_tokens // self.block_size + + input_tokens.extend(prompt_tokens) + input_positions.extend(range(num_prefix_tokens + prompt_len)) + + # Compute the slot mapping. + block_table = input_seq_group.block_tables[seq_id] + block_table = block_table[num_prefix_blocks:] + for i in range(prompt_len): + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + for i in range(num_prefix_tokens + prompt_len): + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slots_including_prefix.append(slot) + context_lens_including_prefix.append(num_prefix_tokens + prompt_len) + + cumulative_query_lens: List[int] = [0] + for query_len in query_lens: + cumulative_query_lens.append( + cumulative_query_lens[-1] + query_len) + cumulative_context_lens_including_prefix: List[int] = [0] + for context_len in context_lens_including_prefix: + cumulative_context_lens_including_prefix.append( + cumulative_context_lens_including_prefix[-1] + context_len) + # Add generation tokens. max_context_len = 0 max_num_blocks_per_seq = 0 @@ -197,6 +255,14 @@ def prepare_inputs( cumulative_prompt_lens_tensor = torch.tensor( cumulative_prompt_lens, dtype=torch.int, device='cuda') + # Data structure for prefix. + slots_including_prefix_tensor = torch.tensor( + slots_including_prefix, dtype=torch.int, device='cuda') + cumulative_query_lens_tensor = torch.tensor( + cumulative_query_lens, dtype=torch.int, device='cuda') + cumulative_context_lens_including_prefix_tensor = torch.tensor( + cumulative_context_lens_including_prefix, dtype=torch.int, device='cuda') + input_metadata = InputMetadata( seq_groups=seq_groups, seq_logprobs=seq_logprobs, @@ -206,6 +272,10 @@ def prepare_inputs( context_lens=context_lens_tensor, max_context_len=max_context_len, block_tables=block_tables_tensor, + query_lens=query_lens, + cumulative_query_lens=cumulative_query_lens_tensor, + cumulative_context_lens_including_prefix=cumulative_context_lens_including_prefix_tensor, + slots_including_prefix=slots_including_prefix_tensor, ) return tokens_tensor, positions_tensor, input_metadata From 16e39413d1c28f538d01b97ff2254b90a73d950f Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 13 Apr 2023 23:32:57 +0000 Subject: [PATCH 2/4] Implemen prefix sharingt & Addexperiment script --- benchmark/benchmark_prefix_translation.py | 300 ++++++++++++++++++++++ benchmark/trace.py | 99 +++++++ cacheflow/master/block_manager.py | 3 +- cacheflow/master/scheduler.py | 8 +- cacheflow/master/server.py | 6 + cacheflow/models/attention.py | 99 +++---- cacheflow/models/input_metadata.py | 5 +- cacheflow/models/llama.py | 17 +- cacheflow/models/memory_analyzer.py | 6 +- cacheflow/models/sample.py | 8 +- cacheflow/sequence.py | 1 - cacheflow/worker/worker.py | 4 +- 12 files changed, 499 insertions(+), 57 deletions(-) create mode 100644 benchmark/benchmark_prefix_translation.py diff --git a/benchmark/benchmark_prefix_translation.py b/benchmark/benchmark_prefix_translation.py new file mode 100644 index 0000000000000..34a2a975b7297 --- /dev/null +++ b/benchmark/benchmark_prefix_translation.py @@ -0,0 +1,300 @@ +import argparse +import logging +import os +import pickle +import time +from typing import List + +from tqdm import tqdm +from transformers import AutoConfig + +from benchmark.trace import generate_translation_requests +from cacheflow.master.simple_frontend import SimpleFrontend +from cacheflow.master.server import (Server, add_server_arguments, + initialize_ray_cluster) +from cacheflow.sampling_params import SamplingParams +from cacheflow.utils import get_gpu_memory, get_cpu_memory + + +logger = logging.getLogger(__name__) + + +def main(args: argparse.Namespace): + assert args.pipeline_parallel_size == 1, ( + 'Pipeline parallelism is not supported yet.') + + (num_nodes, num_devices_per_node, distributed_init_method, + all_stage_devices) = ( + initialize_ray_cluster( + address='local', + pipeline_parallel_size=args.pipeline_parallel_size, + tensor_parallel_size=args.tensor_parallel_size)) + + # Create a server. + server = Server( + model=args.model, + model_path=args.model_path, + use_dummy_weights=args.use_dummy_weights, + pipeline_parallel_size=args.pipeline_parallel_size, + tensor_parallel_size=args.tensor_parallel_size, + block_size=args.block_size, + dtype=args.dtype, + seed=args.seed, + swap_space=args.swap_space, + max_num_batched_tokens=args.max_num_batched_tokens, + max_num_sequences=args.max_num_sequences, + num_nodes=num_nodes, + num_devices_per_node=num_devices_per_node, + distributed_init_method=distributed_init_method, + all_stage_devices=all_stage_devices, + gpu_memory=get_gpu_memory(), + cpu_memory=get_cpu_memory(), + collect_stats=True, + do_memory_analysis=args.do_memory_analysis, + ) + + # Create a frontend. + frontend = SimpleFrontend( + model_name=args.model, + block_size=args.block_size, + ) + # Generate requests. + prefix_tokens, requests = generate_translation_requests( + model=args.model, + dataset=args.dataset, + num_examples=args.num_prefix_examples, + request_rate=args.request_rate, + duration=args.duration, + seed=args.seed, + block_size=args.block_size, + ) + # Register prefix. + logger.info('Registering prefix.') + frontend._add_query(prefix_tokens, SamplingParams.from_dict({})) + server.register_prefix(frontend.get_inputs()[0][0]) + server.step() + + # Warm up. + logger.info('Warming up.') + num_warmup_requests = 8 + warmup_input_len = 8 + warmup_output_len = 32 + warmup_sampling_params = SamplingParams( + n=1, + temperature=1.0, + top_p=0.99, + max_num_steps=warmup_output_len, + use_beam_search=False, + stop_token_ids=set(), + num_logprobs=0, + context_window_size=None, + prefix_id=None, + ) + for _ in range(num_warmup_requests): + frontend._add_query([0] * warmup_input_len, warmup_sampling_params) + server.add_sequence_groups(frontend.get_inputs()) + while True: + updated_seq_groups = server.step() + if not server.has_unfinished_requests(): + break + + # Start benchmarking. + logger.info('Start benchmarking.') + + # # + # _, input_tokens, sampling_params = requests.pop(3) + # frontend._add_query(input_tokens, sampling_params, arrival_time=0.0) + # server.add_sequence_groups(frontend.get_inputs()) + + # step_cnt = 0 + # while True: + # # print(f'Step {step_cnt}:') + # step_cnt += 1 + # updated_seq_groups = server.step() + # for seq_group in updated_seq_groups: + # if seq_group.is_finished(): + # # Es war ein echter Herzschmerzenstod + # frontend.print_response(seq_group) + # if not server.has_unfinished_requests(): + # break + # return + + # Initialize tqdm. + pbar = tqdm(total=len(requests), desc='Finished requests') + + finished = [] + server.scheduler.reset_stats() + start_time = time.time() + while True: + now = time.time() + if args.timeout is not None and now - start_time > args.timeout: + logger.info('Timeout. Stop benchmarking.') + break + + while requests: + if requests[0][0] <= now - start_time: + request_time, input_tokens, sampling_params = requests.pop(0) + frontend._add_query( + input_tokens, sampling_params, arrival_time=start_time + request_time) + else: + break + server.add_sequence_groups(frontend.get_inputs()) + updated_seq_groups = server.step() + + now = time.time() + for seq_group in updated_seq_groups: + if not seq_group.is_finished(): + continue + arrival_time = seq_group.arrival_time + finish_time = now + for seq in seq_group.get_seqs(): + seq_len = seq.get_len() + output_len = seq_len - seq.prompt_len + finished.append({ + 'group_id': seq_group.group_id, + 'seq_id': seq.seq_id, + 'arrival_time': arrival_time, + 'finish_time': finish_time, + 'prompt_len': seq.prompt_len, + 'output_len': output_len, + }) + pbar.update(1) + + if not (requests or server.has_unfinished_requests()): + break + pbar.close() + logger.info('Finish benchmarking. Saving stats.') + server.scheduler.save_stats(args.output_dir) + with open(os.path.join(args.output_dir, 'sequences.pkl'), 'wb') as f: + pickle.dump(finished, f) + logger.info('Done.') + + +def get_model_name(model: str) -> str: + OPT_MODELS = [ + 'opt-125m', + 'opt-350m', + 'opt-1.3b', + 'opt-2.7b', + 'opt-6.7b', + 'opt-13b', + 'opt-30b', + 'opt-66b', + 'opt-175b', + ] + for opt_model in OPT_MODELS: + if opt_model in model: + return opt_model + + config = AutoConfig.from_pretrained(model) + assert config.model_type == 'llama' + hidden_size = config.hidden_size + if hidden_size == 4096: + return 'llama-7b' + elif hidden_size == 5120: + return 'llama-13b' + elif hidden_size == 6656: + return 'llama-30b' + elif hidden_size == 8192: + return 'llama-65b' + else: + raise ValueError(f'Unknown model: {model}') + + +def get_dataset_name(dataset: str) -> str: + if 'sharegpt' in dataset.lower(): + return 'sharegpt' + elif 'alpaca' in dataset.lower(): + return 'alpaca' + else: + raise ValueError(f'Unknown dataset: {dataset}') + + +def get_sampling_dir_name( + n1: float, + n2: float, + n3: float, + n4: float, + n6: float, + n2_beam: float, + n4_beam: float, + n6_beam: float, + n8_beam: float, +) -> str: + method = '' + if n1 > 0.0: + method = 'n1' if n1 == 1.0 else method + f'n1-{n1}-' + if n2 > 0.0: + method = 'n2' if n2 == 1.0 else method + f'n2-{n2}-' + if n3 > 0.0: + method = 'n3' if n3 == 1.0 else method + f'n3-{n3}-' + if n4 > 0.0: + method = 'n4' if n4 == 1.0 else method + f'n4-{n4}-' + if n6 > 0.0: + method = 'n6' if n6 == 1.0 else method + f'n6-{n6}-' + if n2_beam > 0.0: + method = 'n2-beam' if n2_beam == 1.0 else method + f'n2-beam-{n2_beam}-' + if n4_beam > 0.0: + method = 'n4-beam' if n4_beam == 1.0 else method + f'n4-beam-{n4_beam}-' + if n6_beam > 0.0: + method = 'n6-beam' if n6_beam == 1.0 else method + f'n6-beam-{n6_beam}-' + if n8_beam > 0.0: + method = 'n8-beam' if n8_beam == 1.0 else method + f'n8-beam-{n8_beam}-' + return method[:-1] if method.endswith('-') else method + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='CacheFlow simple server.') + parser = add_server_arguments(parser) + parser.add_argument('--output-dir', type=str, help='path to output directory', default=None) + + parser.add_argument('--num-prefix-examples', type=int, help='number of examples to use in prefix', required=True) + parser.add_argument('--dataset', type=str, help='path to dataset', default='wmt16') + parser.add_argument('--request-rate', type=float, help='reqs/sec', required=True) + parser.add_argument('--duration', type=int, help='duration in seconds', required=True) + parser.add_argument('--do-memory-analysis', action='store_true', + help='do memory analysis (This will lower the throughput. Use this only for analysis.)') + parser.add_argument('--timeout', type=int, help='time out in seconds', default=None) + + parser.add_argument('--n1', type=float, help='ratio of requests with n=1', default=0.0) + parser.add_argument('--n2', type=float, help='ratio of requests with n=2', default=0.0) + parser.add_argument('--n3', type=float, help='ratio of requests with n=3', default=0.0) + parser.add_argument('--n4', type=float, help='ratio of requests with n=4', default=0.0) + parser.add_argument('--n6', type=float, help='ratio of requests with n=6', default=0.0) + parser.add_argument('--n2-beam', type=float, help='ratio of requests with n=2 & beam search', default=0.0) + parser.add_argument('--n4-beam', type=float, help='ratio of requests with n=4 & beam search', default=0.0) + parser.add_argument('--n6-beam', type=float, help='ratio of requests with n=6 & beam search', default=0.0) + parser.add_argument('--n8-beam', type=float, help='ratio of requests with n=8 & beam search', default=0.0) + args = parser.parse_args() + if args.n1 + args.n2 + args.n3 + args.n4 + args.n6 + args.n2_beam + args.n4_beam + args.n6_beam + args.n8_beam != 1.0: + raise ValueError('The ratios of requests must sum to 1.') + + model_name = get_model_name(args.model) + sample_dir = get_sampling_dir_name( + args.n1, args.n2, args.n3, args.n4, args.n6, args.n2_beam, args.n4_beam, args.n6_beam, args.n8_beam) + if args.output_dir is None: + args.output_dir = os.path.join( + '../prefix_exp', + f'{args.dataset}-{args.num_prefix_examples}shot', + f'{model_name}-tp{args.tensor_parallel_size}', + sample_dir, + 'cacheflow', + f'req-rate-{args.request_rate}', + f'seed{args.seed}', + f'duration-{args.duration}', + ) + os.makedirs(args.output_dir, exist_ok=True) + + # Set up logging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + handlers=[ + logging.StreamHandler(), + logging.FileHandler(os.path.join(args.output_dir, 'log.txt')), + ], + ) + logger.info(args) + + main(args) diff --git a/benchmark/trace.py b/benchmark/trace.py index 42e2032358434..56f64b4a8be73 100644 --- a/benchmark/trace.py +++ b/benchmark/trace.py @@ -2,7 +2,9 @@ import random from typing import List, Tuple +from datasets import load_dataset import numpy as np +from transformers import AutoTokenizer from cacheflow.sampling_params import SamplingParams @@ -114,3 +116,100 @@ def generate_text_completion_requests( cum_sum += 1 requests.append((timestamp, input_tokens, sampling_params)) return requests + + +def generate_translation_requests( + model: str, + dataset: str, + num_examples: int, + request_rate: float, + duration: int, + seed: int, + block_size: int, + max_seq_len: int = 2048, + time_quantum: int = 10, +) -> Tuple[List[int], List[Tuple[float, List[int], SamplingParams]]]: + tokenizer = AutoTokenizer.from_pretrained(model) + + random.seed(seed) + np.random.seed(seed) + + # Generate timestamps for requests using Poisson distribution. + lam = request_rate * (time_quantum / 1000) + quantums_per_sec = 1000 / time_quantum + arrival_times = np.random.poisson( + lam=lam, size=int(duration * quantums_per_sec)) + timestamps = [] + for i, n in enumerate(arrival_times): + timestamps += [i * (time_quantum / 1000)] * n + + # Load the training dataset and sample examples. + train_set = load_dataset('wmt16', 'de-en', split='train') + train_size = train_set.num_rows + if num_examples > train_size: + raise ValueError( + f'Number of examples ({num_examples}) is greater than the ' + f'number of training examples ({train_size}).') + + # Add instruction first. + prefix = 'Translate English to German:\n' + + # Randomly sample examples from the training dataset and add them to the + # prefix. + indices = np.random.choice(train_size, num_examples, replace=False).tolist() + for i in indices: + pair = train_set[i]['translation'] + en = pair['en'] + de = pair['de'] + example = f'{en} => {de}\n' + prefix += example + prefix_tokens = tokenizer.encode(prefix, add_special_tokens=True) + + # If the prefix length is not a multiple of the block size, truncate it. + prefix_len = len(prefix_tokens) + remainder_tokens = [] + if prefix_len % block_size != 0: + remainder_tokens = prefix_tokens[-(prefix_len % block_size):] + prefix_tokens = prefix_tokens[:-(prefix_len % block_size)] + prefix_len = len(prefix_tokens) + + # Tokenize the test set. + test_set = load_dataset(dataset, 'de-en', split='test') + tokenized = [] + for data in test_set: + en = data['translation']['en'] + ' =>' + # We skip the token because the tokens will be appended to a prefix. + en_tokens = tokenizer.encode(en, add_special_tokens=False) + input_tokens = remainder_tokens + en_tokens + + de = data['translation']['de'] + output_tokens = tokenizer.encode(de, add_special_tokens=False) + + # Filter out too long sequences. + if prefix_len + len(input_tokens) + len(output_tokens) > max_seq_len: + continue + tokenized.append((input_tokens, len(output_tokens))) + + # Generate requests. + num_requests = len(timestamps) + while len(tokenized) < num_requests: + tokenized += tokenized + tokenized = tokenized[:num_requests] + # Shuffle the requests. + random.shuffle(tokenized) + random_sampling_params_dict = { + 'temperature': 0.0, + 'top_p': 1.0, + 'use_beam_search': False, + 'stop_token_ids': set(), + 'num_logprobs': 0, + 'context_window_size': None, + 'prefix_id': 0, # FIXME + } + requests = [] + for timestamp, pair in zip(timestamps, tokenized): + input_tokens, output_len = pair + sampling_params = SamplingParams( + n=1, max_num_steps=output_len, **random_sampling_params_dict) + requests.append((timestamp, input_tokens, sampling_params)) + return prefix_tokens, requests diff --git a/cacheflow/master/block_manager.py b/cacheflow/master/block_manager.py index 26dc72ce9be10..f7ff8abae6a35 100644 --- a/cacheflow/master/block_manager.py +++ b/cacheflow/master/block_manager.py @@ -138,7 +138,8 @@ def append(self, seq: Sequence) -> Optional[Tuple[int, int]]: logical_blocks = seq.logical_token_blocks block_table = self.block_tables[seq.seq_id] - if len(block_table) < len(logical_blocks): + # HACK + if logical_blocks[-1].num_tokens == 1: # The sequence has a new logical block. # Allocate a new physical block. block = self.gpu_allocator.allocate() diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index 08f3672a88104..e49fade314855 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -86,6 +86,12 @@ def register_prefix( raise ValueError( 'The prefix must be a single sequence, ' f'but got {num_seqs} sequences.') + seq = seq_group.seqs[0] + print(f'Registering prefix id: {seq.seq_id}, prefix length: {seq.prompt_len}') + if seq.prompt_len % self.block_size != 0: + raise ValueError( + 'The prefix length must be a multiple of the block size, ' + f'but got {seq.prompt_len} and {self.block_size}.') self.waiting_prefix.append(seq_group) self.sampling_params[seq_group.group_id] = SamplingParams.from_dict({}) @@ -115,7 +121,7 @@ def _schedule( for seq_group in self.waiting_prefix: assert seq_group.num_seqs() == 1 seq = seq_group.seqs[0] - seq.status = SequenceStatus.PREFIX + seq.status = SequenceStatus.RUNNING self.running.append(seq_group) # NOTE(woosuk): The prefix id is the same as the sequence id, diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index 3370058703ca3..70e97bad3e872 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -97,6 +97,12 @@ def add_sequence_groups( ): self.scheduler.add_sequence_groups(sequence_groups) + def register_prefix( + self, + sequence_group: SequenceGroup, + ) -> None: + self.scheduler.register_prefix(sequence_group) + def step(self): return self.scheduler.step() diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py index ad0940696e5b3..a8f36f7e049c6 100644 --- a/cacheflow/models/attention.py +++ b/cacheflow/models/attention.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, List from flash_attn.flash_attn_interface import _flash_attn_forward import torch @@ -17,19 +17,12 @@ def __init__( scale: float, num_heads: int, head_size: int, - kv_buffer_size: int = 2048, ) -> None: super().__init__() self.scale = float(scale) self.num_heads = num_heads self.head_size = head_size - kv_buffer = torch.empty( - size=(kv_buffer_size, 3, num_heads, head_size), - dtype=torch.get_default_dtype(), - ) - self.register_buffer('kv_buffer', kv_buffer, persistent=False) - def multi_query_kv_attention( self, output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] @@ -66,37 +59,47 @@ def multi_query_cached_kv_attention( self, output: torch.Tensor, # [num_prefix_prompt_tokens, num_heads, head_size] query: torch.Tensor, # [num_prefix_prompt_tokens, num_heads, head_size] - key: torch.Tensor, # [num_prefix_prompt_tokens, num_heads, head_size] - value: torch.Tensor, # [num_prefix_prompt_tokens, num_heads, head_size] key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] - slots: torch.Tensor, # [num_prefix_prompt_tokens] - cumulative_query_lens: torch.Tensor, # [num_prompts + 1] - cumulative_context_lens: torch.Tensor, # [num_prompts + 1] - max_query_len: int, - max_context_len: int, + kv_buffer: torch.Tensor, + slots: torch.Tensor, # [] + query_lens: List[int], + kv_lens: List[int], ) -> None: - cache_ops.gather_kv( - key, - value, - key_cache, - value_cache, - slots, - ) - _flash_attn_forward( - query, - key, - value, - output, - cumulative_query_lens, - cumulative_context_lens, - max_query_len, - max_context_len, - dropout_p=0.0, - softmax_scale=self.scale, - causal=True, - return_softmax=False, - ) + _, key_buffer, value_buffer = kv_buffer.unbind(dim=1) + + num_pairs = len(query_lens) + cum_query_len = 0 + cum_kv_len = 0 + for i in range(num_pairs): + query_len = query_lens[i] + kv_len = kv_lens[i] + cache_ops.gather_cached_kv( + key_buffer[:kv_len], + value_buffer[:kv_len], + key_cache, + value_cache, + slots[cum_kv_len:cum_kv_len + kv_len], + ) + torch.cuda.synchronize() + _flash_attn_forward( + query[cum_query_len:cum_query_len + query_len], + key_buffer[:kv_len], + value_buffer[:kv_len], + output[cum_query_len:cum_query_len + query_len], + torch.tensor([0, query_len], dtype=torch.int, device=query.device), + torch.tensor([0, kv_len], dtype=torch.int, device=query.device), + query_len, + kv_len, + dropout_p=0.0, + softmax_scale=self.scale, + causal=True, + return_softmax=False, + ) + torch.cuda.synchronize() + + cum_query_len += query_len + cum_kv_len += kv_len def single_query_cached_kv_attention( self, @@ -133,6 +136,7 @@ def forward( value: torch.Tensor, # [num_tokens, num_heads * head_size] key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] + kv_buffer: torch.Tensor, input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: # [num_tokens, num_heads * head_size] @@ -150,6 +154,7 @@ def forward( # Compute the attention op for prompts. num_prompt_tokens = input_metadata.num_prompt_tokens if num_prompt_tokens > 0: + torch.cuda.synchronize() self.multi_query_kv_attention( output[:num_prompt_tokens], query[:num_prompt_tokens], @@ -158,6 +163,7 @@ def forward( input_metadata.cumulative_prompt_lens, input_metadata.max_prompt_len, ) + torch.cuda.synchronize() # Wait until the cache op is done. if cache_event is not None: @@ -167,6 +173,7 @@ def forward( num_valid_tokens = input_metadata.num_valid_tokens if num_valid_tokens > 0: # The stride is 3 because the key and value are sliced from qkv. + torch.cuda.synchronize() cache_ops.reshape_and_cache( key[:num_valid_tokens], value[:num_valid_tokens], @@ -180,28 +187,30 @@ def forward( if num_query_tokens > 0: start = num_prompt_tokens end = num_prompt_tokens + num_query_tokens + torch.cuda.synchronize() self.multi_query_cached_kv_attention( output[start:end], query[start:end], - key[start:end], - value[start:end], key_cache, value_cache, + kv_buffer, input_metadata.slots_including_prefix, - input_metadata.cumulative_query_lens, - input_metadata.cumulative_context_lens_including_prefix, - input_metadata.max_query_len, - input_metadata.max_context_len_including_prefix, + input_metadata.query_lens, + input_metadata.prefix_context_lens, ) + torch.cuda.synchronize() if input_metadata.num_generation_tokens > 0: # Compute the attention op for generation tokens. + start = num_prompt_tokens + num_query_tokens + end = num_valid_tokens self.single_query_cached_kv_attention( - output[num_prompt_tokens:num_valid_tokens], - query[num_prompt_tokens:num_valid_tokens], + output[start:end], + query[start:end], key_cache, value_cache, input_metadata) + torch.cuda.synchronize() # Reshape the output tensor. # NOTE(woosuk): The output tensor may include paddings. @@ -251,6 +260,7 @@ def forward( value: torch.Tensor, # [num_tokens, num_heads * head_size] key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] + kv_buffer: torch.Tensor, input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: # [num_tokens, num_heads * head_size] @@ -268,6 +278,7 @@ def forward( value, key_cache, value_cache, + kv_buffer, input_metadata, cache_event, ) diff --git a/cacheflow/models/input_metadata.py b/cacheflow/models/input_metadata.py index 1b39eb965e4eb..8789c69bbd58f 100644 --- a/cacheflow/models/input_metadata.py +++ b/cacheflow/models/input_metadata.py @@ -19,7 +19,7 @@ def __init__( block_tables: torch.Tensor, query_lens: List[int], cumulative_query_lens: torch.Tensor, - max_context_len_including_prefix: int, + prefix_context_lens: List[int], cumulative_context_lens_including_prefix: torch.Tensor, slots_including_prefix: torch.Tensor, ) -> None: @@ -34,7 +34,7 @@ def __init__( self.query_lens = query_lens self.cumulative_query_lens = cumulative_query_lens - self.max_context_len_including_prefix = max_context_len_including_prefix + self.prefix_context_lens = prefix_context_lens self.cumulative_context_lens_including_prefix = cumulative_context_lens_including_prefix self.slots_including_prefix = slots_including_prefix @@ -45,6 +45,7 @@ def __init__( self.num_queries = len(query_lens) self.num_query_tokens = sum(query_lens) self.max_query_len = max(query_lens) if query_lens else 0 + self.max_prefix_context_len = max(prefix_context_lens) if prefix_context_lens else 0 self.num_generation_tokens = context_lens.shape[0] self.num_valid_tokens = slot_mapping.shape[0] diff --git a/cacheflow/models/llama.py b/cacheflow/models/llama.py index eeaf4c94a9ea0..eda8c19a439d6 100644 --- a/cacheflow/models/llama.py +++ b/cacheflow/models/llama.py @@ -90,6 +90,7 @@ def forward( positions: torch.LongTensor, hidden_states: torch.Tensor, kv_cache: KVCache, + kv_buffer: torch.Tensor, input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: @@ -97,7 +98,7 @@ def forward( q, k, v = qkv.chunk(chunks=3, dim=-1) k_cache, v_cache = kv_cache attn_output = self.attn( - positions, q, k, v, k_cache, v_cache, input_metadata, cache_event) + positions, q, k, v, k_cache, v_cache, kv_buffer, input_metadata, cache_event) output, _ = self.o_proj(attn_output) return output @@ -124,6 +125,7 @@ def forward( positions: torch.LongTensor, hidden_states: torch.Tensor, kv_cache: KVCache, + kv_buffer: torch.Tensor, input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: @@ -134,6 +136,7 @@ def forward( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, + kv_buffer=kv_buffer, input_metadata=input_metadata, cache_event=cache_event, ) @@ -155,6 +158,17 @@ def __init__(self, config: LlamaConfig): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size + # Allocate KV buffer. + kv_buffer_size = 2048 + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() + num_heads = config.num_attention_heads // tensor_model_parallel_world_size + head_size = config.hidden_size // config.num_attention_heads + kv_buffer = torch.empty( + size=(kv_buffer_size, 3, num_heads, head_size), + dtype=torch.get_default_dtype(), + ) + self.register_buffer('kv_buffer', kv_buffer, persistent=False) + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size, perform_initialization=False) self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) @@ -179,6 +193,7 @@ def forward( positions, hidden_states, kv_caches[i], + self.kv_buffer, input_metadata, cache_event, ) diff --git a/cacheflow/models/memory_analyzer.py b/cacheflow/models/memory_analyzer.py index 3a539cca97633..dc2cd791e6df5 100644 --- a/cacheflow/models/memory_analyzer.py +++ b/cacheflow/models/memory_analyzer.py @@ -210,8 +210,12 @@ def _get_max_act_size( # Size of output logits. output_logits = 2 * (max_num_batched_tokens * self.vocab_size) max_act = max(max_act, output_logits) + + # KV buffer size. + kv_buffer = 2048 * 3 * self.hidden_size // self.tensor_parallel_size + dtype_size = get_dtype_size(self.dtype) - return dtype_size * max_act + return dtype_size * (max_act + kv_buffer) def get_cache_block_size(self) -> int: key_cache_block = self.block_size * self.hidden_size // self.tensor_parallel_size diff --git a/cacheflow/models/sample.py b/cacheflow/models/sample.py index 1e358c7e5278e..64ee232df4db9 100644 --- a/cacheflow/models/sample.py +++ b/cacheflow/models/sample.py @@ -59,7 +59,7 @@ def _prune_hidden_states( ) -> torch.Tensor: start_idx = 0 last_token_indicies: List[int] = [] - for prompt_len in input_metadata.prompt_lens: + for prompt_len in input_metadata.prompt_lens + input_metadata.query_lens: last_token_indicies.append(start_idx + prompt_len - 1) start_idx += prompt_len last_token_indicies.extend( @@ -81,7 +81,7 @@ def _get_temperatures( # Set the temperature to 1 to avoid division by zero. temperature = 1.0 - if i < input_metadata.num_prompts: + if i < input_metadata.num_prompts + input_metadata.num_query_tokens: # A prompt input. temperatures.append(temperature) else: @@ -96,7 +96,7 @@ def _get_top_ps( top_ps: List[float] = [] for i, seq_group in enumerate(input_metadata.seq_groups): seq_ids, sampling_params = seq_group - if i < input_metadata.num_prompts: + if i < input_metadata.num_prompts + input_metadata.num_query_tokens: # A prompt input. top_ps.append(sampling_params.top_p) else: @@ -234,7 +234,7 @@ def _sample( idx = 0 for i, seq_group in enumerate(input_metadata.seq_groups): seq_ids, sampling_params = seq_group - if i < input_metadata.num_prompts: + if i < input_metadata.num_prompts + input_metadata.num_query_tokens: # Generate the next tokens for a prompt input. assert len(seq_ids) == sampling_params.n prob = probs[idx] diff --git a/cacheflow/sequence.py b/cacheflow/sequence.py index 66e0d2eeaf75d..75e92499d6510 100644 --- a/cacheflow/sequence.py +++ b/cacheflow/sequence.py @@ -7,7 +7,6 @@ class SequenceStatus(enum.Enum): - PREFIX = enum.auto() WAITING = enum.auto() RUNNING = enum.auto() SWAPPED = enum.auto() diff --git a/cacheflow/worker/worker.py b/cacheflow/worker/worker.py index 3377e1f59da48..4117a494e3151 100644 --- a/cacheflow/worker/worker.py +++ b/cacheflow/worker/worker.py @@ -174,9 +174,8 @@ def prepare_inputs( # Compute the slot mapping. block_table = input_seq_group.block_tables[seq_id] - block_table = block_table[num_prefix_blocks:] for i in range(prompt_len): - block_number = block_table[i // self.block_size] + block_number = block_table[num_prefix_blocks + (i // self.block_size)] block_offset = i % self.block_size slot = block_number * self.block_size + block_offset slot_mapping.append(slot) @@ -274,6 +273,7 @@ def prepare_inputs( block_tables=block_tables_tensor, query_lens=query_lens, cumulative_query_lens=cumulative_query_lens_tensor, + prefix_context_lens=context_lens_including_prefix, cumulative_context_lens_including_prefix=cumulative_context_lens_including_prefix_tensor, slots_including_prefix=slots_including_prefix_tensor, ) From 092be5d469f631c06564ac0b40a592147c83a632 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 13 Apr 2023 23:37:05 +0000 Subject: [PATCH 3/4] Remove synchronization --- cacheflow/models/attention.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py index a8f36f7e049c6..36d49f8ba153a 100644 --- a/cacheflow/models/attention.py +++ b/cacheflow/models/attention.py @@ -81,7 +81,6 @@ def multi_query_cached_kv_attention( value_cache, slots[cum_kv_len:cum_kv_len + kv_len], ) - torch.cuda.synchronize() _flash_attn_forward( query[cum_query_len:cum_query_len + query_len], key_buffer[:kv_len], @@ -96,7 +95,6 @@ def multi_query_cached_kv_attention( causal=True, return_softmax=False, ) - torch.cuda.synchronize() cum_query_len += query_len cum_kv_len += kv_len @@ -154,7 +152,6 @@ def forward( # Compute the attention op for prompts. num_prompt_tokens = input_metadata.num_prompt_tokens if num_prompt_tokens > 0: - torch.cuda.synchronize() self.multi_query_kv_attention( output[:num_prompt_tokens], query[:num_prompt_tokens], @@ -163,7 +160,6 @@ def forward( input_metadata.cumulative_prompt_lens, input_metadata.max_prompt_len, ) - torch.cuda.synchronize() # Wait until the cache op is done. if cache_event is not None: @@ -173,7 +169,6 @@ def forward( num_valid_tokens = input_metadata.num_valid_tokens if num_valid_tokens > 0: # The stride is 3 because the key and value are sliced from qkv. - torch.cuda.synchronize() cache_ops.reshape_and_cache( key[:num_valid_tokens], value[:num_valid_tokens], @@ -187,7 +182,6 @@ def forward( if num_query_tokens > 0: start = num_prompt_tokens end = num_prompt_tokens + num_query_tokens - torch.cuda.synchronize() self.multi_query_cached_kv_attention( output[start:end], query[start:end], @@ -198,7 +192,6 @@ def forward( input_metadata.query_lens, input_metadata.prefix_context_lens, ) - torch.cuda.synchronize() if input_metadata.num_generation_tokens > 0: # Compute the attention op for generation tokens. @@ -210,7 +203,6 @@ def forward( key_cache, value_cache, input_metadata) - torch.cuda.synchronize() # Reshape the output tensor. # NOTE(woosuk): The output tensor may include paddings. From e81b8f7c5709ed9d3487685be22bb2eb5433b9ea Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 14 Apr 2023 23:18:43 +0000 Subject: [PATCH 4/4] Fix --- benchmark/benchmark_prefix_translation.py | 3 +++ cacheflow/models/attention.py | 6 ++++-- cacheflow/worker/worker.py | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/benchmark/benchmark_prefix_translation.py b/benchmark/benchmark_prefix_translation.py index 34a2a975b7297..8aebf42e00d91 100644 --- a/benchmark/benchmark_prefix_translation.py +++ b/benchmark/benchmark_prefix_translation.py @@ -145,6 +145,9 @@ def main(args: argparse.Namespace): for seq_group in updated_seq_groups: if not seq_group.is_finished(): continue + # Print outputs. + # frontend.print_response(seq_group) + arrival_time = seq_group.arrival_time finish_time = now for seq in seq_group.get_seqs(): diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py index 36d49f8ba153a..6d9fc111bdd8d 100644 --- a/cacheflow/models/attention.py +++ b/cacheflow/models/attention.py @@ -66,7 +66,7 @@ def multi_query_cached_kv_attention( query_lens: List[int], kv_lens: List[int], ) -> None: - _, key_buffer, value_buffer = kv_buffer.unbind(dim=1) + query_buffer, key_buffer, value_buffer = kv_buffer.unbind(dim=1) num_pairs = len(query_lens) cum_query_len = 0 @@ -81,8 +81,10 @@ def multi_query_cached_kv_attention( value_cache, slots[cum_kv_len:cum_kv_len + kv_len], ) + q_buffer = query_buffer[:query_len] + q_buffer.copy_(query[cum_query_len:cum_query_len + query_len]) _flash_attn_forward( - query[cum_query_len:cum_query_len + query_len], + q_buffer, key_buffer[:kv_len], value_buffer[:kv_len], output[cum_query_len:cum_query_len + query_len], diff --git a/cacheflow/worker/worker.py b/cacheflow/worker/worker.py index 4117a494e3151..b71d4bdad781b 100644 --- a/cacheflow/worker/worker.py +++ b/cacheflow/worker/worker.py @@ -170,7 +170,7 @@ def prepare_inputs( num_prefix_blocks = num_prefix_tokens // self.block_size input_tokens.extend(prompt_tokens) - input_positions.extend(range(num_prefix_tokens + prompt_len)) + input_positions.extend(range(num_prefix_tokens, num_prefix_tokens + prompt_len)) # Compute the slot mapping. block_table = input_seq_group.block_tables[seq_id]