From 5d6e9467d4624a66ca64b0714042cb032df72695 Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Fri, 10 Jan 2025 20:22:01 -0800 Subject: [PATCH] Cache controller for hierarchical caching (#2804) --- .../sglang/srt/managers/cache_controller.py | 307 ++++++++++++++++++ 1 file changed, 307 insertions(+) create mode 100644 python/sglang/srt/managers/cache_controller.py diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py new file mode 100644 index 00000000000..4560a270870 --- /dev/null +++ b/python/sglang/srt/managers/cache_controller.py @@ -0,0 +1,307 @@ +from __future__ import annotations + +""" +Copyright 2023-2025 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import logging +import threading +from queue import PriorityQueue, Queue +from typing import Optional + +import torch + +from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPoolHost + +logger = logging.getLogger(__name__) + + +class CacheOperation: + + counter = 0 + + def __init__( + self, + host_indices: torch.Tensor, + device_indices: torch.Tensor, + node_id: int, + priority: Optional[int] = None, + ): + self.host_indices = host_indices + self.device_indices = device_indices + self.node_ids = [node_id] + self.data = None + + self.id = CacheOperation.counter + CacheOperation.counter += 1 + # default priority is the order of creation + self.priority = priority if priority is not None else self.id + + def merge(self, other: "CacheOperation") -> None: + # multiple operations can be merged into a single operation for batch processing + self.host_indices = torch.cat([self.host_indices, other.host_indices]) + self.device_indices = torch.cat([self.device_indices, other.device_indices]) + self.priority = min(self.priority, other.priority) + self.node_ids.extend(other.node_ids) + + def __lt__(self, other: "CacheOperation"): + return self.priority < other.priority + + +class TransferBuffer: + """ + Overlapping buffer preparation and transfer operations to improve throughput. + """ + + def __init__(self, buffer_count: int = 3, max_buffer_size: int = 1000) -> None: + self.buffers = Queue(maxsize=buffer_count) + # todo: adjust the buffer size based on throughput profile of the system + self.max_buffer_size = max_buffer_size + + def full(self) -> bool: + return self.buffers.full() + + def empty(self) -> bool: + return self.buffers.empty() + + def put(self, item, block=True) -> None: + self.buffers.put(item, block=block) + + def get(self, block=True) -> Optional[CacheOperation]: + try: + return self.buffers.get(block=block) + except Exception as e: + logger.error(e) + + +class HiCacheController: + + def __init__( + self, + mem_pool_device: MHATokenToKVPool, + mem_pool_host: MLATokenToKVPoolHost, + write_policy: str = "write_through_selective", + ): + + self.mem_pool_device = mem_pool_device + self.mem_pool_host = mem_pool_host + self.write_policy = write_policy + + if write_policy not in [ + "write_through", + "write_through_selective", + "write_back", + ]: + raise ValueError(f"Invalid write policy: {write_policy}") + + self.write_queue = PriorityQueue() + self.load_queue = PriorityQueue() + + self.ack_write_queue = Queue() + self.ack_load_queue = Queue() + + self.write_buffer = TransferBuffer() + self.load_buffer = TransferBuffer() + + self.write_stream = torch.cuda.Stream() + self.load_stream = torch.cuda.Stream() + + self.write_thread = threading.Thread( + target=self.write_thread_func_buffer, daemon=True + ) + self.load_thread = threading.Thread( + target=self.load_thread_func_buffer, daemon=True + ) + self.write_thread.start() + self.load_thread.start() + + def write( + self, + device_indices: torch.Tensor, + priority: Optional[int] = None, + node_id: int = 0, + ) -> Optional[torch.Tensor]: + """ + Back up KV caches from device memory to host memory. + """ + host_indices = self.mem_pool_host.alloc(len(device_indices)) + if host_indices is None: + return None + self.write_queue.put( + CacheOperation(host_indices, device_indices, node_id, priority) + ) + self.mem_pool_host.protect_write(host_indices) + return host_indices + + def load( + self, + host_indices: torch.Tensor, + priority: Optional[int] = None, + node_id: int = 0, + ) -> Optional[torch.Tensor]: + """ + Load KV caches from host memory to device memory. + """ + device_indices = self.mem_pool_device.alloc(len(host_indices)) + if device_indices is None: + return None + self.load_queue.put( + CacheOperation(host_indices, device_indices, node_id, priority) + ) + self.mem_pool_host.protect_load(host_indices) + return device_indices + + def write_thread_func_direct(self): + """ + Directly write through KV caches to host memory without buffering. + """ + with torch.cuda.stream(self.write_stream): + while True: + try: + operation = self.write_queue.get(block=True) + operation.data = self.mem_pool_device.get_flat_data( + operation.device_indices + ) + self.mem_pool_host.transfer(operation.host_indices, operation.data) + self.mem_pool_host.complete_io(operation.host_indices) + for node_id in operation.node_ids: + self.ack_write_queue.put(node_id) + except Exception as e: + logger.error(e) + + def load_thread_func_direct(self): + """ + Directly load KV caches from host memory to device memory without buffering. + """ + with torch.cuda.stream(self.load_stream): + while True: + try: + operation = self.load_queue.get(block=True) + operation.data = self.mem_pool_host.get_flat_data( + operation.host_indices + ) + self.mem_pool_device.transfer( + operation.device_indices, operation.data + ) + self.mem_pool_host.complete_io(operation.host_indices) + for node_id in operation.node_ids: + self.ack_load_queue.put(node_id) + except Exception as e: + logger.error(e) + + def write_aux_func(self, no_wait=False): + """ + Auxiliary function to prepare the buffer for write operations. + """ + buffer = None + while True: + try: + operation = self.write_queue.get(block=True) + if buffer is None: + buffer = operation + else: + buffer.merge(operation) + if ( + no_wait + or len(buffer.host_indices) >= self.write_buffer.max_buffer_size + or self.write_queue.empty() + or self.write_buffer.empty() + ): + assert ( + buffer.device_indices.is_cuda + ), "Device indices should be on GPU" + buffer.data = self.mem_pool_device.get_flat_data( + buffer.device_indices + ).contiguous() + self.write_buffer.put(buffer, block=True) + buffer = None + except Exception as e: + logger.error(e) + + def load_aux_func(self): + """ + Auxiliary function to prepare the buffer for load operations. + """ + buffer = None + while True: + try: + operation = self.load_queue.get(block=True) + if buffer is None: + buffer = operation + else: + buffer.merge(operation) + if ( + len(buffer.host_indices) >= self.load_buffer.max_buffer_size + or self.load_queue.empty() + or self.load_buffer.empty() + ): + buffer.data = ( + self.mem_pool_host.get_flat_data(buffer.host_indices) + .contiguous() + .pin_memory() + ) + self.load_buffer.put(buffer, block=True) + buffer = None + except Exception as e: + logger.error(e) + + def write_thread_func_buffer(self): + aux_thread = threading.Thread(target=self.write_aux_func, daemon=True) + aux_thread.start() + with torch.cuda.stream(self.write_stream): + while True: + operation = self.write_buffer.get() + if operation is None: + continue + self.mem_pool_host.transfer(operation.host_indices, operation.data) + self.mem_pool_host.complete_io(operation.host_indices) + for node_id in operation.node_ids: + self.ack_write_queue.put(node_id) + + def load_thread_func_buffer(self): + aux_thread = threading.Thread(target=self.load_aux_func, daemon=True) + aux_thread.start() + with torch.cuda.stream(self.load_stream): + while True: + operation = self.load_buffer.get() + if operation is None: + continue + self.mem_pool_device.transfer(operation.device_indices, operation.data) + self.mem_pool_host.complete_io(operation.host_indices) + for node_id in operation.node_ids: + self.ack_load_queue.put(node_id) + + def evict_device( + self, device_indices: torch.Tensor, host_indices: torch.Tensor + ) -> int: + if self.mem_pool_host.is_synced(host_indices): + self.mem_pool_device.free(device_indices) + self.mem_pool_host.update_backup(host_indices) + return len(device_indices) + else: + raise ValueError( + f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}" + ) + + def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int: + if not backup_only: + raise ValueError("Other eviction policies are not supported yet.") + + if self.mem_pool_host.is_backup(host_indices): + self.mem_pool_host.free(host_indices) + return len(host_indices) + else: + raise ValueError( + f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}" + )