Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Update interface for idefics3 #10680

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 46 additions & 27 deletions vllm/model_executor/models/idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.inputs import NestedTensors
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import is_list_of
Expand Down Expand Up @@ -597,33 +598,21 @@ def _process_image_input(self, image_input: ImageInputs) -> torch.Tensor:
image_features = self._process_image_pixels(image_input)
return self.connector(image_features)

def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
return self.text_model.get_input_embeddings(input_ids)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
image_input = self._parse_and_validate_image_input(**kwargs)

if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.text_model.get_input_embeddings(input_ids)

inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
else:
inputs_embeds = self.text_model.get_input_embeddings(input_ids)
input_ids = None

hidden_states = self.text_model(
input_ids,
Expand Down Expand Up @@ -718,23 +707,53 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
self.sampler = Sampler()

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self.model._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self.model._process_image_input(image_input)
return vision_embeddings

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_id)
return inputs_embeds

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(
input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
**kwargs,
)
if intermediate_tensors is not None:
inputs_embeds = None

# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None

hidden_states = self.model.text_model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)

return hidden_states

def compute_logits(self, hidden_states: torch.Tensor,
Expand Down