Skip to content

Commit

Permalink
Optimize data movement (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Apr 2, 2023
1 parent 1f01a18 commit 897cb2a
Show file tree
Hide file tree
Showing 17 changed files with 275 additions and 135 deletions.
20 changes: 20 additions & 0 deletions cacheflow/models/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
import torch.nn as nn

from cacheflow import activation_ops


class SiluAndMul(nn.Module):

def __init__(self):
super().__init__()

def forward(
self,
x: torch.Tensor, # (num_tokens, 2 * d)
) -> torch.Tensor: # (num_tokens, d)
num_tokens = x.shape[0]
d = x.shape[1] // 2
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
activation_ops.silu_and_mul(out, x)
return out
92 changes: 46 additions & 46 deletions cacheflow/models/attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Optional
from typing import Optional

from flash_attn.flash_attention import FlashAttention
from flash_attn.flash_attn_interface import _flash_attn_forward
import torch
import torch.nn as nn

Expand All @@ -16,40 +16,38 @@ def __init__(self, scale: float) -> None:
super().__init__()
self.scale = float(scale)

self.flash_attn = FlashAttention(softmax_scale=self.scale)

def multi_query_kv_attention(
self,
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
prompt_lens: List[int],
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
cumulative_prompt_lens: torch.Tensor, # [num_prompts + 1]
max_prompt_len: int,
) -> None:
if query.dtype == torch.float:
raise ValueError('The float data type is not supported by '
'FlashAttention. Use the half data type instead.')
head_size = query.shape[2]
head_size = query.shape[-1]
if head_size > 128:
raise ValueError('FlashAttention does not support head_size > 128.')

device = query.device
prefix_sum = [0]
for prompt_len in prompt_lens:
prefix_sum.append(prefix_sum[-1] + prompt_len)
prefix_sum = torch.tensor(prefix_sum, dtype=torch.int, device=device)
max_prompt_len = max(prompt_lens)

# FIXME(woosuk): Unnecessary copy. Optimize this.
qkv = torch.stack([query, key, value], dim=1)
out = self.flash_attn(
qkv,
cu_seqlens=prefix_sum,
max_s=max_prompt_len,
# Directly call FlashAttention's internal function to avoid allocating
# a new tensor for the output.
_flash_attn_forward(
query,
key,
value,
output,
cumulative_prompt_lens,
cumulative_prompt_lens,
max_prompt_len,
max_prompt_len,
dropout_p=0.0,
softmax_scale=self.scale,
causal=True,
)[0]
# FIXME(woosuk): Unnecessary copy. Optimize this.
output.copy_(out, non_blocking=True)
return_softmax=False,
)

def single_query_cached_kv_attention(
self,
Expand Down Expand Up @@ -90,21 +88,18 @@ def forward(
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
# Pre-allocate the output tensor.
output = torch.empty_like(query)

# Prune out paddings if any.
query = query[:input_metadata.num_valid_tokens]
key = key[:input_metadata.num_valid_tokens]
value = value[:input_metadata.num_valid_tokens]
# NOTE: The query, key, and value tensors must be sliced from a qkv
# tensor of shape [num_tokens, 3 * num_heads * head_size].

# Reshape the input tensors.
# Reshape the query, key, and value tensors.
num_heads = value_cache.shape[1]
head_size = value_cache.shape[2]
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_heads, head_size)
value = value.view(-1, num_heads, head_size)
output = output.view(-1, num_heads, head_size)

# Pre-allocate the output tensor.
output = torch.empty_like(query)

# Compute the attention op for prompts.
num_prompt_tokens = input_metadata.num_prompt_tokens
Expand All @@ -114,22 +109,31 @@ def forward(
query[:num_prompt_tokens],
key[:num_prompt_tokens],
value[:num_prompt_tokens],
input_metadata.prompt_lens,
input_metadata.cumulative_prompt_lens,
input_metadata.max_prompt_len,
)

# Wait until the cache op is done.
if cache_event is not None:
cache_event.wait()

# Reshape the keys and values and store them in the cache.
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, input_metadata.slot_mapping)
num_valid_tokens = input_metadata.num_valid_tokens
if num_valid_tokens > 0:
# The stride is 3 because the key and value are sliced from qkv.
cache_ops.reshape_and_cache(
key[:num_valid_tokens],
value[:num_valid_tokens],
key_cache,
value_cache,
input_metadata.slot_mapping,
)

if input_metadata.num_generation_tokens > 0:
# Compute the attention op for generation tokens.
self.single_query_cached_kv_attention(
output[num_prompt_tokens:],
query[num_prompt_tokens:],
output[num_prompt_tokens:num_valid_tokens],
query[num_prompt_tokens:num_valid_tokens],
key_cache,
value_cache,
input_metadata)
Expand Down Expand Up @@ -186,19 +190,15 @@ def forward(
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
# Apply rotary embedding to the query and key before passing them
# to the attention op.
out_query = torch.empty_like(query)
out_key = torch.empty_like(key)
pos_encoding_ops.rotary_embedding_neox(
out_query,
out_key,
positions,
query,
key,
self.cos_sin_cache,
)
return super().forward(
out_query,
out_key,
query,
key,
value,
key_cache,
value_cache,
Expand Down
5 changes: 5 additions & 0 deletions cacheflow/models/input_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def __init__(
seq_groups: List[Tuple[List[int], SamplingParams]],
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
prompt_lens: List[int],
cumulative_prompt_lens: torch.Tensor,
slot_mapping: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
Expand All @@ -20,13 +21,15 @@ def __init__(
self.seq_groups = seq_groups
self.seq_logprobs = seq_logprobs
self.prompt_lens = prompt_lens
self.cumulative_prompt_lens = cumulative_prompt_lens
self.slot_mapping = slot_mapping
self.context_lens = context_lens
self.max_context_len = max_context_len
self.block_tables = block_tables

self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens)
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
self.num_generation_tokens = context_lens.shape[0]
self.num_valid_tokens = slot_mapping.shape[0]
if block_tables.numel() > 0:
Expand All @@ -40,11 +43,13 @@ def __repr__(self) -> str:
return (f'InputMetadata('
f'num_prompts={self.num_prompts}, '
f'num_prompt_tokens={self.num_prompt_tokens}, '
f'max_prompt_len={self.max_prompt_len}, '
f'num_generation_tokens={self.num_generation_tokens}, '
f'num_valid_tokens={self.num_valid_tokens}, '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'max_context_len={self.max_context_len}), '
f'prompt_lens={self.prompt_lens}, '
f'cumulative_prompt_lens={self.cumulative_prompt_lens}, '
f'slot_mapping={self.slot_mapping}, '
f'context_lens={self.context_lens}, '
f'block_tables={self.block_tables})')
19 changes: 7 additions & 12 deletions cacheflow/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from transformers import LlamaConfig

from cacheflow.models import InputMetadata
from cacheflow.models.activation import SiluAndMul
from cacheflow.models.attention import LlamaCacheFlowAttention
from cacheflow.models.layernorm import RMSNorm
from cacheflow.models.sample import Sampler
Expand Down Expand Up @@ -39,16 +40,14 @@ def __init__(
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
bias=False, input_is_parallel=True,
perform_initialization=False)
assert hidden_act == 'silu'
self.act_fn = nn.SiLU()
if hidden_act != 'silu':
raise ValueError(f'Unsupported activation: {hidden_act}. '
'Only silu is supported for now.')
self.act_fn = SiluAndMul()

def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
gate_up = gate_up.reshape(gate_up.shape[:-1] + (2, -1))
gate, up = torch.split(gate_up, 1, dim=-2)
gate = gate.squeeze(dim=-2).contiguous()
up = up.squeeze(dim=-2).contiguous()
x = self.act_fn(gate) * up
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x

Expand Down Expand Up @@ -94,11 +93,7 @@ def forward(
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
qkv = qkv.reshape(qkv.shape[:-1] + (3, -1))
q, k, v = torch.split(qkv, 1, dim=-2)
q = q.squeeze(dim=-2).contiguous()
k = k.squeeze(dim=-2).contiguous()
v = v.squeeze(dim=-2).contiguous()
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
Expand Down
7 changes: 2 additions & 5 deletions cacheflow/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,14 @@ def forward(
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
qkv = qkv.reshape(qkv.shape[:-1] + (3, -1))
q, k, v = torch.split(qkv, 1, dim=-2)
q = q.squeeze(dim=-2).contiguous()
k = k.squeeze(dim=-2).contiguous()
v = v.squeeze(dim=-2).contiguous()
q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache
attn_output = self.attn(
q, k, v, key_cache, value_cache, input_metadata, cache_event)
output, _ = self.out_proj(attn_output)
return output


class OPTDecoderLayer(nn.Module):

def __init__(self, config: OPTConfig):
Expand Down
8 changes: 8 additions & 0 deletions cacheflow/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ def prepare_inputs(
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)

cumulative_prompt_lens: List[int] = [0]
for prompt_len in prompt_lens:
cumulative_prompt_lens.append(
cumulative_prompt_lens[-1] + prompt_len)

# Add generation tokens.
max_context_len = 0
max_num_blocks_per_seq = 0
Expand Down Expand Up @@ -183,11 +188,14 @@ def prepare_inputs(
for block_table in generation_block_tables]
block_tables_tensor = torch.tensor(
padded_block_tables, dtype=torch.int, device='cuda')
cumulative_prompt_lens_tensor = torch.tensor(
cumulative_prompt_lens, dtype=torch.int, device='cuda')

input_metadata = InputMetadata(
seq_groups=seq_groups,
seq_logprobs=seq_logprobs,
prompt_lens=prompt_lens,
cumulative_prompt_lens=cumulative_prompt_lens_tensor,
slot_mapping=slot_mapping_tensor,
context_lens=context_lens_tensor,
max_context_len=max_context_len,
Expand Down
12 changes: 12 additions & 0 deletions csrc/activation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#include <torch/extension.h>

void silu_and_mul(
torch::Tensor& out,
torch::Tensor& input);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"silu_and_mul",
&silu_and_mul,
"Activation function used in SwiGLU.");
}
46 changes: 46 additions & 0 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

namespace cacheflow {

template<typename T>
__device__ __forceinline__ T silu(const T& x) {
// x * sigmoid(x)
return (T) (((float) x) / (1.0f + expf((float) -x)));
}

template<typename scalar_t>
__global__ void silu_and_mul_kernel(
scalar_t* __restrict__ out, // [num_tokens, d]
const scalar_t* __restrict__ input, // [num_tokens, 2, d]
const int d) {
const int token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = silu(x) * y;
}
}

} // namespace cacheflow

void silu_and_mul(
torch::Tensor& out, // [num_tokens, d]
torch::Tensor& input) // [num_tokens, 2 * d]
{
int num_tokens = input.size(0);
int d = input.size(1) / 2;

dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(),
"silu_and_mul_kernel",
[&] {
cacheflow::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
d);
});
}
Loading

0 comments on commit 897cb2a

Please sign in to comment.