From e13e67d11a8b31de48bccb2af9dba1b89026c977 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 16 Feb 2025 01:46:45 +0800 Subject: [PATCH 1/5] refactor patcher --- swift/llm/__init__.py | 21 ++++--------- swift/llm/dataset/__init__.py | 3 +- swift/llm/model/patcher.py | 52 +++++++++++++++++++++++++++++++-- swift/llm/model/register.py | 4 +-- swift/llm/train/__init__.py | 1 - swift/llm/train/patcher.py | 55 ----------------------------------- swift/utils/env.py | 5 +--- 7 files changed, 59 insertions(+), 82 deletions(-) delete mode 100644 swift/llm/train/patcher.py diff --git a/swift/llm/__init__.py b/swift/llm/__init__.py index c8f093c8e0..b0d10dafc5 100644 --- a/swift/llm/__init__.py +++ b/swift/llm/__init__.py @@ -25,7 +25,7 @@ from .dataset import (AlpacaPreprocessor, ResponsePreprocessor, MessagesPreprocessor, AutoPreprocessor, DATASET_MAPPING, MediaResource, register_dataset, register_dataset_info, EncodePreprocessor, LazyLLMDataset, ConstantLengthDataset, load_dataset, DATASET_TYPE, sample_dataset, - RowPreprocessor, DatasetMeta) + RowPreprocessor, DatasetMeta, HfDataset, SubsetDataset) from .utils import (deep_getattr, to_device, History, Messages, history_to_messages, messages_to_history, Processor, save_checkpoint, ProcessorMixin, get_temporary_cache_files_directory, get_cache_dir) from .base import SwiftPipeline @@ -59,21 +59,10 @@ 'load_by_unsloth', 'git_clone_github', 'get_matched_model_meta' ], 'dataset': [ - 'AlpacaPreprocessor', - 'MessagesPreprocessor', - 'DATASET_MAPPING', - 'MediaResource', - 'register_dataset', - 'register_dataset_info', - 'EncodePreprocessor', - 'LazyLLMDataset', - 'ConstantLengthDataset', - 'load_dataset', - 'DATASET_TYPE', - 'sample_dataset', - 'RowPreprocessor', - 'ResponsePreprocessor', - 'DatasetMeta', + 'AlpacaPreprocessor', 'MessagesPreprocessor', 'DATASET_MAPPING', 'MediaResource', 'register_dataset', + 'register_dataset_info', 'EncodePreprocessor', 'LazyLLMDataset', 'ConstantLengthDataset', 'load_dataset', + 'DATASET_TYPE', 'sample_dataset', 'RowPreprocessor', 'ResponsePreprocessor', 'DatasetMeta', 'HfDataset', + 'SubsetDataset' ], 'utils': [ 'deep_getattr', 'to_device', 'History', 'Messages', 'history_to_messages', 'messages_to_history', diff --git a/swift/llm/dataset/__init__.py b/swift/llm/dataset/__init__.py index ce32d93910..617948c953 100644 --- a/swift/llm/dataset/__init__.py +++ b/swift/llm/dataset/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import datasets.fingerprint +from datasets import Dataset as HfDataset from datasets import disable_caching from swift.utils.torch_utils import _find_local_mac @@ -9,7 +10,7 @@ from .media import MediaResource from .preprocessor import (AlpacaPreprocessor, AutoPreprocessor, MessagesPreprocessor, ResponsePreprocessor, RowPreprocessor) -from .register import DATASET_MAPPING, DatasetMeta, register_dataset, register_dataset_info +from .register import DATASET_MAPPING, DatasetMeta, SubsetDataset, register_dataset, register_dataset_info from .utils import (ConstantLengthDataset, EncodePreprocessor, GetLengthPreprocessor, LazyLLMDataset, PackingPreprocessor, sample_dataset) diff --git a/swift/llm/model/patcher.py b/swift/llm/model/patcher.py index be84a3d067..7c723a7d4b 100644 --- a/swift/llm/model/patcher.py +++ b/swift/llm/model/patcher.py @@ -2,18 +2,22 @@ from contextlib import contextmanager from functools import wraps from types import MethodType -from typing import List +from typing import Dict, List, Optional, Union +import accelerate import torch import torch.nn as nn import torch.nn.functional as F +import transformers from accelerate.utils import find_device from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers import PreTrainedModel +from torch.nn.parallel import DistributedDataParallel as DDP +from transformers import PreTrainedModel, trainer from transformers.modeling_outputs import SequenceClassifierOutputWithPast from swift.llm import to_device -from swift.utils import get_logger +from swift.utils import get_dist_setting, get_logger, is_mp_ddp, use_torchacc +from swift.utils.torch_utils import _get_max_memory, _sync_max_memory, get_device_count from .model_arch import get_model_arch from .utils import HfConfigFactory @@ -234,3 +238,45 @@ def _new_from_pretrained(cls, *args, **kwargs): yield finally: PreTrainedModel.from_pretrained = classmethod(from_pretrained) + + +_mp_ddp_patched = False + + +def patch_mp_ddp(): + """Patch ddp with device_map. + After patching, the ddp can run with the device_map. + This should be called before any training starts. + """ + global _mp_ddp_patched + if is_mp_ddp() and not _mp_ddp_patched: + _mp_ddp_patched = True + from accelerate.utils.modeling import get_balanced_memory, infer_auto_device_map + + @wraps(infer_auto_device_map) + def _infer_auto_device_map_patch(model: nn.Module, + max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, + **kwargs) -> Dict[str, Union[int, str, torch.device]]: + """The auxiliary function for supports MP + DDP. Monkey Patching. + add feat in accelerate to support MP + DDP""" + verbose = kwargs.pop('verbose', False) + n_gpu = get_device_count() + _, local_rank, _, local_world_size = get_dist_setting() + device_ids = list(range(local_rank, n_gpu, local_world_size)) + max_memory = _get_max_memory(device_ids) + max_memory = _sync_max_memory(max_memory) + max_memory = get_balanced_memory(model, max_memory, low_zero=False, **kwargs) + max_memory = {k: v for k, v in max_memory.items() if v > 0} + return infer_auto_device_map(model, max_memory, verbose=verbose, **kwargs) + + _old_ddp_init = DDP.__init__ + accelerate.accelerator.torch.nn.parallel.DistributedDataParallel.__init__ = ( + lambda self, model, device_ids, output_device, *args, **kwargs: _old_ddp_init(self, model, *args, **kwargs)) + transformers.modeling_utils.get_balanced_memory = lambda *args, **kwargs: None + transformers.modeling_utils.infer_auto_device_map = _infer_auto_device_map_patch + + if is_mp_ddp() or use_torchacc(): + _old_accelerator_init = trainer.Accelerator.__init__ + trainer.Accelerator.__init__ = (lambda self, device_placement=False, *args, **kwargs: _old_accelerator_init( + self, device_placement=device_placement, *args, **kwargs)) + trainer.Accelerator.verify_device_map = lambda *args, **kwargs: False diff --git a/swift/llm/model/register.py b/swift/llm/model/register.py index 9d7e0d8b41..bc9734fac8 100644 --- a/swift/llm/model/register.py +++ b/swift/llm/model/register.py @@ -21,7 +21,7 @@ from swift.utils import get_dist_setting, get_logger, is_mp, is_unsloth_available, patch_getattr, use_torchacc from .constant import ModelType -from .patcher import patch_automodel_for_awq, patch_automodel_for_sequence_classification +from .patcher import patch_automodel_for_awq, patch_automodel_for_sequence_classification, patch_mp_ddp from .utils import AttnImpl, HfConfigFactory, ModelInfo, safe_snapshot_download GetModelTokenizerFunction = Callable[..., Tuple[Optional[PreTrainedModel], PreTrainedTokenizerBase]] @@ -468,7 +468,7 @@ def get_model_tokenizer( If set to None : It will be automatically selected between sdpa and eager. download_model: Whether to download the model weights. If `None`, it will be selected based on load_model. """ - + patch_mp_ddp() if model_kwargs is None: model_kwargs = {} if download_model is None: diff --git a/swift/llm/train/__init__.py b/swift/llm/train/__init__.py index 1d379ddf38..24b51f5444 100644 --- a/swift/llm/train/__init__.py +++ b/swift/llm/train/__init__.py @@ -1,5 +1,4 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from . import patcher from .pt import SwiftPt, pt_main from .rlhf import SwiftRLHF, rlhf_main from .sft import SwiftSft, sft_main diff --git a/swift/llm/train/patcher.py b/swift/llm/train/patcher.py deleted file mode 100644 index 09ef711566..0000000000 --- a/swift/llm/train/patcher.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -from functools import wraps -from typing import Dict, Optional, Union - -import accelerate -import torch -import transformers -from torch.nn import Module -from torch.nn.parallel import DistributedDataParallel as DDP -from transformers import trainer - -from swift.utils import get_dist_setting, get_logger, is_mp_ddp, use_torchacc -from swift.utils.torch_utils import _get_max_memory, _sync_max_memory, get_device_count - -logger = get_logger() - - -def patch_mp_ddp(): - """Patch ddp with device_map. - After patching, the ddp can run with the device_map. - This should be called before any training starts. - """ - if is_mp_ddp(): - from accelerate.utils.modeling import get_balanced_memory, infer_auto_device_map - - @wraps(infer_auto_device_map) - def _infer_auto_device_map_patch(model: Module, - max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, - **kwargs) -> Dict[str, Union[int, str, torch.device]]: - """The auxiliary function for supports MP + DDP. Monkey Patching. - add feat in accelerate to support MP + DDP""" - verbose = kwargs.pop('verbose', False) - n_gpu = get_device_count() - _, local_rank, _, local_world_size = get_dist_setting() - device_ids = list(range(local_rank, n_gpu, local_world_size)) - max_memory = _get_max_memory(device_ids) - max_memory = _sync_max_memory(max_memory) - max_memory = get_balanced_memory(model, max_memory, low_zero=False, **kwargs) - max_memory = {k: v for k, v in max_memory.items() if v > 0} - return infer_auto_device_map(model, max_memory, verbose=verbose, **kwargs) - - _old_ddp_init = DDP.__init__ - accelerate.accelerator.torch.nn.parallel.DistributedDataParallel.__init__ = ( - lambda self, model, device_ids, output_device, *args, **kwargs: _old_ddp_init(self, model, *args, **kwargs)) - transformers.modeling_utils.get_balanced_memory = lambda *args, **kwargs: None - transformers.modeling_utils.infer_auto_device_map = _infer_auto_device_map_patch - - if is_mp_ddp() or use_torchacc(): - _old_accelerator_init = trainer.Accelerator.__init__ - trainer.Accelerator.__init__ = (lambda self, device_placement=False, *args, **kwargs: _old_accelerator_init( - self, device_placement=device_placement, *args, **kwargs)) - trainer.Accelerator.verify_device_map = lambda *args, **kwargs: False - - -patch_mp_ddp() diff --git a/swift/utils/env.py b/swift/utils/env.py index 6d655a5505..a0d4629247 100644 --- a/swift/utils/env.py +++ b/swift/utils/env.py @@ -71,10 +71,7 @@ def is_mp() -> bool: def is_mp_ddp() -> bool: # patch_mp_ddp will occur when `import swift`. - from swift.utils import get_device_count - n_gpu = get_device_count() - local_world_size = get_dist_setting()[3] - if is_dist() and n_gpu != local_world_size + 1 and is_mp(): # fix grpo + if is_dist() and is_mp(): # fix grpo logger.info('Using MP + DDP(device_map)') return True return False From a981242e7dcd4ea58030761789d7c9c4ae201916 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 16 Feb 2025 01:48:37 +0800 Subject: [PATCH 2/5] update --- swift/utils/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/utils/env.py b/swift/utils/env.py index a0d4629247..5c243a7a4e 100644 --- a/swift/utils/env.py +++ b/swift/utils/env.py @@ -71,7 +71,7 @@ def is_mp() -> bool: def is_mp_ddp() -> bool: # patch_mp_ddp will occur when `import swift`. - if is_dist() and is_mp(): # fix grpo + if is_dist() and is_mp(): logger.info('Using MP + DDP(device_map)') return True return False From da84f163c969331b5550d444cde090a820f0a8b6 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 16 Feb 2025 02:08:47 +0800 Subject: [PATCH 3/5] update --- examples/train/multi-gpu/ddp/train.sh | 13 ++++++++++--- .../train/multi-gpu/ddp_device_map/train.sh | 14 ++++++++++---- .../train/multi-gpu/deepspeed/train_zero2.sh | 9 ++++++++- .../train/multi-gpu/deepspeed/train_zero3.sh | 13 ++++++++++--- examples/train/multi-gpu/fsdp_qlora/train.sh | 17 ++++++++++++----- 5 files changed, 50 insertions(+), 16 deletions(-) diff --git a/examples/train/multi-gpu/ddp/train.sh b/examples/train/multi-gpu/ddp/train.sh index 6ce56701b5..3acc260c48 100644 --- a/examples/train/multi-gpu/ddp/train.sh +++ b/examples/train/multi-gpu/ddp/train.sh @@ -10,14 +10,21 @@ swift sft \ --dataset 'swift/self-cognition#1000' \ --num_train_epochs 1 \ --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --learning_rate 1e-4 \ --lora_rank 8 \ --lora_alpha 32 \ - --learning_rate 1e-4 \ + --target_modules all-linear \ --gradient_accumulation_steps $(expr 16 / $nproc_per_node) \ --eval_steps 100 \ --save_steps 100 \ --save_total_limit 2 \ --logging_steps 5 \ - --gradient_checkpointing_kwargs '{"use_reentrant": false}' \ + --max_length 2048 \ + --output_dir output \ + --system 'You are a helpful assistant.' \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ --model_author swift \ - --model_name swift-robot + --model_name swift-robot \ + --gradient_checkpointing_kwargs '{"use_reentrant": false}' diff --git a/examples/train/multi-gpu/ddp_device_map/train.sh b/examples/train/multi-gpu/ddp_device_map/train.sh index 3949ae766d..e0bc2bd299 100644 --- a/examples/train/multi-gpu/ddp_device_map/train.sh +++ b/examples/train/multi-gpu/ddp_device_map/train.sh @@ -10,15 +10,21 @@ swift sft \ --torch_dtype bfloat16 \ --num_train_epochs 1 \ --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --learning_rate 1e-4 \ --lora_rank 8 \ --lora_alpha 32 \ - --weight_decay 0.1 \ - --learning_rate 1e-4 \ + --target_modules all-linear \ --gradient_accumulation_steps $(expr 16 / $nproc_per_node) \ --eval_steps 100 \ --save_steps 100 \ --save_total_limit 2 \ --logging_steps 5 \ - --gradient_checkpointing_kwargs '{"use_reentrant": false}' \ + --max_length 2048 \ + --output_dir output \ + --system 'You are a helpful assistant.' \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ --model_author swift \ - --model_name swift-robot + --model_name swift-robot \ + --gradient_checkpointing_kwargs '{"use_reentrant": false}' diff --git a/examples/train/multi-gpu/deepspeed/train_zero2.sh b/examples/train/multi-gpu/deepspeed/train_zero2.sh index 61b92e6fd3..deaea2afa0 100644 --- a/examples/train/multi-gpu/deepspeed/train_zero2.sh +++ b/examples/train/multi-gpu/deepspeed/train_zero2.sh @@ -10,14 +10,21 @@ swift sft \ --torch_dtype bfloat16 \ --num_train_epochs 1 \ --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --learning_rate 1e-4 \ --lora_rank 8 \ --lora_alpha 32 \ - --learning_rate 1e-4 \ + --target_modules all-linear \ --gradient_accumulation_steps $(expr 16 / $nproc_per_node) \ --eval_steps 100 \ --save_steps 100 \ --save_total_limit 2 \ --logging_steps 5 \ + --max_length 2048 \ + --output_dir output \ + --system 'You are a helpful assistant.' \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ --model_author swift \ --model_name swift-robot \ --deepspeed zero2 diff --git a/examples/train/multi-gpu/deepspeed/train_zero3.sh b/examples/train/multi-gpu/deepspeed/train_zero3.sh index 5bed97bf52..04a32a9da7 100644 --- a/examples/train/multi-gpu/deepspeed/train_zero3.sh +++ b/examples/train/multi-gpu/deepspeed/train_zero3.sh @@ -7,17 +7,24 @@ swift sft \ --model Qwen/Qwen2.5-7B-Instruct \ --train_type lora \ --dataset 'swift/self-cognition#1000' \ + --torch_dtype bfloat16 \ --num_train_epochs 1 \ --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --learning_rate 1e-4 \ --lora_rank 8 \ --lora_alpha 32 \ - --learning_rate 1e-4 \ + --target_modules all-linear \ --gradient_accumulation_steps $(expr 16 / $nproc_per_node) \ --eval_steps 100 \ --save_steps 100 \ --save_total_limit 2 \ --logging_steps 5 \ + --max_length 2048 \ + --output_dir output \ + --system 'You are a helpful assistant.' \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ --model_author swift \ --model_name swift-robot \ - --deepspeed zero3 \ - --max_length 1024 + --deepspeed zero3 diff --git a/examples/train/multi-gpu/fsdp_qlora/train.sh b/examples/train/multi-gpu/fsdp_qlora/train.sh index 8b10a78b51..bf2609dc9b 100644 --- a/examples/train/multi-gpu/fsdp_qlora/train.sh +++ b/examples/train/multi-gpu/fsdp_qlora/train.sh @@ -7,21 +7,28 @@ accelerate launch --config_file "./examples/train/fsdp_qlora/fsdp_offload.json" --model Qwen/Qwen2.5-7B-Instruct \ --train_type lora \ --dataset 'swift/self-cognition#1000' \ + --torch_dtype bfloat16 \ --num_train_epochs 1 \ --per_device_train_batch_size 1 \ - --max_length 2048 \ + --per_device_eval_batch_size 1 \ --quant_bits 4 \ --bnb_4bit_compute_dtype bfloat16 \ --bnb_4bit_quant_storage bfloat16 \ + --learning_rate 1e-4 \ --lora_rank 8 \ --lora_alpha 32 \ --gradient_checkpointing true \ --weight_decay 0.1 \ - --learning_rate 1e-4 \ + --target_modules all-linear \ --gradient_accumulation_steps $(expr 16 / $nproc_per_node) \ - --eval_steps 50 \ - --save_steps 50 \ + --eval_steps 100 \ + --save_steps 100 \ --save_total_limit 2 \ - --logging_steps 10 \ + --logging_steps 5 \ + --max_length 2048 \ + --output_dir output \ + --system 'You are a helpful assistant.' \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ --model_author swift \ --model_name swift-robot From 5bd8bdf1b8b40dee3052e70209a6b52555632b9b Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 16 Feb 2025 03:48:22 +0800 Subject: [PATCH 4/5] compat ppo --- swift/llm/train/rlhf.py | 2 +- swift/trainers/rlhf_trainer/ppo_trainer.py | 13 +++++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/swift/llm/train/rlhf.py b/swift/llm/train/rlhf.py index d7e0b15869..3cf000689f 100644 --- a/swift/llm/train/rlhf.py +++ b/swift/llm/train/rlhf.py @@ -54,7 +54,7 @@ def _prepare_model_tokenizer(self): self.train_msg['value_model_parameter_info'] = model_parameter_info logger.info(f'value_model_parameter_info: {model_parameter_info}') setattr(self, f'{origin_key}_model', model) - if origin_key == 'reward': + if origin_key == 'reward' and args.rlhf_type == 'grpo': reward_template = self.args.get_template(processor) if reward_template.use_model: reward_template.model = model diff --git a/swift/trainers/rlhf_trainer/ppo_trainer.py b/swift/trainers/rlhf_trainer/ppo_trainer.py index 0b194dc506..5fc20c882b 100644 --- a/swift/trainers/rlhf_trainer/ppo_trainer.py +++ b/swift/trainers/rlhf_trainer/ppo_trainer.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import inspect from contextlib import contextmanager import transformers @@ -39,8 +40,16 @@ def __init__(self, model: PreTrainedModel, ref_model: PreTrainedModel, *_args, * for k, v in kwargs.items() if k in ['train_dataset', 'data_collator', 'reward_model', 'value_model', 'eval_dataset'] } - ppo_trainer_init( - self, config=kwargs['args'], tokenizer=self.tokenizer, model=model, ref_model=ref_model, **new_kwargs) + parameters = inspect.signature(ppo_trainer_init).parameters + if 'config' in parameters: + new_kwargs['config'] = kwargs['args'] + else: + new_kwargs['args'] = kwargs['args'] + if 'processing_class' in parameters: + new_kwargs['processing_class'] = self.tokenizer + else: + new_kwargs['tokenizer'] = self.tokenizer + ppo_trainer_init(self, model=model, ref_model=ref_model, **new_kwargs) unwrap_model = self.accelerator.unwrap_model(self.model) patch_getattr(unwrap_model.__class__, 'policy') From 2771a3993686e8006a56cea5d403f99e9f14f0a8 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 16 Feb 2025 11:46:11 +0800 Subject: [PATCH 5/5] support qwen2.5-vl awq --- ...213\345\222\214\346\225\260\346\215\256\351\233\206.md" | 3 +++ .../source_en/Instruction/Supported-models-and-datasets.md | 3 +++ swift/llm/model/model/qwen.py | 7 ++++++- 3 files changed, 12 insertions(+), 1 deletion(-) diff --git "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" index 57c4383680..1e4ebc85a6 100644 --- "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" +++ "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" @@ -529,6 +529,9 @@ |[Qwen/Qwen2.5-VL-3B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-3B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|vision, video|[Qwen/Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)| |[Qwen/Qwen2.5-VL-7B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-7B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|vision, video|[Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)| |[Qwen/Qwen2.5-VL-72B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-72B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|vision, video|[Qwen/Qwen2.5-VL-72B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-72B-Instruct)| +|[Qwen/Qwen2.5-VL-3B-Instruct-AWQ](https://modelscope.cn/models/Qwen/Qwen2.5-VL-3B-Instruct-AWQ)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|vision, video|[Qwen/Qwen2.5-VL-3B-Instruct-AWQ](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct-AWQ)| +|[Qwen/Qwen2.5-VL-7B-Instruct-AWQ](https://modelscope.cn/models/Qwen/Qwen2.5-VL-7B-Instruct-AWQ)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|vision, video|[Qwen/Qwen2.5-VL-7B-Instruct-AWQ](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct-AWQ)| +|[Qwen/Qwen2.5-VL-72B-Instruct-AWQ](https://modelscope.cn/models/Qwen/Qwen2.5-VL-72B-Instruct-AWQ)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|vision, video|[Qwen/Qwen2.5-VL-72B-Instruct-AWQ](https://huggingface.co/Qwen/Qwen2.5-VL-72B-Instruct-AWQ)| |[Qwen/Qwen2-Audio-7B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-Audio-7B-Instruct)|qwen2_audio|qwen2_audio|transformers>=4.45, librosa|audio|[Qwen/Qwen2-Audio-7B-Instruct](https://huggingface.co/Qwen/Qwen2-Audio-7B-Instruct)| |[Qwen/Qwen2-Audio-7B](https://modelscope.cn/models/Qwen/Qwen2-Audio-7B)|qwen2_audio|qwen2_audio|transformers>=4.45, librosa|audio|[Qwen/Qwen2-Audio-7B](https://huggingface.co/Qwen/Qwen2-Audio-7B)| |[Qwen/QVQ-72B-Preview](https://modelscope.cn/models/Qwen/QVQ-72B-Preview)|qvq|qvq|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|vision, video|[Qwen/QVQ-72B-Preview](https://huggingface.co/Qwen/QVQ-72B-Preview)| diff --git a/docs/source_en/Instruction/Supported-models-and-datasets.md b/docs/source_en/Instruction/Supported-models-and-datasets.md index 4ea7573759..ca518a186c 100644 --- a/docs/source_en/Instruction/Supported-models-and-datasets.md +++ b/docs/source_en/Instruction/Supported-models-and-datasets.md @@ -529,6 +529,9 @@ The table below introduces the models integrated with ms-swift: |[Qwen/Qwen2.5-VL-3B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-3B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|vision, video|[Qwen/Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)| |[Qwen/Qwen2.5-VL-7B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-7B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|vision, video|[Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)| |[Qwen/Qwen2.5-VL-72B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-72B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|vision, video|[Qwen/Qwen2.5-VL-72B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-72B-Instruct)| +|[Qwen/Qwen2.5-VL-3B-Instruct-AWQ](https://modelscope.cn/models/Qwen/Qwen2.5-VL-3B-Instruct-AWQ)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|vision, video|[Qwen/Qwen2.5-VL-3B-Instruct-AWQ](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct-AWQ)| +|[Qwen/Qwen2.5-VL-7B-Instruct-AWQ](https://modelscope.cn/models/Qwen/Qwen2.5-VL-7B-Instruct-AWQ)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|vision, video|[Qwen/Qwen2.5-VL-7B-Instruct-AWQ](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct-AWQ)| +|[Qwen/Qwen2.5-VL-72B-Instruct-AWQ](https://modelscope.cn/models/Qwen/Qwen2.5-VL-72B-Instruct-AWQ)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|vision, video|[Qwen/Qwen2.5-VL-72B-Instruct-AWQ](https://huggingface.co/Qwen/Qwen2.5-VL-72B-Instruct-AWQ)| |[Qwen/Qwen2-Audio-7B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-Audio-7B-Instruct)|qwen2_audio|qwen2_audio|transformers>=4.45, librosa|audio|[Qwen/Qwen2-Audio-7B-Instruct](https://huggingface.co/Qwen/Qwen2-Audio-7B-Instruct)| |[Qwen/Qwen2-Audio-7B](https://modelscope.cn/models/Qwen/Qwen2-Audio-7B)|qwen2_audio|qwen2_audio|transformers>=4.45, librosa|audio|[Qwen/Qwen2-Audio-7B](https://huggingface.co/Qwen/Qwen2-Audio-7B)| |[Qwen/QVQ-72B-Preview](https://modelscope.cn/models/Qwen/QVQ-72B-Preview)|qvq|qvq|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|vision, video|[Qwen/QVQ-72B-Preview](https://huggingface.co/Qwen/QVQ-72B-Preview)| diff --git a/swift/llm/model/model/qwen.py b/swift/llm/model/model/qwen.py index 0ef1c8bbbb..1f3ad7cb43 100644 --- a/swift/llm/model/model/qwen.py +++ b/swift/llm/model/model/qwen.py @@ -592,7 +592,12 @@ def get_model_tokenizer_qwen2_5_vl(*args, **kwargs): Model('Qwen/Qwen2.5-VL-3B-Instruct', 'Qwen/Qwen2.5-VL-3B-Instruct'), Model('Qwen/Qwen2.5-VL-7B-Instruct', 'Qwen/Qwen2.5-VL-7B-Instruct'), Model('Qwen/Qwen2.5-VL-72B-Instruct', 'Qwen/Qwen2.5-VL-72B-Instruct'), - ]) + ]), + ModelGroup([ + Model('Qwen/Qwen2.5-VL-3B-Instruct-AWQ', 'Qwen/Qwen2.5-VL-3B-Instruct-AWQ'), + Model('Qwen/Qwen2.5-VL-7B-Instruct-AWQ', 'Qwen/Qwen2.5-VL-7B-Instruct-AWQ'), + Model('Qwen/Qwen2.5-VL-72B-Instruct-AWQ', 'Qwen/Qwen2.5-VL-72B-Instruct-AWQ'), + ]), ], TemplateType.qwen2_5_vl, get_model_tokenizer_qwen2_5_vl,