Skip to content

Commit

Permalink
added offline examples
Browse files Browse the repository at this point in the history
Signed-off-by: Shanshan Wang <[email protected]>
  • Loading branch information
cooleel committed Oct 28, 2024
1 parent 819f608 commit 4f36fcf
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 32 deletions.
26 changes: 26 additions & 0 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,31 @@ def run_minicpmv(question: str, modality: str):
return llm, prompt, stop_token_ids


# H2OVL-Mississippi
def run_h2ovl(question: str, modality: str):
assert modality == "image"

model_name = "h2oai/h2ovl-mississippi-2b"

llm = LLM(
model=model_name,
trust_remote_code=True,
max_model_len=8192,
)

tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)

# Stop tokens for H2OVL-Mississippi
# https://huggingface.co/h2oai/h2ovl-mississippi-2b
stop_token_ids = [tokenizer.eos_token_id]
return llm, prompt, stop_token_ids


# InternVL
def run_internvl(question: str, modality: str):
assert modality == "image"
Expand Down Expand Up @@ -364,6 +389,7 @@ def run_glm4v(question: str, modality: str):
"chameleon": run_chameleon,
"minicpmv": run_minicpmv,
"blip-2": run_blip2,
"h2ovl_chat": run_h2ovl,
"internvl_chat": run_internvl,
"NVLM_D": run_nvlm_d,
"qwen_vl": run_qwen_vl,
Expand Down
34 changes: 34 additions & 0 deletions examples/offline_inference_vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,38 @@ def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
chat_template=None,
)

def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData:
model_name = "h2oai/h2ovl-mississippi-2b"

llm = LLM(
model=model_name,
trust_remote_code=True,
max_model_len=8192,
limit_mm_per_prompt={"image": len(image_urls)},
mm_processor_kwargs={"max_dynamic_patch": 4},
)

placeholders = "\n".join(f"Image-{i}: <image>\n"
for i, _ in enumerate(image_urls, start=1))
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]

tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)

# Stop tokens for H2OVL-Mississippi
# https://huggingface.co/h2oai/h2ovl-mississippi-2b
stop_token_ids = [tokenizer.eos_token_id]

return ModelRequestData(
llm=llm,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)

def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData:
model_name = "OpenGVLab/InternVL2-2B"
Expand Down Expand Up @@ -258,6 +290,7 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData:

model_example_map = {
"phi3_v": load_phi3v,
"h2ovl_chat": load_h2onvl,
"internvl_chat": load_internvl,
"NVLM_D": load_nvlm_d,
"qwen2_vl": load_qwen2_vl,
Expand Down Expand Up @@ -285,6 +318,7 @@ def run_generate(model, question: str, image_urls: List[str]):
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)



def run_chat(model: str, question: str, image_urls: List[str]):
Expand Down
47 changes: 15 additions & 32 deletions vllm/model_executor/models/h2ovl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# adapted from https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/modeling_h2ovl_chat.py
# https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/image_process.py
# --------------------------------------------------------
# H2OVL
# H2OVL-Mississippi
# Copyright (c) 2024 H2O.AI
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
Expand Down Expand Up @@ -30,7 +31,7 @@
IMG_START, IMG_END, IMG_CONTEXT)


# Modified to include blocks generated in second pass
# modified to include blocks generated in second pass
def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
max_num: int, image_size: int,
use_thumbnail: bool,
Expand All @@ -43,7 +44,7 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
if i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

# If prior_aspect_ratio is provided, filter the target ratios
# if prior_aspect_ratio is provided, filter the target ratios
if prior_aspect_ratio is not None:
target_ratios = [ratio for ratio in target_ratios if
prior_aspect_ratio[0] % ratio[0] != 0 and prior_aspect_ratio[1] % ratio[1] != 0]
Expand Down Expand Up @@ -95,7 +96,7 @@ def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
return processed_images, target_aspect_ratio


# New dynamic_preprocess2 with prior_aspect_ratio
# new dynamic_preprocess2 with prior_aspect_ratio
def dynamic_preprocess2(image: Image.Image, min_num: int, max_num: int,
image_size: int, use_thumbnail: bool, prior_aspect_ratio: Tuple[int, int]) -> List[Image.Image]:
orig_width, orig_height = image.size
Expand Down Expand Up @@ -148,7 +149,7 @@ def image_to_pixel_values(image:Image.Image,
input_size: int, min_num: int,
max_num: int, use_thumbnail: bool,
use_MSAC: bool) -> torch.Tensor:
# When MSAC is turned on, we need to preprocess the image twice
# when MSAC is turned on, we need to process the image twice
if use_MSAC:
pixel_values, target_aspect_ratio = load_image1(image, input_size=input_size, min_num=min_num, max_num=max_num)
pixel_values2 = load_image2(image, input_size=input_size, min_num=min_num, max_num=max_num, target_aspect_ratio=target_aspect_ratio)
Expand Down Expand Up @@ -185,37 +186,21 @@ def image_to_pixel_values_wrapper(hf_config: PretrainedConfig,
def get_max_internvl_image_tokens(ctx: InputContext,
*,
max_dynamic_patch: Optional[int] = None):
"""
Calculate the maximum number of tokens with/without MSAC and thumbnail
"""
hf_config = ctx.get_hf_config()
vision_config = hf_config.vision_config


use_thumbnail = hf_config.use_thumbnail
max_dynamic_patch = hf_config.max_dynamic_patch
use_MSAC = hf_config.use_msac

if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch

# calculate the actual max_dy
print('The max_dynamic_patch is:', max_dynamic_patch)

image_size = vision_config.image_size
num_patches = get_internvl_num_patches(hf_config)
# return num_patches * max_dynamic_patch

min_num = hf_config.min_dynamic_patch
max_num = hf_config.max_dynamic_patch

# Assuming we're calculating for a dummy image with maximum size
max_image_width, max_image_height = get_max_internvl_image_size(ctx, max_dynamic_patch=max_dynamic_patch)
dummy_image = Image.new('RGB', (max_image_width, max_image_height))

# Calculate num_blocks based on the dummy image's size
num_blocks = image_to_pixel_values(dummy_image,
image_size,
min_num,
max_num,
use_thumbnail=use_thumbnail,
use_MSAC=use_MSAC).shape[0]

# Return the final token count: num_blocks * num_patches
coefficient = 2 if use_MSAC else 1
num_blocks = coefficient * max_dynamic_patch + (1 if use_thumbnail else 0)

return num_blocks * num_patches


Expand Down Expand Up @@ -337,8 +322,6 @@ def _init_vision_model(
else:
num_hidden_layers = vision_feature_layer + 1

# We added additional dummy heads to the original num of heads to
# make the number of heads divisible by 8.
return InternVisionModel(
config.vision_config,
quant_config=quant_config,
Expand Down

0 comments on commit 4f36fcf

Please sign in to comment.