Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

推理响应的数据是一串json数据 #668

Open
hackerhaiJu opened this issue Jan 23, 2025 · 0 comments
Open

推理响应的数据是一串json数据 #668

hackerhaiJu opened this issue Jan 23, 2025 · 0 comments

Comments

@hackerhaiJu
Copy link

import os  # 添加os模块用于路径检查
import torch
from modelscope import snapshot_download
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info

os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"

def run_model(model_path: str):
    # 检查模型路径是否存在
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model path {model_path} does not exist.")

    model = Qwen2VLForConditionalGeneration.from_pretrained(
        model_path, device_map="auto", torch_dtype='auto'
    )
    processor = AutoProcessor.from_pretrained(model_path)

    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": "file:///app/qwen/images/demo3.png",
                },
                {"type": "text", "text": "识别图中的表单信息,并且以key:value的形式返回给我"},
            ],
        }
    ]
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")

    # Inference: Generation of the output
    generated_ids = model.generate(**inputs, max_new_tokens=128)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    print(output_text)


if __name__ == '__main__':
    run_model('/app/qwen/Qwen2-VL-7B-Instruct')

为什么推理出的是一串json数据呢?

Image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant