Skip to content

Commit

Permalink
Align nanollava input with original model (#1132)
Browse files Browse the repository at this point in the history
* align nanollava input with original model

* update test refs

* properly fix quantization

* Update tests/openvino/test_modeling.py
  • Loading branch information
eaidova authored Jan 29, 2025
1 parent a59bb41 commit 68cacea
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
9 changes: 8 additions & 1 deletion optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,8 +695,10 @@ def forward(
image_grid_thw=None,
video_grid_thw=None,
rope_deltas=None,
images=None,
**kwargs,
):
pixel_values = pixel_values if pixel_values is not None else images
inputs_embeds, attention_mask, position_ids = self.get_multimodal_embeddings(
input_ids,
pixel_values,
Expand Down Expand Up @@ -794,6 +796,9 @@ def prepare_inputs_for_generation(
else:
model_inputs = {"input_ids": input_ids}

if pixel_values is None:
pixel_values = kwargs.get("images")

model_inputs.update(
{
"position_ids": position_ids,
Expand Down Expand Up @@ -1733,6 +1738,8 @@ def get_multimodal_embeddings(
vision_embeds = None
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
if pixel_values is None and "images" in kwargs:
pixel_values = kwargs["images"]
if pixel_values is not None:
vision_embeds = self.get_vision_embeddings(pixel_values, input_ids=input_ids, **kwargs)
if vision_embeds is None:
Expand Down Expand Up @@ -1907,7 +1914,7 @@ def preprocess_inputs(
attention_mask = torch.ones_like(input_ids, dtype=torch.int64)
result = {"input_ids": input_ids, "attention_mask": attention_mask}
if image is not None:
result["pixel_values"] = processor(images=[image], return_tensors="pt")["pixel_values"]
result["images"] = processor(images=[image], return_tensors="pt")["pixel_values"]
return result


Expand Down
6 changes: 1 addition & 5 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2182,11 +2182,7 @@ def test_compare_to_transformers(self, model_arch):
ov_model.clear_requests()
self._check_device_and_request(ov_model, test_device, False)

# nanollava pixel_values input named as images
if model_arch == "nanollava":
pixel_values = transformers_inputs.pop("pixel_values", None)
transformers_inputs["images"] = pixel_values
# pytorch minicpmv is not designed to be used via forward
# pytorch minicpmv and internvl2 are not designed to be used via forward
if model_arch not in ["minicpmv", "internvl2"]:
set_seed(SEED)
ov_outputs = ov_model(**inputs)
Expand Down

0 comments on commit 68cacea

Please sign in to comment.