Skip to content

Commit

Permalink
support qwen2vl model
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Feb 26, 2025
1 parent b4c13ce commit 51f0d5a
Show file tree
Hide file tree
Showing 22 changed files with 793 additions and 103 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,7 @@ tests/e2e/toy_examples/deepspeed/synchronous/output.txt
*.swp

# ckpt
*.lock
*.lock

# data
*.parquet
83 changes: 83 additions & 0 deletions examples/data_preprocess/geo3k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess the Geometry3k dataset to parquet format
"""

import os
import datasets

from verl.utils.hdfs_io import copy, makedirs
import argparse

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--local_dir', default='~/data/geo3k')
parser.add_argument('--hdfs_dir', default=None)

args = parser.parse_args()

data_source = 'hiyouga/geometry3k'

dataset = datasets.load_dataset(data_source)

train_dataset = dataset['train']
test_dataset = dataset['test']

instruction_following = r"Please reason step by step, and put your final answer within \boxed{}."

# add a row to each data item that represents a unique id
def make_map_fn(split):

def process_fn(example, idx):
problem = example.pop('problem')
prompt = problem + ' ' + instruction_following
answer = example.pop('answer')
images = example.pop('images')

data = {
"data_source": data_source,
"prompt": [{
"role": "user",
"content": prompt,
}],
"images": images,
"ability": "math",
"reward_model": {
"style": "rule",
"ground_truth": answer
},
"extra_info": {
'split': split,
'index': idx,
'answer': answer,
"question": problem,
}
}
return data

return process_fn

train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True)
test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True)

local_dir = args.local_dir
hdfs_dir = args.hdfs_dir

train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet'))
test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet'))

if hdfs_dir is not None:
makedirs(hdfs_dir)
copy(src=local_dir, dst=hdfs_dir)
42 changes: 42 additions & 0 deletions examples/grpo_trainer/run_qwen2_5_vl-7b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
set -x

export VLLM_ATTENTION_BACKEND=XFORMERS

python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=$HOME/data/geo3k/train.parquet \
data.val_files=$HOME/data/geo3k/test.parquet \
data.train_batch_size=512 \
data.max_prompt_length=1536 \
data.max_response_length=1536 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.enable_chunked_prefill=False \
actor_rollout_ref.rollout.enforce_eager=False \
actor_rollout_ref.rollout.free_cache_engine=False \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='verl_grpo_example_geo3k' \
trainer.experiment_name='qwen2_5_vl_7b_function_rm' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=5 \
trainer.total_epochs=15 $@
4 changes: 3 additions & 1 deletion scripts/model_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import os
import torch
import argparse
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq
from concurrent.futures import ThreadPoolExecutor
from torch.distributed._tensor import DTensor, Shard, Placement

Expand Down Expand Up @@ -140,6 +140,8 @@ def process_one_shard(rank):
auto_model = AutoModelForTokenClassification
elif 'ForCausalLM' in config.architectures[0]:
auto_model = AutoModelForCausalLM
elif 'ForConditionalGeneration' in config.architectures[0]:
auto_model = AutoModelForVision2Seq
else:
raise NotImplementedError(f'Unknown architecture {config["architectures"]}')

Expand Down
16 changes: 16 additions & 0 deletions verl/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@

_REOVEPAD_MODELS = {'llama': LlamaConfig, 'mistral': MistralConfig, 'gemma': GemmaConfig, 'qwen2': Qwen2Config}

try:
from transformers import Qwen2VLConfig, Qwen2_5_VLConfig

_REOVEPAD_MODELS.update({'qwen2_vl': Qwen2VLConfig, 'qwen2_5_vl': Qwen2_5_VLConfig})
except ImportError:
pass


def check_model_support_rmpad(model_type: str):
assert isinstance(model_type, str)
Expand All @@ -31,6 +38,15 @@ def check_model_support_rmpad(model_type: str):
f"RMPad supported architectures: {_REOVEPAD_MODELS.keys()}."
f"Please set `use_remove_padding=False` in the model config.")

if model_type in ("qwen2_vl", "qwen2_5_vl"): # patch remove padding for qwen2vl mrope
from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2

Qwen2VLFlashAttention2.forward = ulysses_flash_attn_forward
Qwen2_5_VLFlashAttention2.forward = ulysses_flash_attn_forward
print("Qwen2vl patch applied!")


# Supported models in Megatron-LM
# Architecture -> (module, class).
Expand Down
Loading

0 comments on commit 51f0d5a

Please sign in to comment.