diff --git a/examples/grpo_trainer/run_qwen2-7b_npu.sh b/examples/grpo_trainer/run_qwen2-7b_npu.sh new file mode 100644 index 00000000..e9a2d8f3 --- /dev/null +++ b/examples/grpo_trainer/run_qwen2-7b_npu.sh @@ -0,0 +1,41 @@ +set -x + + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=32 \ + data.val_batch_size=1312 \ + data.max_prompt_length=64 \ + data.max_response_length=128 \ + actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=20 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=160 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=0 \ + trainer.logger=['console','wandb'] \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='qwen2_7b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 3a447e05..8f154335 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,6 @@ dependencies = [ "ray>=2.10", "tensordict<0.6", "transformers", - "vllm<=0.6.3", 'wandb', ] diff --git a/requirements-npu.txt b/requirements-npu.txt new file mode 100644 index 00000000..0ad7f301 --- /dev/null +++ b/requirements-npu.txt @@ -0,0 +1,18 @@ +# requirements.txt records the full set of dependencies for development +accelerate +codetiming +datasets +dill +hydra-core +numpy +pandas +peft +pyarrow>=15.0.0 +pybind11 +pylatexenc +ray +tensordict<0.6 +transformers +wandb +vllm +vllm-ascend diff --git a/verl/bert_padding.py b/verl/bert_padding.py new file mode 100644 index 00000000..d7584beb --- /dev/null +++ b/verl/bert_padding.py @@ -0,0 +1,220 @@ +# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + + +class IndexFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + # return input[indices] + return torch.gather( + rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim) + ).reshape(-1, *other_shape) + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + grad_output = rearrange(grad_output, "b ... -> b (...)") + grad_input = torch.zeros( + [ctx.first_axis_dim, grad_output.shape[1]], + device=grad_output.device, + dtype=grad_output.dtype, + ) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + # grad_input[indices] = grad_output + grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis = IndexFirstAxis.apply + + +class IndexPutFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, values, indices, first_axis_dim): + ctx.save_for_backward(indices) + assert indices.ndim == 1 + assert values.ndim >= 2 + output = torch.zeros( + first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype + ) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + output[indices] = values + # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) + return output + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + grad_values = grad_output[indices] + # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) + return grad_values, None, None + + +index_put_first_axis = IndexPutFirstAxis.apply + + +class IndexFirstAxisResidual(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + output = input[indices] + # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last + # memory format to channel_first. In other words, input might not be contiguous. + # If we don't detach, Pytorch complains about output being a view and is being modified inplace + return output, input.detach() + + @staticmethod + def backward(ctx, grad_output, grad_residual): + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + assert grad_residual.shape[1:] == other_shape + grad_input = grad_residual + # grad_input[indices] += grad_output + indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1))) + indices = indices.expand_as(grad_output) + grad_input.scatter_add_(0, indices, grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis_residual = IndexFirstAxisResidual.apply + + +def unpad_input(hidden_states, attention_mask, unused_mask=None): + """ + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. + """ + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the + # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim + # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to + # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, + # so we write custom forward and backward to make it a bit faster. + return ( + index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + + +def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length): + """ + Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model). + The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286). + + For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is: + ``` + [ + [2, 3, 0, 0, 0, 0], + [3, 2, 0, 0, 0, 0], + [6, 0, 0, 0, 0, 0] + ] + ``` + , which refers to the 3D-attention mask: + ``` + [ + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0], + [0, 0, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 1] + ], + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 1] + ], + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1] + ] + ] + ```. + + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + """ + length = attention_mask_in_length.sum(dim=-1) + seqlen = attention_mask_in_length.size(-1) + attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), + seqlen) < length.unsqueeze( + 1) + real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten() + seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx] + indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the + # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim + # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to + # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, + # so we write custom forward and backward to make it a bit faster. + return ( + index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def pad_input(hidden_states, indices, batch, seqlen): + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[-1] + # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) + # output[indices] = hidden_states + output = index_put_first_axis(hidden_states, indices, batch * seqlen) + return rearrange(output, "(b s) ... -> b s ...", b=batch) \ No newline at end of file diff --git a/verl/single_controller/ray/base.py b/verl/single_controller/ray/base.py index 4763a20d..d2c81eff 100644 --- a/verl/single_controller/ray/base.py +++ b/verl/single_controller/ray/base.py @@ -22,6 +22,7 @@ from ray.experimental.state.api import get_actor from verl.single_controller.base import WorkerGroup, ResourcePool, ClassWithInitArgs, Worker +from verl.utils.device import is_cuda_available __all__ = ['Worker'] @@ -68,9 +69,10 @@ def get_placement_groups(self, strategy="STRICT_PACK", name=None): pg_name_prefix = name if name else \ f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:" # print(f"pg_name_prefix = {pg_name_prefix}") + device_name = "GPU" if is_cuda_available else "NPU" pg_scheme = [[{ "CPU": self.max_collocate_count, - "GPU": 1 + device_name: 1 } if self.use_gpu else { "CPU": self.max_collocate_count } for _ in range(process_count)] for process_count in self._store] @@ -160,8 +162,10 @@ def __call__(self, } options.update(self._options) - if use_gpu: + if use_gpu and is_cuda_available: options["num_gpus"] = num_gpus + if use_gpu and not is_cuda_available: + options["resources"] = {"NPU": num_gpus} if len(self._additional_resource) > 1: for k, v in self._additional_resource.items(): @@ -379,7 +383,7 @@ def world_size(self): def _bind_workers_method_to_parent(cls, key, user_defined_cls): """ - Binds the methods of each worker to the WorkerDict. + Binds the methods of each worker to the WorkerDict. Note that we only bind public methods that are decorated by register """ for method_name in dir(user_defined_cls): @@ -419,7 +423,7 @@ def _unwrap_ray_remote(cls): def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]): """ - This function should return a class instance that delegates the calls to every + This function should return a class instance that delegates the calls to every cls in cls_dict """ cls_dict = {} diff --git a/verl/third_party/vllm/vllm_spmd/dtensor_weight_loaders.py b/verl/third_party/vllm/vllm_spmd/dtensor_weight_loaders.py index a3042cab..2fb9f893 100644 --- a/verl/third_party/vllm/vllm_spmd/dtensor_weight_loaders.py +++ b/verl/third_party/vllm/vllm_spmd/dtensor_weight_loaders.py @@ -19,6 +19,7 @@ from torch.distributed._tensor import DTensor from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import is_pp_missing_parameter +from verl.utils.device import get_device_name def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: @@ -365,7 +366,7 @@ def load_dtensor_weights(actor_weights: Dict, vllm_model: nn.Module): weight_loader(actor_weights, vllm_model) # NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu # after init, and we need this after sync model weights for in first iter. - vllm_model = vllm_model.cuda() + vllm_model = vllm_model.to(get_device_name()) def _get_model_weight_loader(arch: str): diff --git a/verl/trainer/fsdp_sft_trainer.py b/verl/trainer/fsdp_sft_trainer.py index b715c8cd..036dc593 100644 --- a/verl/trainer/fsdp_sft_trainer.py +++ b/verl/trainer/fsdp_sft_trainer.py @@ -35,13 +35,14 @@ from verl.utils.torch_functional import get_cosine_schedule_with_warmup from tensordict import TensorDict from torch.utils.data import DataLoader, DistributedSampler -from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis +from verl.bert_padding import pad_input, unpad_input, rearrange, index_first_axis from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager from verl.utils.dataset import SFTDataset from verl.utils.fs import copy_local_path_from_hdfs from verl.utils.tracking import Tracking from verl.utils.ulysses import get_ulysses_sequence_parallel_world_size, set_ulysses_sequence_parallel_group +from verl.utils.device import get_device_name, is_cuda_available from torch.distributed.device_mesh import DeviceMesh import verl.utils.hdfs_io as hdfs_io @@ -106,6 +107,7 @@ def __init__(self, config, device_mesh: DeviceMesh, ulysses_device_mesh: DeviceM # TODO: add checkpoint manager if self.device_mesh.get_rank() == 0: print(self.config) + self.device = get_device_name() def _normalize_config_bsz(self): dp_size = self.device_mesh.size(0) if not self.ulysses_device_mesh else self.ulysses_device_mesh.size(0) @@ -210,7 +212,8 @@ def _build_model_optimizer(self): self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(local_model_path, config=config, torch_dtype=torch.float32, - attn_implementation='flash_attention_2', + attn_implementation='flash_attention_2' + if is_cuda_available else 'eager', trust_remote_code=trust_remote_code) # Apply Liger kernel if use_liger is enabled @@ -257,7 +260,8 @@ def _build_model_optimizer(self): mixed_precision=mixed_precision, device_mesh=self.device_mesh, sync_module_states=True, - device_id=torch.cuda.current_device(), + device_id=torch.cuda.current_device() if is_cuda_available else + torch.npu.current_device(), cpu_offload=cpu_offload, use_orig_params=False) @@ -289,16 +293,16 @@ def _compute_loss_and_backward(self, batch, do_backward=True): use_sp = self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1 # Move inputs to GPU and prepare loss mask - input_ids = batch['input_ids'].cuda() - attention_mask = batch['attention_mask'].cuda() - position_ids = batch['position_ids'].cuda() - loss_mask = batch.pop('loss_mask')[:, :-1].reshape(-1).cuda() + input_ids = batch['input_ids'].to(self.device) + attention_mask = batch['attention_mask'].to(self.device) + position_ids = batch['position_ids'].to(self.device) + loss_mask = batch.pop('loss_mask')[:, :-1].reshape(-1).to(self.device) loss_fct = nn.CrossEntropyLoss(reduction='none') # Context manager for sequence parallel if needed context = self.sharding_manager if use_sp else nullcontext() with context: - with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + with torch.autocast(device_type=self.device, dtype=torch.bfloat16): if not use_sp: # Standard forward pass without sequence parallel labels = input_ids[:, 1:].contiguous() @@ -412,7 +416,7 @@ def training_step(self, batch: TensorDict): log_gpu_memory_usage('After offload weights', logger=logger) - step_loss = torch.tensor(step_loss).cuda() + step_loss = torch.tensor(step_loss).to(self.device) torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG) return {'train/loss': step_loss.detach().item(), 'train/lr(1e-3)': lr * 1e3} @@ -468,7 +472,7 @@ def fit(self): for data in tqdm(self.train_dataloader, total=self.steps_per_epoch, desc=f"Epoch {epoch+1}/{self.config.trainer.total_epochs}"): - data = TensorDict(data, batch_size=self.config.data.train_batch_size).cuda() + data = TensorDict(data, batch_size=self.config.data.train_batch_size).to(self.device) metric = self.training_step(data) if rank == 0: tracking.log(data=metric, step=global_step) @@ -479,7 +483,7 @@ def fit(self): # Perform final validation val_losses = [] for val_data in self.val_dataloader: - val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).cuda() + val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).to(self.device) val_loss = self.validation_step(val_data) val_losses.append(val_loss) if rank == 0: @@ -495,7 +499,7 @@ def fit(self): # validation val_losses = [] for data in self.val_dataloader: - data = TensorDict(data, batch_size=self.config.data.micro_batch_size_per_gpu).cuda() + data = TensorDict(data, batch_size=self.config.data.micro_batch_size_per_gpu).to(self.device) val_loss = self.validation_step(data) val_losses.append(val_loss) if rank == 0: @@ -518,11 +522,12 @@ def fit(self): @hydra.main(config_path='config', config_name='sft_trainer', version_base=None) def main(config): + device_name = get_device_name() local_rank, rank, world_size = initialize_global_process_group() - device_mesh = init_device_mesh(device_type='cuda', mesh_shape=(world_size,), mesh_dim_names=('fsdp',)) + device_mesh = init_device_mesh(device_type=device_name, mesh_shape=(world_size,), mesh_dim_names=('fsdp',)) dp_size = world_size // config.ulysses_sequence_parallel_size - ulysses_device_mesh = init_device_mesh(device_type='cuda', + ulysses_device_mesh = init_device_mesh(device_type=device_name, mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=('dp', 'sp')) trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh) diff --git a/verl/utils/device.py b/verl/utils/device.py new file mode 100644 index 00000000..55344e5d --- /dev/null +++ b/verl/utils/device.py @@ -0,0 +1,75 @@ +# This code is inspired by the torchtune. +# https://github.com/pytorch/torchtune/blob/main/torchtune/utils/_device.py + +import os +import logging +from enum import Enum +from typing import Optional + +import torch + +logger = logging.getLogger(__name__) + + +def is_torch_npu_available() -> bool: + """Check the availability of NPU""" + try: + import torch_npu # noqa: F401 + + return torch.npu.is_available() + except ImportError: + return False + + +is_cuda_available = torch.cuda.is_available() +is_npu_available = is_torch_npu_available() + + +def get_device_name() -> str: + """Function that gets the torch.device based on the current machine. + This currently only supports CPU, CUDA, NPU. + Returns: + device + """ + if is_cuda_available: + device = "cuda" + elif is_npu_available: + device = "npu" + else: + device = "cpu" + return device + + +def get_device(device_name: Optional[str] = None) -> torch.device: + """Function that takes an optional device string, verifies it's correct and available given the machine and + distributed settings, and returns a :func:`~torch.device`. If device string is not provided, this function will + infer the device based on the environment. + If CUDA-like is available and being used, this function also sets the CUDA-like device. + Args: + device (Optional[str]): The name of the device to use, e.g. "cuda" or "cpu" or "npu". + Example: + >>> device = get_device("cuda") + >>> device + device(type='cuda', index=0) + Returns: + torch.device: Device + """ + if device_name is None: + device_name = get_device_name() + device = torch.device(device_name) + return device + + +def get_torch_device() -> any: + """Return the corresponding torch attribute based on the device type string. + Returns: + module: The corresponding torch device namespace, or torch.cuda if not found. + """ + device_name = get_device_name() + try: + return getattr(torch, device_name) + except AttributeError: + logger.warning( + f"Device namespace '{device_name}' not found in torch, try to load torch.cuda." + ) + return torch.cuda \ No newline at end of file diff --git a/verl/utils/distributed.py b/verl/utils/distributed.py index 6fea5a29..85446b93 100644 --- a/verl/utils/distributed.py +++ b/verl/utils/distributed.py @@ -13,16 +13,18 @@ # limitations under the License. """Utilities for distributed training.""" import os +from verl.utils.device import is_cuda_available def initialize_global_process_group(timeout_second=36000): import torch.distributed from datetime import timedelta - torch.distributed.init_process_group('nccl', timeout=timedelta(seconds=timeout_second)) + torch.distributed.init_process_group('nccl' if is_cuda_available else 'hccl', + timeout=timedelta(seconds=timeout_second)) local_rank = int(os.environ["LOCAL_RANK"]) rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) if torch.distributed.is_initialized(): - torch.cuda.set_device(local_rank) + torch.cuda.set_device(local_rank) if is_cuda_available else torch.npu.set_device(local_rank) return local_rank, rank, world_size diff --git a/verl/utils/flops_counter.py b/verl/utils/flops_counter.py index 3c5ac1a9..0b668a03 100644 --- a/verl/utils/flops_counter.py +++ b/verl/utils/flops_counter.py @@ -14,6 +14,7 @@ import torch from transformers import PretrainedConfig, Qwen2Config, LlamaConfig +from verl.utils.device import is_cuda_available VALID_CONFIG_TYPE = (Qwen2Config, LlamaConfig) @@ -30,7 +31,7 @@ def unit_convert(number, level): ptr += 1 return number - device_name = torch.cuda.get_device_name() + device_name = torch.cuda.get_device_name() if is_cuda_available else torch.npu.get_device_name() flops = float("inf") # INF flops for unkown gpu type if "H100" in device_name or "H800" in device_name: flops = 989e12 diff --git a/verl/utils/fsdp_utils.py b/verl/utils/fsdp_utils.py index 26b7dbd5..7120418d 100644 --- a/verl/utils/fsdp_utils.py +++ b/verl/utils/fsdp_utils.py @@ -26,12 +26,14 @@ import torch import torch.nn as nn import torch.distributed as dist +from verl.utils.device import is_cuda_available def init_fn(x: torch.nn.Module): if not torch.distributed.get_rank() == 0: - x = x.to_empty(device=torch.cuda.current_device(), recurse=False) - torch.cuda.empty_cache() + x = x.to_empty(device=torch.cuda.current_device() if is_cuda_available else torch.npu.current_device() if is_cuda_available else torch.npu.current_device(), + recurse=False) + torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() if is_cuda_available else torch.npu.empty_cache() if is_cuda_available else torch.npu.empty_cache() return x @@ -49,7 +51,7 @@ def get_init_weight_context_manager(use_meta_tensor=True): # Adapted from https://github.com/huggingface/transformers/src/transformers/trainer.py def get_fsdp_wrap_policy(module, config=None, is_lora=False): """Get FSDP wrap policy for the module. - + Args: module: The module to get wrap policy for config: Configuration for wrap policy @@ -118,14 +120,14 @@ def offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True): continue flat_param = handle.flat_param assert flat_param.data.data_ptr() == flat_param._local_shard.data_ptr() and \ - id(flat_param.data) != id(flat_param._local_shard) and \ - flat_param.data.size() == flat_param._local_shard.size() + id(flat_param.data) != id(flat_param._local_shard) and \ + flat_param.data.size() == flat_param._local_shard.size() handle.flat_param_to(torch.device("cpu"), non_blocking=True) # the following still keeps id(._local_shard) != id(.data) flat_param._local_shard = flat_param.data assert id(flat_param._local_shard) != id(flat_param.data) if empty_cache: - torch.cuda.empty_cache() + torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() if is_cuda_available else torch.npu.empty_cache() @torch.no_grad() @@ -134,7 +136,7 @@ def load_fsdp_model_to_gpu(model: FSDP): # lazy init FSDP model _lazy_init(model, model) assert model._is_root, f"Only support root model loading to GPU" - device_id = torch.cuda.current_device() + device_id = torch.cuda.current_device() if is_cuda_available else torch.npu.current_device() for handle in model._all_handles: if handle._offload_params: continue @@ -241,7 +243,7 @@ def parallel_load_safetensors(filepath): ckpt_chunks = [ckpt_chunks[rank * size:rank * size + size] for rank in range(world_size)] shard_states = {} - device = torch.cuda.current_device() + device = torch.cuda.current_device() if is_cuda_available else torch.npu.current_device() for rank, files in enumerate(ckpt_chunks): if rank == dist.get_rank(): for file in files: @@ -280,7 +282,7 @@ def parallel_init_module_fn(module: torch.nn.Module, shard_states: Dict[str, tor @torch.no_grad() def create_and_sync_state(param_name, state, is_param): assert param_name in shard_states, f"{param_name} not loaded" - device = torch.cuda.current_device() + device = torch.cuda.current_device() if is_cuda_available else torch.npu.current_device() if is_param: param = torch.nn.Parameter(torch.empty_like(state.data, device=device), requires_grad=state.requires_grad) else: # buffer @@ -335,4 +337,4 @@ def init_fn(sub_mod: torch.nn.Module, recurse: bool = True): # if len(shard_states) == 0: print("clear") return sub_mod - return init_fn + return init_fn \ No newline at end of file diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 4db326a1..b64e2903 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -30,8 +30,9 @@ from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx import verl.utils.torch_functional as verl_F +from verl.utils.device import get_device_name -from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis +from verl.bert_padding import pad_input, unpad_input, rearrange, index_first_axis __all__ = ['DataParallelPPOActor'] @@ -54,6 +55,7 @@ def __init__( self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1 self.compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True) + self.device = get_device_name() def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -62,7 +64,7 @@ def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, log_probs: # (bs, response_len) """ response_length = micro_batch['responses'].size(-1) - with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + with torch.autocast(device_type=self.device, dtype=torch.bfloat16): input_ids = micro_batch['input_ids'] batch_size, seqlen = input_ids.shape attention_mask = micro_batch['attention_mask'] @@ -230,7 +232,7 @@ def update_policy(self, data: DataProto): self.actor_optimizer.zero_grad() for data in micro_batches: - data = data.cuda() # actor device is cpu when using offload + data = data.to(self.device) # actor device is cpu when using offload responses = data['responses'] response_length = responses.size(1) attention_mask = data['attention_mask'] diff --git a/verl/workers/critic/dp_critic.py b/verl/workers/critic/dp_critic.py index f2eb44c2..88b6bf1b 100644 --- a/verl/workers/critic/dp_critic.py +++ b/verl/workers/critic/dp_critic.py @@ -30,8 +30,9 @@ from verl.utils.torch_functional import masked_mean from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx +from verl.utils.device import get_device_name -from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis +from verl.bert_padding import pad_input, unpad_input, rearrange, index_first_axis __all__ = ['DataParallelPPOCritic'] @@ -46,10 +47,11 @@ def __init__(self, config, critic_module: nn.Module, critic_optimizer: optim.Opt print(f'Critic use_remove_padding={self.use_remove_padding}') self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) + self.device = get_device_name() def _forward_micro_batch(self, micro_batch): response_length = micro_batch['responses'].size(-1) - with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + with torch.autocast(device_type=self.device, dtype=torch.bfloat16): input_ids = micro_batch['input_ids'] batch, seqlen = input_ids.shape attention_mask = micro_batch['attention_mask'] @@ -164,7 +166,7 @@ def update_critic(self, data: DataProto): self.critic_optimizer.zero_grad() for data in micro_batches: - data = data.cuda() # critic device is cpu when using offload + data = data.to(self.device) # critic device is cpu when using offload input_ids = data['input_ids'] responses = data['responses'] attention_mask = data['attention_mask'] diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 1660bd06..f579310c 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -38,21 +38,24 @@ from verl.utils.flops_counter import FlopsCounter from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager +from verl.utils.device import get_device_name, is_cuda_available from codetiming import Timer logger = logging.getLogger(__file__) logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) +DEVICE = get_device_name() + def create_device_mesh(world_size, fsdp_size): if fsdp_size < 0 or fsdp_size >= world_size: - device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=['fsdp']) + device_mesh = init_device_mesh(DEVICE, mesh_shape=(world_size,), mesh_dim_names=['fsdp']) else: raise ValueError( 'HSDP is not supported yet because it produces incorrect results for now. Please set fsdp_size=-1') assert world_size % fsdp_size == 0 - device_mesh = init_device_mesh('cuda', + device_mesh = init_device_mesh(DEVICE, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=['ddp', 'fsdp']) return device_mesh @@ -80,7 +83,7 @@ def __init__(self, config: DictConfig, role: str): self.config = config import torch.distributed if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") + torch.distributed.init_process_group(backend="nccl" if is_cuda_available else "hccl") # build device mesh for FSDP world_size = torch.distributed.get_world_size() @@ -92,7 +95,7 @@ def __init__(self, config: DictConfig, role: str): self.ulysses_sequence_parallel_size = self.config.actor.get('ulysses_sequence_parallel_size', 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh('cuda', + self.ulysses_device_mesh = init_device_mesh(DEVICE, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=['dp', 'sp']) @@ -106,9 +109,11 @@ def __init__(self, config: DictConfig, role: str): self._is_ref = self.role in ['ref', 'actor_rollout_ref'] self._is_offload_param = False + self._is_offload_optimizer = False if self._is_actor: self._is_offload_param = self.config.actor.fsdp_config.get('param_offload', False) + self._is_offload_grad = self.config.actor.fsdp_config.get('grad_offload', False) self._is_offload_optimizer = self.config.actor.fsdp_config.get('optimizer_offload', False) elif self._is_ref: # TODO: it seems that manual offload is slowly than FSDP offload @@ -148,7 +153,8 @@ def _build_model_optimizer(self, from verl.utils.model import print_model_size, update_model_config, get_generation_config from verl.utils.torch_dtypes import PrecisionType from transformers import AutoModelForCausalLM, AutoConfig - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, CPUOffload + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, \ + CPUOffload from torch import optim assert role in ['actor', 'ref'] @@ -197,7 +203,8 @@ def _build_model_optimizer(self, actor_module = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=local_path, torch_dtype=torch_dtype, config=actor_model_config, - attn_implementation='flash_attention_2', + attn_implementation='flash_attention_2' + if is_cuda_available else 'eager', trust_remote_code=trust_remote_code) # Apply Liger kernel to the model if use_liger is set to True if use_liger: @@ -250,7 +257,7 @@ def _build_model_optimizer(self, param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), + device_id=torch.cuda.current_device() if is_cuda_available else torch.npu.current_device(), sharding_strategy=sharding_strategy, # zero3 mixed_precision=mixed_precision, sync_module_states=True, @@ -289,7 +296,7 @@ def _build_rollout(self): infer_tp = self.config.rollout.tensor_model_parallel_size dp = self.world_size // infer_tp assert self.world_size % infer_tp == 0, f'rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}' - rollout_device_mesh = init_device_mesh('cuda', mesh_shape=(dp, infer_tp), mesh_dim_names=['dp', 'infer_tp']) + rollout_device_mesh = init_device_mesh(DEVICE, mesh_shape=(dp, infer_tp), mesh_dim_names=['dp', 'infer_tp']) if self.config.rollout.name == 'hf': from verl.workers.rollout import HFRollout @@ -397,19 +404,20 @@ def init_model(self): lr_scheduler=self.actor_lr_scheduler, tokenizer=self.tokenizer) - torch.cuda.empty_cache() + torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_actor(self, data: DataProto): - data = data.to('cuda') + data = data.to(DEVICE) assert self._is_actor if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device()) + load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device() + if is_cuda_available else torch.npu.current_device()) - data.batch = data.batch.cuda() + data.batch = data.batch.to(DEVICE) log_gpu_memory_usage('Before update policy', logger=logger) @@ -439,18 +447,18 @@ def update_actor(self, data: DataProto): offload_fsdp_model_to_cpu(self.actor_module_fsdp) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) - torch.cuda.empty_cache() + torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def generate_sequences(self, prompts: DataProto): - prompts = prompts.to('cuda') + prompts = prompts.to(DEVICE) assert self._is_rollout if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) - prompts.batch = prompts.batch.cuda() + prompts.batch = prompts.batch.to(DEVICE) meta_info = { 'eos_token_id': self.generation_config.eos_token_id @@ -480,7 +488,7 @@ def generate_sequences(self, prompts: DataProto): output = output.to('cpu') # clear kv cache - torch.cuda.empty_cache() + torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() log_gpu_memory_usage('After recompute log prob', logger=logger) return output @@ -489,7 +497,7 @@ def compute_log_prob(self, data: DataProto): assert self._is_actor if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) - data = data.to('cuda') + data = data.to(DEVICE) # we should always recompute old_log_probs when it is HybridEngine data.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size_per_gpu data.meta_info['max_token_len'] = self.config.rollout.log_prob_max_token_len_per_gpu @@ -511,10 +519,11 @@ def compute_log_prob(self, data: DataProto): self.actor.actor_module._handle.reshard(True) if self._is_offload_param: + # NOTE(sgm): the grad is already in CPU, only offload param here offload_fsdp_model_to_cpu(self.actor_module_fsdp) # clear kv cache - torch.cuda.empty_cache() + torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() log_gpu_memory_usage('After compute_log_prob', logger=logger) return output @@ -522,7 +531,7 @@ def compute_log_prob(self, data: DataProto): def compute_ref_log_prob(self, data: DataProto): assert self._is_ref - data = data.to('cuda') + data = data.to(DEVICE) micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu data.meta_info['micro_batch_size'] = micro_batch_size @@ -542,7 +551,7 @@ def compute_ref_log_prob(self, data: DataProto): if self.world_size > 1: self.ref_policy.actor_module._handle.reshard(True) - torch.cuda.empty_cache() + torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() return output @register(dispatch_mode=Dispatch.ONE_TO_ALL) @@ -579,7 +588,7 @@ def __init__(self, config): super().__init__() import torch.distributed if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") + torch.distributed.init_process_group(backend="nccl" if is_cuda_available else "hccl") self.config = config # build device mesh for Ulysses Sequence Parallel @@ -593,7 +602,7 @@ def __init__(self, config): self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh('cuda', + self.ulysses_device_mesh = init_device_mesh(DEVICE, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=['dp', 'sp']) @@ -601,6 +610,7 @@ def __init__(self, config): # set FSDP offload params self._is_offload_param = self.config.model.fsdp_config.param_offload + self._is_offload_grad = self.config.model.fsdp_config.grad_offload self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload # normalize config @@ -666,7 +676,8 @@ def _build_critic_model_optimizer(self, config): critic_module = AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path=local_path, torch_dtype=torch_dtype, config=critic_model_config, - attn_implementation='flash_attention_2', + attn_implementation='flash_attention_2' + if is_cuda_available else 'eager', trust_remote_code=trust_remote_code) # some parameters may not in torch_dtype @@ -704,7 +715,8 @@ def _build_critic_model_optimizer(self, config): param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), + device_id=torch.cuda.current_device() if is_cuda_available else torch.npu.current_device() + if is_cuda_available else torch.npu.current_device(), sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, sync_module_states=True, @@ -755,11 +767,11 @@ def init_model(self): lr_scheduler=self.critic_lr_scheduler, tokenizer=self.tokenizer) - torch.cuda.empty_cache() + torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_values(self, data: DataProto): - data = data.to('cuda') + data = data.to(DEVICE) if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) @@ -776,16 +788,18 @@ def compute_values(self, data: DataProto): output = output.to('cpu') if self._is_offload_param: - offload_fsdp_model_to_cpu(self.critic_module) + load_fsdp_model_to_gpu(self.critic_module) + torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_critic(self, data: DataProto): - data = data.to('cuda') + data = data.to(DEVICE) if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=torch.cuda.current_device()) + load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=torch.cuda.current_device() + if is_cuda_available else torch.npu.current_device()) # perform forward computation with self.ulysses_sharding_manager: @@ -810,7 +824,7 @@ def update_critic(self, data: DataProto): offload_fsdp_model_to_cpu(self.critic_module) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.critic_optimizer) - torch.cuda.empty_cache() + torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() output = output.to('cpu') return output @@ -839,7 +853,7 @@ def load_checkpoint(self, path, del_local_after_load=True): torch.distributed.barrier() if self._is_offload_param: - offload_fsdp_model_to_cpu(self.critic_module) + offload_fsdp_model_to_cpu(self.actor_module_fsdp) # TODO(sgm): we may need to extract it to dp_reward_model.py @@ -852,7 +866,7 @@ def __init__(self, config): super().__init__() import torch.distributed if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") + torch.distributed.init_process_group(backend="nccl" if is_cuda_available else "hccl") self.config = config # build device mesh for Ulysses Sequence Parallel @@ -866,7 +880,7 @@ def __init__(self, config): self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh('cuda', + self.ulysses_device_mesh = init_device_mesh(DEVICE, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=['dp', 'sp']) @@ -918,7 +932,8 @@ def _build_model(self, config): reward_module = AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path=local_path, config=model_config, torch_dtype=torch.bfloat16, - attn_implementation='flash_attention_2', + attn_implementation='flash_attention_2' + if is_cuda_available else 'eager', trust_remote_code=trust_remote_code) reward_module.to(torch.bfloat16) auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config) @@ -931,7 +946,7 @@ def _build_model(self, config): param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), + device_id=torch.cuda.current_device() if is_cuda_available else torch.npu.current_device(), sharding_strategy=sharding_strategy, # zero3 sync_module_states=True, cpu_offload=CPUOffload(offload_params=True), @@ -945,13 +960,13 @@ def init_model(self): # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get('external_lib', None)) self.reward_module = self._build_model(config=self.config) - torch.cuda.empty_cache() + torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() def _forward_micro_batch(self, micro_batch): - from flash_attn.bert_padding import pad_input, unpad_input, index_first_axis, rearrange + from verl.bert_padding import pad_input, unpad_input, index_first_axis, rearrange from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad - with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16): + with torch.no_grad(), torch.autocast(device_type=DEVICE, dtype=torch.bfloat16): input_ids = micro_batch['input_ids'] batch_size, seqlen = input_ids.shape attention_mask = micro_batch['attention_mask'] @@ -969,8 +984,8 @@ def _forward_micro_batch(self, micro_batch): # pad and slice the inputs if sp > 1 if self.ulysses_sequence_parallel_size > 1: input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, \ - position_ids_rmpad, \ - sp_size=self.ulysses_sequence_parallel_size) + position_ids_rmpad, \ + sp_size=self.ulysses_sequence_parallel_size) # only pass input_ids and position_ids to enable flash_attn_varlen output = self.reward_module(input_ids=input_ids_rmpad, @@ -1077,11 +1092,11 @@ def _switch_chat_template(self, data: DataProto): def compute_rm_score(self, data: DataProto): import itertools from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx - data = data.to('cuda') + data = data.to(DEVICE) if self._do_switch_chat_template: rm_data = self._switch_chat_template(data) - rm_data.batch = rm_data.batch.cuda() + rm_data.batch = rm_data.batch.to(DEVICE) # perform forward computation with self.ulysses_sharding_manager: @@ -1116,5 +1131,5 @@ def compute_rm_score(self, data: DataProto): self.reward_module._handle.reshard(True) output = output.to('cpu') - torch.cuda.empty_cache() + torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() return output diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index bcee3544..88bfc174 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -38,6 +38,7 @@ from vllm.distributed import parallel_state as vllm_ps from vllm import LLM, SamplingParams from verl.third_party.vllm import vllm_version +from verl.utils.device import is_cuda_available # TODO # 1. support pp in vllm @@ -91,7 +92,7 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf self.inference_engine = LLM( model=model_path, - enable_sleep_mode=True, + enable_sleep_mode=True if is_cuda_available else False, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend="external_launcher", dtype=config.dtype, @@ -106,7 +107,8 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf ) # Offload vllm model to reduce peak memory usage - self.inference_engine.sleep(level=1) + if is_cuda_available: + self.inference_engine.sleep(level=1) kwargs = dict( n=1, diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py index c79d3031..31092980 100644 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ b/verl/workers/sharding_manager/fsdp_vllm.py @@ -25,6 +25,7 @@ from verl.utils.torch_functional import (broadcast_dict_tensor, allgather_dict_tensors) from verl.utils.debug import log_gpu_memory_usage from verl.third_party.vllm import vllm_version +from verl.utils.device import is_cuda_available from .base import BaseShardingManager @@ -57,13 +58,15 @@ def __init__(self, state_dict_config=ShardedStateDictConfig()) # Note that torch_random_states may be different on each dp rank - self.torch_random_states = torch.cuda.get_rng_state() + self.torch_random_states = torch.cuda.get_rng_state() if is_cuda_available else torch.npu.get_rng_state() # get a random rng states if self.device_mesh is not None: gen_dp_rank = self.device_mesh['dp'].get_local_rank() - torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states - self.gen_random_states = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.torch_random_states) + torch.cuda.manual_seed(gen_dp_rank + 1000) if is_cuda_available else \ + torch.npu.manual_seed(gen_dp_rank + 1000)# make sure all tp ranks have the same random states + self.gen_random_states = torch.cuda.get_rng_state() if is_cuda_available else torch.npu.get_rng_state() + torch.cuda.set_rng_state(self.torch_random_states) if is_cuda_available else \ + torch.npu.set_rng_state(self.torch_random_states) else: self.gen_random_states = None @@ -76,7 +79,8 @@ def __enter__(self): if vllm_version in ('0.4.2', '0.5.4', '0.6.3'): self.inference_engine.sync_model_weights(params, load_format=load_format) else: - self.inference_engine.wake_up() + if is_cuda_available: + self.inference_engine.wake_up() # TODO(ZSL): deal with 'hf' format if load_format == 'dtensor': from verl.third_party.vllm import load_dtensor_weights @@ -87,26 +91,27 @@ def __enter__(self): log_gpu_memory_usage('After sync model weights in sharding manager', logger=logger) del params - torch.cuda.empty_cache() + torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() log_gpu_memory_usage('After del state_dict and empty_cache in sharding manager', logger=logger) # TODO: offload FSDP model weights # self.module.cpu() - # torch.cuda.empty_cache() + # torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() # if torch.distributed.get_rank() == 0: # print(f'after model to cpu in sharding manager memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB') # important: need to manually set the random states of each tp to be identical. if self.device_mesh is not None: - self.torch_random_states = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.gen_random_states) + self.torch_random_states = torch.cuda.get_rng_state() if is_cuda_available else torch.npu.get_rng_state() + torch.cuda.set_rng_state(self.gen_random_states) if is_cuda_available else \ + torch.npu.set_rng_state(self.gen_random_states) def __exit__(self, exc_type, exc_value, traceback): log_gpu_memory_usage('Before vllm offload in sharding manager', logger=logger) # TODO(ZSL): check this if vllm_version in ('0.4.2', '0.5.4', '0.6.3'): self.inference_engine.offload_model_weights() - else: + elif is_cuda_available: self.inference_engine.sleep(level=1) log_gpu_memory_usage('After vllm offload in sharding manager', logger=logger) @@ -117,12 +122,13 @@ def __exit__(self, exc_type, exc_value, traceback): self.module.train() # add empty cache after each compute - torch.cuda.empty_cache() + torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() # restore random states if self.device_mesh is not None: - self.gen_random_states = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.torch_random_states) + self.gen_random_states = torch.cuda.get_rng_state() if is_cuda_available else torch.npu.get_rng_state() + torch.cuda.set_rng_state(self.torch_random_states) if is_cuda_available else \ + torch.npu.set_rng_state(self.torch_random_states) def preprocess_data(self, data: DataProto) -> DataProto: # TODO: Current impl doesn't consider FSDP with torch micro-dp @@ -156,4 +162,4 @@ def postprocess_data(self, data: DataProto) -> DataProto: # TODO: shall we build a micro_dp group for vllm when integrating with vLLM? local_prompts = data.chunk(chunks=tp_size) data = local_prompts[dp_rank % tp_size] - return data + return data \ No newline at end of file