Skip to content

Commit

Permalink
Update sp layout (#3)
Browse files Browse the repository at this point in the history
* update layout

* bug fix
  • Loading branch information
ZYHowell authored Jul 28, 2024
1 parent a11bc61 commit 4b8203a
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 78 deletions.
171 changes: 95 additions & 76 deletions python/sglang/srt/managers/controller/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,9 @@ def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor
# Handle prefix
# Note: The flatten input ids with Sequence Parallel is in form of:
# [req_0_sp_0, req_1_sp_0, ... req_n_sp_0,
# req_0_sp_1, ..., req_n_sp_1, padding_sp_1,
# req_0_sp_1, req_1_sp_1, ..., req_n_sp_1,
# ...
# req_0_sp_m, req_0_padding, req_1_sp_m, req_1_padding, ...]
# ]
# The padding is for collection primitives which needs each candidate to
# have the same size. Since we don't expect too many requests in SP,
Expand All @@ -354,11 +355,21 @@ def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor
)

req_pool_indices_cpu = req_pool_indices.cpu().numpy()
num_padding_tokens = _get_num_padding_tokens(
self.sp_size, np.asarray([len(ids) for ids in input_ids])
)
for i in range(bs):
for sp_rank in range(self.sp_size):
ids = input_ids[i]
local_slice = _get_local_token_slices(sp_rank, self.sp_size, len(ids))
flatten_input_ids[sp_rank].extend(ids[local_slice])
local_slice = _get_local_token_slices_new(
sp_rank, self.sp_size, len(ids)
)
try:
flatten_input_ids[sp_rank].extend(ids[local_slice])
except TypeError as e:
print(local_slice, sp_rank, self.sp_size, len(ids))
raise e
flatten_input_ids[-1].extend([0] * num_padding_tokens[i])
extend_lens.append(len(input_ids[i]))

if len(prefix_indices[i]) == 0:
Expand All @@ -370,11 +381,8 @@ def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor
] = prefix_indices[i]

seq_lens.append(prefix_lens[-1] + extend_lens[-1])
# For sequence parallel, add padding zeros for each rank.
padded_sp_len = max(len(ids) for ids in flatten_input_ids)
for flatten_ids in flatten_input_ids:
if len(flatten_ids) < padded_sp_len:
flatten_ids.extend([0] * (padded_sp_len - len(flatten_ids)))
# Already padded at the last shard of each request.
padded_sp_len = len(flatten_input_ids[0])
flatten_input_ids = list(itertools.chain(*flatten_input_ids))
self.padded_sp_len = padded_sp_len

Expand All @@ -386,7 +394,7 @@ def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor
if self.sp_size > 1:
extend_seq_lens = seq_lens - prefix_lens
# FIXME(yonghao): _extend_num_tokens -> extend_num_tokens once kv cache store is ready for SP
extend_local_token_nums = _get_local_token_nums(
extend_local_token_nums = _get_local_token_nums_new(
self.sp_rank, self.sp_size, extend_seq_lens
)
_extend_num_tokens = int(np.sum(extend_local_token_nums))
Expand Down Expand Up @@ -804,7 +812,6 @@ class InputMetadata:
# For Sequence Parallel
sp_rank: int = None
sp_size: int = None
local_token_indices: np.ndarray = None
sp_to_normal_indices: np.ndarray = None
sp_local_token_length: int = None
_debug_normal_to_sp_metadata: Optional[List[np.ndarray]] = None
Expand Down Expand Up @@ -877,28 +884,29 @@ def create(
prefix_lens = prefix_lens if prefix_lens is not None else 0
extend_seq_lens_cpu = (seq_lens - prefix_lens).cpu().numpy()
if forward_mode == ForwardMode.DECODE:
local_token_indices = get_decode_indices(
sp_rank, sp_size, extend_seq_lens_cpu
)
sp_to_normal_indices = sp_to_normal_indices_decode(
sp_size, extend_seq_lens_cpu, padded_sp_len
)
else:
extend_start_loc_cpu = extend_start_loc.cpu().numpy()
local_token_indices = get_prefill_indices(
sp_rank, sp_size, extend_seq_lens_cpu, extend_start_loc_cpu
_debug_normal_to_sp_metadata = _debug_normal_to_sp_indices_decode(
sp_size, extend_seq_lens_cpu
)
sp_local_token_length = get_decode_indices(
sp_rank, sp_size, extend_seq_lens_cpu
).size
else:
sp_to_normal_indices = sp_to_normal_indices_prefill(
sp_size, extend_seq_lens_cpu, padded_sp_len
sp_size, extend_seq_lens_cpu
)
_debug_normal_to_sp_metadata = _debug_normal_to_sp_indices_prefill(
sp_size, extend_seq_lens_cpu
)
sp_local_token_length = _get_local_token_nums_new(
sp_rank, sp_size, extend_seq_lens_cpu
)
_debug_normal_to_sp_metadata = _debug_normal_to_sp_indices(
forward_mode, sp_size, extend_seq_lens_cpu, padded_sp_len
)
else:
local_token_indices = np.arange(positions.numel())
sp_to_normal_indices = np.arange(positions.numel())
_debug_normal_to_sp_metadata = None
sp_local_token_length = len(local_token_indices)
sp_local_token_length = positions.numel()

ret = cls(
forward_mode=forward_mode,
Expand All @@ -920,7 +928,6 @@ def create(
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
sp_rank=sp_rank,
sp_size=sp_size,
local_token_indices=local_token_indices,
sp_to_normal_indices=sp_to_normal_indices,
sp_local_token_length=sp_local_token_length,
_debug_normal_to_sp_metadata=_debug_normal_to_sp_metadata,
Expand Down Expand Up @@ -1026,65 +1033,54 @@ def init_triton_args(forward_mode, seq_lens, prefix_lens):
return max_seq_len, max_extend_len, start_loc, prefix_lens


def get_prefill_indices(
sp_rank, sp_size, extend_seq_lens: np.ndarray, extend_start_loc
def _get_local_token_nums_new(
sp_rank, sp_size, extend_seq_lens: Union[int, np.ndarray]
):
"""
Get indices from the normal layout to the sequence parallel layout of all
requests.
"""
# For the first few ranks, they have one more token to compute
sp_req_lens = _get_local_token_nums(sp_rank, sp_size, extend_seq_lens)
# the offset of each request in the batch. Only the first few ranks may get
# 1 more token (for each). For sp_rank=r, therere r peers ahread (0-based),
# each will get one token
sp_in_req_offset = extend_seq_lens // sp_size * sp_rank + np.clip(
extend_seq_lens % sp_size, a_min=None, a_max=sp_rank
)
sp_req_start = extend_start_loc + sp_in_req_offset
sp_indices = np.concatenate(
[np.arange(s, s + l) for s, l in zip(sp_req_start, sp_req_lens)]
"""Get the number of tokens in this SP. Padding is not considered."""
padded_size = np.ceil(extend_seq_lens / sp_size).astype(np.int32)
return (
padded_size
if sp_rank != sp_size - 1
else extend_seq_lens - (sp_size - 1) * padded_size
)
return sp_indices


def _get_local_token_nums(sp_rank, sp_size, extend_seq_lens: Union[int, np.ndarray]):
"""Get the number of tokens in this SP. Padding is not considered."""
has_remainder = (extend_seq_lens % sp_size) > sp_rank
return extend_seq_lens // sp_size + has_remainder
def _get_num_padding_tokens(sp_size, extend_seq_lens: np.ndarray):
"""Get the number of tokens padded for SP."""
padded_size = np.ceil(extend_seq_lens / sp_size).astype(np.int32)
return sp_size * padded_size - extend_seq_lens


def get_decode_indices(sp_rank, sp_size, seq_lens: np.ndarray, offset=0):
def get_decode_indices(sp_rank, sp_size, seq_lens: np.ndarray):
"""Get Indices from the normal layout to the sequence parallel layout."""
return np.nonzero((seq_lens % sp_size) == sp_rank)[0] + offset
return np.nonzero((seq_lens % sp_size) == sp_rank)[0]


def _get_local_token_slices(sp_rank, sp_size, seq_len: int):
def _get_local_token_slices_new(sp_rank, sp_size, seq_len: int):
"""Get the SP local slice for a single request's extended input ids."""
start = seq_len // sp_size * sp_rank + min(seq_len % sp_size, sp_rank)
length = _get_local_token_nums(sp_rank, sp_size, seq_len)
start = int(np.ceil(seq_len / sp_size) * sp_rank)
length = _get_local_token_nums_new(sp_rank, sp_size, seq_len)
return slice(start, start + length)


def sp_to_normal_indices_prefill(
sp_size, extend_seq_lens: np.ndarray, padded_sp_len: int
):
def sp_to_normal_indices_prefill(sp_size, extend_seq_lens: np.ndarray):
"""
Indices from the Sequence Parallel layout (padded) to the normal layout.
"""
sp_seq_lens = np.ceil(extend_seq_lens / sp_size).astype(np.int32)
sp_len = np.sum(sp_seq_lens)
sp_seq_offset = np.concatenate(
[np.asarray([0], dtype=np.int32), np.cumsum(sp_seq_lens[:-1])]
)
sp_arange = np.arange(sp_size).reshape(-1, 1)
indices = []
sp_offset = [padded_sp_len * sp_rank for sp_rank in range(sp_size)]
sp_local_token_nums = [
_get_local_token_nums(sp_rank, sp_size, extend_seq_lens)
for sp_rank in range(sp_size)
]
for req_id in range(len(extend_seq_lens)):
for sp_rank in range(sp_size):
sp_len = int(sp_local_token_nums[sp_rank][req_id])
sp_my_offset = sp_offset[sp_rank]
indices.extend(range(sp_my_offset, sp_my_offset + sp_len))
sp_offset[sp_rank] += sp_len
return np.asarray(indices)
for i in range(len(extend_seq_lens)):
sp_idx = np.arange(sp_seq_lens[i]).reshape(1, -1).repeat(sp_size, axis=0)
sp_idx = (sp_idx + sp_seq_offset[i] + sp_len * sp_arange).reshape(-1)
sp_idx = sp_idx[: extend_seq_lens[i]]
indices.append(sp_idx)
indices = np.concatenate(indices)
return indices


def sp_to_normal_indices_decode(sp_size, seq_lens_cpu: np.ndarray, padded_sp_len: int):
Expand All @@ -1100,19 +1096,42 @@ def sp_to_normal_indices_decode(sp_size, seq_lens_cpu: np.ndarray, padded_sp_len
return req_sp_offset


def _debug_normal_to_sp_indices(mode, sp_size, seq_lens, sp_padded_len):
def _debug_normal_to_sp_indices_decode(sp_size, seq_lens):
"""(Debug only) Indices from normal layout to the SP layout (padded)."""
get_indices_fn = (
get_decode_indices if mode == ForwardMode.DECODE else get_prefill_indices
)
offset = (
0
if mode == ForwardMode.DECODE
else np.concatenate([np.asarray([0], dtype=np.int32), np.cumsum(seq_lens[:-1])])
)
indices = [
get_indices_fn(sp_rank, sp_size, seq_lens, offset) for sp_rank in range(sp_size)
get_decode_indices(sp_rank, sp_size, seq_lens) for sp_rank in range(sp_size)
]
indices = [(np.arange(len(idxs)), idxs) for idxs in indices]
return indices


def _debug_normal_to_sp_indices_prefill(sp_size, seq_lens):
"""(Debug only) Indices from normal layout to the SP layout (padded)."""
indices = []
sp_seq_lens = np.ceil(seq_lens / sp_size).astype(np.int32)
seq_offset = np.concatenate(
[np.asarray([0], dtype=np.int32), np.cumsum(seq_lens[:-1])]
)
sp_seq_offset = np.concatenate(
[np.asarray([0], dtype=np.int32), np.cumsum(sp_seq_lens[:-1])]
)
for sp_rank in range(sp_size):
start_idx = seq_offset + sp_seq_lens * sp_rank
end_idx = np.minimum(seq_offset + sp_seq_lens * (sp_rank + 1), seq_lens)
normal_layout_idx = np.concatenate(
[np.arange(start_idx[i], end_idx[i]) for i in range(len(seq_lens))]
)
if sp_rank == sp_size - 1:
length = end_idx - start_idx
target_layout_idx = np.concatenate(
[
np.arange(sp_seq_offset[i], sp_seq_offset[i] + length[i])
for i in range(len(seq_lens))
]
)
else:
target_layout_idx = np.arange(len(normal_layout_idx))
indices.append((target_layout_idx, normal_layout_idx))
return indices


Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/models/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def forward(
input_metadata.sp_size, -1, *ori_hidden_states.shape[1:]
)
for sp_rank, idxs in enumerate(input_metadata._debug_normal_to_sp_metadata):
sp_real_vals = output[idxs]
output_sp[sp_rank][: sp_real_vals.shape[0]] = sp_real_vals
tgt_idx, src_idx = idxs
output_sp[sp_rank][tgt_idx] = output[src_idx]
output = output_sp.reshape(ori_hidden_states.shape).contiguous()
return output

Expand Down

0 comments on commit 4b8203a

Please sign in to comment.