Skip to content

Commit

Permalink
Merge branch 'main' into min_p
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyls2002 authored Aug 21, 2024
2 parents 9df818f + d6aeb9f commit 99a6013
Show file tree
Hide file tree
Showing 24 changed files with 695 additions and 195 deletions.
15 changes: 7 additions & 8 deletions .github/pull_request_template.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
<!-- Thank you for your contribution, we really appreciate it. The following instructions will help improve your pull request and make it easier to receive feedback. If there are any items you don't understand, don't worry. Just submit the pull request and ask the maintainers for help. -->
<!-- Thank you for your contribution! We appreciate it. The following guidelines will help improve your pull request and facilitate feedback. If anything is unclear, don't hesitate to submit your pull request and ask the maintainers for assistance. -->

## Motivation

<!-- Please explain the motivation behind this PR and the goal you aim to achieve with it. -->
<!-- Explain the purpose of this PR and the goals it aims to achieve. -->

## Modification
## Modifications

<!-- Briefly describe the changes made in this PR. -->
<!-- Describe the changes made in this PR. -->

## Checklist

- [ ] Before submitting a PR for review, make sure it has passed verification in your local development environment **at least**.
- [ ] Ensure pre-commit `pre-commit run --all-files` or other linting tools are used to fix potential lint issues.
- [ ] Confirm that modifications are covered by complete unit tests. If not, please add more unit tests for correctness.
- [ ] Modify documentation as needed, such as docstrings or example tutorials.
- [ ] Format your code according to the [Contributor Guide](https://github.com/sgl-project/sglang/blob/main/docs/en/contributor_guide.md).
- [ ] Add unit tests as outlined in the [Contributor Guide](https://github.com/sgl-project/sglang/blob/main/docs/en/contributor_guide.md).
- [ ] Update documentation as needed, including docstrings or example tutorials.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,17 @@ docker run --gpus all \

### Method 4: Using docker compose

<details>
> This method is recommended if you plan to serve it as a service.
> A better approach is to use the [k8s-sglang-service.yaml](./docker/k8s-sglang-service.yaml).
1. Copy the [compose.yml](./docker/compose.yaml) to your local machine
2. Execute the command `docker compose up -d` in your terminal.
</details>

### Method 5: Run on Kubernetes or Clouds with SkyPilot

<details>
To deploy on Kubernetes or 12+ clouds, you can use [SkyPilot](https://github.com/skypilot-org/skypilot).

1. Install SkyPilot and set up Kubernetes cluster or cloud access: see [SkyPilot's documentation](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html).
Expand All @@ -113,7 +116,6 @@ run: |
--host 0.0.0.0 \
--port 30000
```
</details>
```bash
Expand All @@ -124,7 +126,7 @@ HF_TOKEN=<secret> sky launch -c sglang --env HF_TOKEN sglang.yaml
sky status --endpoint 30000 sglang
```
3. To further scale up your deployment with autoscaling and failure recovery, check out the [SkyServe + SGLang guide](https://github.com/skypilot-org/skypilot/tree/master/llm/sglang#serving-llama-2-with-sglang-for-more-traffic-using-skyserve).

</details>


### Common Notes
Expand Down
15 changes: 15 additions & 0 deletions python/sglang/check_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,17 @@ def get_gpu_topology():
return None


def get_hypervisor_vendor():
try:
output = subprocess.check_output(["lscpu"], text=True)
for line in output.split("\n"):
if "Hypervisor vendor:" in line:
return line.split(":")[1].strip()
return None
except:
return None


def check_env():
"""
Check and print environment information.
Expand All @@ -184,6 +195,10 @@ def check_env():
if gpu_topo:
env_info["NVIDIA Topology"] = gpu_topo

hypervisor_vendor = get_hypervisor_vendor()
if hypervisor_vendor:
env_info["Hypervisor vendor"] = hypervisor_vendor

ulimit_soft, _ = resource.getrlimit(resource.RLIMIT_NOFILE)
env_info["ulimit soft"] = ulimit_soft

Expand Down
9 changes: 8 additions & 1 deletion python/sglang/launch_server.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
"""Launch the inference server."""

import argparse
import os

from sglang.srt.server import launch_server
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_child_process

if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)

launch_server(server_args)
try:
launch_server(server_args)
except Exception as e:
raise e
finally:
kill_child_process(os.getpid(), including_parent=False)
29 changes: 17 additions & 12 deletions python/sglang/srt/hf_transformers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,12 @@ def get_tokenizer(
and kwargs.get("use_fast", True)
and tokenizer_name != _FAST_LLAMA_TOKENIZER
):
pass
# warnings.warn(
# "For some LLaMA V1 models, initializing the fast tokenizer may "
# "take a long time. To reduce the initialization time, consider "
# f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
# "tokenizer."
# )
warnings.warn(
"For some LLaMA V1 models, initializing the fast tokenizer may "
"take a long time. To reduce the initialization time, consider "
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
"tokenizer."
)
try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
Expand Down Expand Up @@ -234,6 +233,8 @@ def __init__(self, tokenizer_path):
}
assert tok_dict["word_split"] == "V1"

default_allowed_special = None

kwargs = {
"name": name,
"pat_str": tok_dict.get("pat_str", PAT_STR_B),
Expand All @@ -247,14 +248,18 @@ def __init__(self, tokenizer_path):
for bytes_list in tok_dict["default_allowed_special"]
]
)
else:
default_allowed_special = None
if "vocab_size" in tok_dict:
kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]

PAD = "<|pad|>"
EOS = "<|eos|>"
SEP = "<|separator|>"

DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP}

tokenizer = tiktoken.Encoding(**kwargs)
tokenizer._default_allowed_special = default_allowed_special or set()
tokenizer._default_allowed_special |= {"<|separator|>"}
tokenizer._control_tokens = DEFAULT_CONTROL_TOKENS

def encode_patched(
self,
Expand All @@ -271,14 +276,14 @@ def encode_patched(
self,
text,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
disallowed_special=(),
)

tokenizer.encode = functools.partial(encode_patched, tokenizer)

# Convert to HF interface
self.tokenizer = tokenizer
self.eos_token_id = tokenizer._special_tokens["<|eos|>"]
self.eos_token_id = tokenizer._special_tokens[EOS]
self.vocab_size = tokenizer.n_vocab
self.chat_template = Template(
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,18 @@


class SiluAndMul(CustomOp):
def __init__(self, **kwargs):
super().__init__()
self.is_lower_sm80 = torch.cuda.get_device_capability()[0] < 8

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
if self.is_lower_sm80:
return self.forward_native(x)

d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,15 @@ def __init__(
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.is_lower_sm80 = torch.cuda.get_device_capability()[0] < 8

def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.is_lower_sm80:
return self.forward_native(x, residual)

if residual is not None:
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
Expand Down
2 changes: 0 additions & 2 deletions python/sglang/srt/managers/controller_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,4 @@ def start_controller_process(
except Exception:
logger.error("Exception in ControllerMulti:\n" + get_exception_traceback())
finally:
for w in controller.workers:
os.kill(w.proc.pid, 9)
kill_parent_process()
2 changes: 0 additions & 2 deletions python/sglang/srt/managers/controller_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,4 @@ def start_controller_process(
except Exception:
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
finally:
for t in controller.tp_procs:
os.kill(t.pid, 9)
kill_parent_process()
5 changes: 5 additions & 0 deletions python/sglang/srt/managers/detokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
BatchEmbeddingOut,
BatchStrOut,
BatchTokenIDOut,
UpdateWeightReqOutput,
)
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
from sglang.srt.server_args import PortArgs, ServerArgs
Expand Down Expand Up @@ -84,6 +85,10 @@ async def handle_loop(self):
)
continue

if isinstance(recv_obj, UpdateWeightReqOutput):
self.send_to_tokenizer.send_pyobj(recv_obj)
continue

assert isinstance(recv_obj, BatchTokenIDOut)
bs = len(recv_obj.rids)

Expand Down
14 changes: 14 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,20 @@ class FlushCacheReq:
pass


@dataclass
class UpdateWeightReqInput:
# The model path with the new weights
model_path: str
# The format to load the weights
load_format: Optional[str] = None


@dataclass
class UpdateWeightReqOutput:
success: bool
message: str


@dataclass
class AbortReq:
# The request id
Expand Down
20 changes: 5 additions & 15 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
"""Meta data for requests and batches"""

import logging
import warnings
from dataclasses import dataclass
from typing import List, Optional, Union

Expand Down Expand Up @@ -275,7 +274,7 @@ def jump_forward_and_retokenize(self, jump_forward_str, next_state):

if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
# TODO(lsyin): fix token fusion
warnings.warn(
logger.warning(
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
)
return False
Expand Down Expand Up @@ -764,7 +763,7 @@ def merge(self, other: "ScheduleBatch"):
)
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])

def sample(self, logits: torch.Tensor, is_multi_node_tp=False):
def sample(self, logits: torch.Tensor):
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
# Post process logits
logits = logits.contiguous()
Expand Down Expand Up @@ -809,7 +808,7 @@ def sample(self, logits: torch.Tensor, is_multi_node_tp=False):
)

if not torch.all(success):
warnings.warn("Sampling failed, fallback to top_k=1 strategy")
logger.warning(f"Sampling failed. Fallback to top_k=1 strategy. {logits=}")
probs = probs.masked_fill(torch.isnan(probs), 0.0)
argmax_ids = torch.argmax(probs, dim=-1)
batch_next_token_ids = torch.where(
Expand All @@ -826,16 +825,6 @@ def sample(self, logits: torch.Tensor, is_multi_node_tp=False):

self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)

if is_multi_node_tp:
# If the tensor parallelism spans across multiple nodes, there is some indeterminism
# that can cause the TP workers to generate different tokens, so we need to
# sync here
torch.distributed.all_reduce(
batch_next_token_ids,
op=dist.ReduceOp.MIN,
group=get_tensor_model_parallel_group().device_group,
)

return batch_next_token_ids


Expand All @@ -858,7 +847,8 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
try:
sampled_index = torch.multinomial(probs_sort, num_samples=1)
except RuntimeError:
except RuntimeError as e:
logger.warning(f"Sampling error: {e}")
batch_next_token_ids = torch.zeros(
(probs_sort.shape[0],), dtype=torch.int32, device=probs.device
)
Expand Down
Loading

0 comments on commit 99a6013

Please sign in to comment.