Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] several features for veRL integration #2736

Open
2 tasks done
PeterSH6 opened this issue Jan 5, 2025 · 3 comments
Open
2 tasks done

[Feature] several features for veRL integration #2736

PeterSH6 opened this issue Jan 5, 2025 · 3 comments
Assignees
Labels
enhancement New feature or request

Comments

@PeterSH6
Copy link

PeterSH6 commented Jan 5, 2025

Checklist

Motivation

TL;DR: Introducing several features that would be beneficial for integrating SGLang into veRL and may also be beneficial for other Post-Training frameworks.

Provide an inference script that is started by torchrun (support SPMD)

Currently, the offline inference script is launched by sgl.Engine. Internally, it spawns multiple Scheduler.
With torchrun, the Scheduler is launched by torchrun and the tp_rank can be obtained from the environ.
In veRL, the Data Parallel dimension is managed by our WorkerGroup and the dp_rank of each Scheduler should be None.
More specifically, if the current WorkerGroup has 8 GPUs while we set the Rollout TP size to 2. All the GPUs in this WorkerGroup will build the distributed world and the generation engine and training engine will construct its own TP/PP groups. veRL's data_protocol will partition and dispatch the prompts to each TP/PP group without the generation engine is aware of the DP dimension.

A general picture of a torchrun script that can simulate the HybridEngine behavior.

# build distributed world
local_rank, rank, world_size = initialize_global_process_group()
# build device mesh for training engine.
device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=['fsdp'])
fsdp_model = FSDP(actor_model,
                       ...
                      device_mesh=device_mesh) 
FSDP.set_state_dict_type(fsdp_model,
                             state_dict_type=StateDictType.SHARDED_STATE_DICT,
                             state_dict_config=ShardedStateDictConfig())
# get sharded model state dict
state_dict = fsdp_model.state_dict()

# [Optional] build device mesh for inference engine
gen_device_mesh = init_device_mesh('cuda', mesh_shape=(2, 4), mesh_dim_names=['dp', 'tp'])
# build inference engine
inference_engine = SGLEngine(model_hf_config=actor_model_config,
              tensor_parallel_size=tensor_model_parallel_size,
              pipeline_parallel_size=pipeline_parallel_size, # if any
              enforce_eager=False, # use cuda graph with offload KVCache and weight
              dtype='bfloat16',
              load_format='dummy_dtensor', # initialize dummy weight
              gpu_memory_utilization=0.1,
              trust_remote_code=True)

# [Optional] update parallel state in SGLang for 3D-HybridEngine
inference_engine.update_parallel_state(TP=device_mesh["tp"])

# sync weights between actor and rollout, support several format: DTensor and Megatron (sharded)
inference_engine.sync_model_weights(actor_weights=state_dict, load_format='dtensor')

# generate sequence, it would be better if the output is a list of Tensor not list of list[str]
outputs = lnference_engine.generate(prompt_token_ids=idx_list, sampling_params=sampling_params, use_tqdm=False)

# offload kvcache after generation
inference_engine.free_kvcache() # inference_engine.init_kvcache()

# offload model
inference_engine.offload_model_weights() # inference_engine.load_model_weights(), we can simply re-init them

Expose an API that can load weights in TP/PP format

inference_engine.sync_model_weights(actor_weights=state_dict, load_format='dtensor') in the above code.
We may need two different load formats with different weight loaders:

  • dtensor: The SGLang model weight is sharded, our state_dict is sharded in different ways but we gather them layer-by-layer and feed them into the SGLang weight loader for synchronization.
  • megatron sharded: The SGLang model weight is sharded, verl hybrid engine prepares a state_dict that is identical to SGLang's sharded weight. Therefore, the SGLang model can directly copy the weights in place without any further sharding.

Expose an API that can free/re-init kv cache, and offload/load model weights

inference_engine.free_kvcache() and inference_engine.init_kvcache() ; inference_engine.offload_model_weights() and inference_engine.load_model_weights()
It would be better to support CUDAGraph although we offload kvcache and model weights. Reference: #2542

Disable detokenize during generation.

In RL training, we only need token_ids in most training scenarios and we can perform batch detokenize when we really need tokens. We don't care about the ITL metric.
After being disabled, we can check whether there are any opportunities to improve the throughput

3D-HybridEngine parallel state construction (TP/PP group generation logic should be different from Megatron-LM when using 3D-HybridEngine)

With our 3D-HybridEngine design in paper and code, the grouping strategy for TP/PP in SGLang shall be aware of the TP/PP size in training framework.
We consider that SGLang is not necessarily to be aware of the TP/PP size in the training framework.
So, we can build the TP/PP groups for SGLang before SGLang initialization and then update these TP/PP groups to the SGLEngine. See [Optional] in the above code.

Output post-process to torch.Tensor (token_ids).

A small feature, if not supported, we can implement some post-process in veRL. No worries.

Related resources

No response

@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Jan 6, 2025

Quick update: A POC that can run with TP=4 on 4 GPU cards.

The code is super hacky - will rigorously do refactors on SGLang later.

Code: https://github.com/fzyzcjy/sglang/tree/feat/add_verl, more specifically https://github.com/fzyzcjy/sglang/blob/feat/add_verl/examples/runtime/engine/offline_batch_inference_torchrun.py

Experiment: Run llama 70B on 4 GPUs. (If the code is buggy such that it does not enable TP, then we will see OOM).

Output:

hf_tokenizer.decode(out.decode_ids[0])=' sunny day and I like the sun. It seems as if everything in my life'

@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Jan 6, 2025

Quick update:

Refactor in progress: #2747

Question: Is the following API looks good?

As we know, users originally create one sgl.Engine(). Now, users create multiple sgl.EngineFragment() (name to be determined), one instance per TP rank.
After that, users should use EngineFragment in the same way as Engine, e.g. call generate() and get exactly the same output type and content.
In other words, the only thing that changes is the class name as well as several extra arguments (e.g. tp_rank).

An alternative API is to allow exposing different output types or doing hacky conversions. For example, by directly exposing Scheduler class with some kind of thin wrapping. That would be faster to implement, but in my humble opinion maybe a bit uglier. So I personally like the proposed one above.

@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Jan 6, 2025

Quick update: PR to SGLang that seems to support TP.

Code: https://github.com/fzyzcjy/sglang/tree/feat/process_coordinator, more specifically https://github.com/fzyzcjy/sglang/blob/feat/process_coordinator/examples/runtime/engine/offline_batch_inference_torchrun.py

Draft PR: #2749

It seems to work now, but I will need to do more checks later. llama 70B on 4xH100 outputs something seemingly reasonable:

image

@fzyzcjy fzyzcjy mentioned this issue Feb 26, 2025
6 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants