diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index ae632376f946..05b066839d51 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -812,6 +812,8 @@ title: CLIPSeg - local: model_doc/clvp title: CLVP + - local: model_doc/colpali + title: ColPali - local: model_doc/data2vec title: Data2Vec - local: model_doc/deplot diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 0a5518fd71c8..2e6fcdf9011f 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -97,6 +97,7 @@ Flax), PyTorch, and/or TensorFlow. | [CodeGen](model_doc/codegen) | ✅ | ❌ | ❌ | | [CodeLlama](model_doc/code_llama) | ✅ | ❌ | ✅ | | [Cohere](model_doc/cohere) | ✅ | ❌ | ❌ | +| [ColPali](model_doc/colpali) | ❌ | ❌ | ❌ | | [Conditional DETR](model_doc/conditional_detr) | ✅ | ❌ | ❌ | | [ConvBERT](model_doc/convbert) | ✅ | ✅ | ❌ | | [ConvNeXT](model_doc/convnext) | ✅ | ✅ | ❌ | diff --git a/docs/source/en/model_doc/colpali.md b/docs/source/en/model_doc/colpali.md new file mode 100644 index 000000000000..1e7d629fa206 --- /dev/null +++ b/docs/source/en/model_doc/colpali.md @@ -0,0 +1,47 @@ + + +# ColPali + +## Overview + +The ColPali model was proposed in []() by . + + +The abstract from the paper is the following: + +** + +Tips: + + + +This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). +The original code can be found [here](). + + +## ColPaliConfig + +[[autodoc]] ColPaliConfig + +## ColPaliProcessor + +[[autodoc]] ColPaliProcessor + +## ColPaliForRetrieval + +[[autodoc]] ColPaliForRetrieval + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 078e4d0e4abd..1518ec4ed517 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -640,6 +640,7 @@ "OwlViTVisionConfig", ], "models.paligemma": ["PaliGemmaConfig"], + "models.colpali": ["ColPaliConfig"], "models.patchtsmixer": ["PatchTSMixerConfig"], "models.patchtst": ["PatchTSTConfig"], "models.pegasus": [ @@ -1735,6 +1736,12 @@ ] ) _import_structure["models.cohere"].extend(["CohereForCausalLM", "CohereModel", "CoherePreTrainedModel"]) + _import_structure["models.colpali"].extend( + [ + "ColPaliForRetrieval", + "ColPaliProcessor", + ] + ) _import_structure["models.conditional_detr"].extend( [ "ConditionalDetrForObjectDetection", @@ -5091,6 +5098,9 @@ CodeGenTokenizer, ) from .models.cohere import CohereConfig + from .models.colpali import ( + ColPaliConfig, + ) from .models.conditional_detr import ( ConditionalDetrConfig, ) @@ -6532,6 +6542,10 @@ CohereModel, CoherePreTrainedModel, ) + from .models.colpali import ( + ColPaliForRetrieval, + ColPaliProcessor, + ) from .models.conditional_detr import ( ConditionalDetrForObjectDetection, ConditionalDetrForSegmentation, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index e47a4ed9c342..32328af2e18f 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -51,6 +51,7 @@ code_llama, codegen, cohere, + colpali, conditional_detr, convbert, convnext, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 6d55f87d60ac..90bc6f8eac4c 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -202,6 +202,7 @@ ("owlv2", "Owlv2Config"), ("owlvit", "OwlViTConfig"), ("paligemma", "PaliGemmaConfig"), + ("colpali", "ColPaliConfig"), ("patchtsmixer", "PatchTSMixerConfig"), ("patchtst", "PatchTSTConfig"), ("pegasus", "PegasusConfig"), @@ -512,6 +513,8 @@ ("owlv2", "OWLv2"), ("owlvit", "OWL-ViT"), ("paligemma", "PaliGemma"), + ("colpali", "ColPali"), + ("colpali", "ColPali"), ("patchtsmixer", "PatchTSMixer"), ("patchtst", "PatchTST"), ("pegasus", "Pegasus"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 6e730e848db7..2777ddc53ce6 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -295,6 +295,7 @@ ("big_bird", "BigBirdForPreTraining"), ("bloom", "BloomForCausalLM"), ("camembert", "CamembertForMaskedLM"), + ("colpali", "ColPaliForRetrieval"), ("ctrl", "CTRLLMHeadModel"), ("data2vec-text", "Data2VecTextForMaskedLM"), ("deberta", "DebertaForMaskedLM"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index c894840c6ad2..b4ce89d6a9cb 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -57,6 +57,7 @@ ("clip", "CLIPProcessor"), ("clipseg", "CLIPSegProcessor"), ("clvp", "ClvpProcessor"), + ("colpali", "ColPaliProcessor"), ("flava", "FlavaProcessor"), ("fuyu", "FuyuProcessor"), ("git", "GitProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 6a5cba11f094..cb6f005c9a43 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -146,6 +146,7 @@ ), ("codegen", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)), ("cohere", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)), + ("colpali", ("PaligemmaTokenizer", "PaligemmaTokenizerFast" if is_tokenizers_available() else None)), ("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)), ( "cpm", diff --git a/src/transformers/models/colpali/__init__.py b/src/transformers/models/colpali/__init__.py new file mode 100644 index 000000000000..18d787c4e6cd --- /dev/null +++ b/src/transformers/models/colpali/__init__.py @@ -0,0 +1,51 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available + + +_import_structure = {"configuration_colpali": ["ColPaliConfig"]} + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_colpali"] = [ + "ColPaliForRetrieval", + "ColPaliPreTrainedModel", + ] + _import_structure["processing_colpali"] = ["ColPaliProcessor"] + + +if TYPE_CHECKING: + from .configuration_colpali import ColPaliConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_colpali import ColPaliForRetrieval + from .processing_colpali import ColPaliProcessor + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) diff --git a/src/transformers/models/colpali/configuration_colpali.py b/src/transformers/models/colpali/configuration_colpali.py new file mode 100644 index 000000000000..54c61575b9e1 --- /dev/null +++ b/src/transformers/models/colpali/configuration_colpali.py @@ -0,0 +1,42 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from . +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_xxx.py file directly. One of our CI enforces this +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ..paligemma import ( + PaliGemmaConfig, +) + + +class ColPaliConfig(PaliGemmaConfig): + r""" + This is the configuration class to store the configuration of a [`ColPaliForRetrieval`]. It is used to instantiate an + ColPaliForRetrieval according to the specified arguments, defining the model architecture. + + The ColPali config is stricly equivalent to the PaliGemma config, but with a different model type. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.model_type = "colpali" + self.is_composition = False diff --git a/src/transformers/models/colpali/convert_colpali_weights_to_hf.py b/src/transformers/models/colpali/convert_colpali_weights_to_hf.py new file mode 100644 index 000000000000..d174fdf84475 --- /dev/null +++ b/src/transformers/models/colpali/convert_colpali_weights_to_hf.py @@ -0,0 +1,345 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert ColPali checkpoints from the original repository.""" + +import argparse +import collections + +import torch +from numpy import load + +from transformers import ( + AutoTokenizer, + ColPaliConfig, + ColPaliForRetrieval, + ColPaliProcessor, + GemmaTokenizer, + GemmaTokenizerFast, + SiglipImageProcessor, +) +from transformers.tokenization_utils_base import AddedToken +from transformers.utils import logging + + +device = "cuda" # "cpu" + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +# TODO add sequence length variations here + +COLPALI_VARIANTS = ["2b-test", "3b-224px", "3b-448px", "3b-896px"] + + +def get_colpali_config(variant: str, precision: str): + config = { + "image_token_index": None, + "pad_token_id": 0, + "bos_token_id": 2, + "eos_token_id": 1, + } + + image_sizes = {"2b-test": 224, "3b-224px": 224, "3b-448px": 448, "3b-896px": 896} + + if variant in COLPALI_VARIANTS: + image_size = image_sizes[variant] + patch_size = 14 + num_image_tokens = (image_size**2) // (patch_size**2) + + config["image_token_index"] = 257152 if variant != "2b-test" else 256000 + text_config = { + "vocab_size": 257152, + "num_hidden_layers": 18, + "num_key_value_heads": 1, + "head_dim": 256, + "torch_dtype": precision, + "hidden_size": 2048, + "hidden_activation": "gelu_pytorch_tanh", + "num_attention_heads": 8, + "intermediate_size": 16384, + "is_encoder_decoder": False, + } + vision_config = { + "torch_dtype": precision, + "image_size": image_size, + "patch_size": patch_size, + "num_image_tokens": num_image_tokens, + "hidden_size": 1152, + "intermediate_size": 4304, + "num_hidden_layers": 27, + "num_attention_heads": 16, + "projector_hidden_act": "gelu_fast", + "vision_use_head": False, + } + final_config = ColPaliConfig(text_config=text_config, vision_config=vision_config, **config) + else: + raise ValueError(f"Identifier {variant} not supported. Available: {COLPALI_VARIANTS}") + return final_config + + +def slice_state_dict(state_dict, config): + # fmt: off + # patch embeddings + state_dict["vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop("img/embedding/kernel").transpose( + 3, 2, 0, 1 + ) + state_dict["vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop("img/embedding/bias") + # positional embeddings + state_dict["vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop("img/pos_embedding").reshape( + -1, config.vision_config.hidden_size + ) + + # extract vision layers to be sliced at index 0. There are 27 layers in the base model. + encoderblock_layernorm0_scale = state_dict.pop("img/Transformer/encoderblock/LayerNorm_0/scale") + encoderblock_layernorm0_bias = state_dict.pop("img/Transformer/encoderblock/LayerNorm_0/bias") + encoderblock_layernorm1_scale = state_dict.pop("img/Transformer/encoderblock/LayerNorm_1/scale") + encoderblock_layernorm1_bias = state_dict.pop("img/Transformer/encoderblock/LayerNorm_1/bias") + + encoderblock_mlp_dense0_kernel= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel") + encoderblock_mlp_dense0_bias= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias") + encoderblock_mlp_dense1_kernel= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel") + encoderblock_mlp_dense1_bias= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias") + + encoderblock_attention_0_key_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel") + encoderblock_attention_0_key_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias") + encoderblock_attention_0_value_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel") + encoderblock_attention_0_value_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias") + encoderblock_attention_0_query_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel") + encoderblock_attention_0_query_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias") + encoderblock_attention_0_out_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel") + encoderblock_attention_0_out_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias") + + for i in range(config.vision_config.num_hidden_layers): + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i] + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i] + + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i] + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i] + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + + state_dict["vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop("img/Transformer/encoder_norm/scale").transpose() + state_dict["vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop("img/Transformer/encoder_norm/bias") + + # multimodal projector + + state_dict['multi_modal_projector.linear.weight'] = state_dict.pop("img/head/kernel").transpose() + state_dict['multi_modal_projector.linear.bias'] = state_dict.pop("img/head/bias") + + # text decoder (gemma) + + embedding_vector = state_dict.pop("llm/embedder/input_embedding") + state_dict["language_model.model.embed_tokens.weight"] = embedding_vector + + # pop the einsum attention + mlp representations. There are 18 layers in gemma-2b. + + llm_attention_attn_vec_einsum = state_dict.pop("llm/layers/attn/attn_vec_einsum/w") + llm_attention_kv_einsum = state_dict.pop("llm/layers/attn/kv_einsum/w") + llm_attention_q_einsum = state_dict.pop("llm/layers/attn/q_einsum/w") + + llm_mlp_gating_einsum = state_dict.pop("llm/layers/mlp/gating_einsum") + llm_mlp_linear = state_dict.pop("llm/layers/mlp/linear") + # TODO verify correctness of layer norm loading + + llm_input_layernorm = state_dict.pop("llm/layers/pre_attention_norm/scale") + llm_post_attention_layernorm = state_dict.pop("llm/layers/pre_ffw_norm/scale") + + for i in range(config.text_config.num_hidden_layers): + # llm_attention_q_einsum[i].shape = (8, 2048, 256) + q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size) + + state_dict[f"language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped + + # llm_attention_kv_einsum[i, 0, 0].shape = (2048, 256) + k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose() + state_dict[f"language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped + # llm_attention_kv_einsum[i, 1, 0].shape = (2048, 256) + v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose() + state_dict[f"language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped + + # output projection. + + # llm_attention_attn_vec_einsum[i].shape = (8, 256, 2048) + o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size) + + state_dict[f"language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped + # mlp layers + gate_proj_weight = llm_mlp_gating_einsum[i, 0] + state_dict[f"language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose() + up_proj_weight = llm_mlp_gating_einsum[i, 1] + state_dict[f"language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose() + state_dict[f"language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose() + state_dict[f"language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i] + state_dict[f"language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i] + + state_dict["language_model.model.norm.weight"] = state_dict.pop("llm/final_norm/scale") + state_dict["language_model.lm_head.weight"] = embedding_vector # weights are tied. + + # fmt: on + for key, value in state_dict.items(): + state_dict[key] = torch.from_numpy(value) + return state_dict + + +def flatten_nested_dict(params, parent_key="", sep="/"): + items = [] + + for k, v in params.items(): + k = k.removeprefix("params/") + new_key = parent_key + sep + k if parent_key else k + + if isinstance(v, collections.abc.MutableMapping): + items.extend(flatten_nested_dict(v, parent_key=new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +@torch.no_grad() +def convert_colpali_checkpoint( + checkpoint_path, + tokenizer_model_file, + pytorch_dump_folder_path, + variant: str, + precision: str, + do_convert_weights=False, +): + """ + Read checkpoints from flax npz files, rename/reshape, send result to state dict and verify logits if needed. + """ + config = get_colpali_config(variant, precision=precision) + if do_convert_weights: + if variant == "2b-test": + # for the test model, the vocabulary was smaller + tokenizer_id = "google/gemma-2b" + tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + else: + tokenizer_class = GemmaTokenizer if GemmaTokenizerFast is None else GemmaTokenizerFast + tokenizer = tokenizer_class(tokenizer_model_file) + image_token = AddedToken("", normalized=False, special=True) + tokens_to_add = {"additional_special_tokens": [image_token]} + tokenizer.add_special_tokens(tokens_to_add) + + # tokenizer.padding_side = 'right' # uncomment for testing purposes only. + + image_processor = SiglipImageProcessor.from_pretrained("google/siglip-so400m-patch14-384") + image_processor.size = {"width": config.vision_config.image_size, "height": config.vision_config.image_size} + image_processor.image_seq_length = config.vision_config.num_image_tokens + + processor = ColPaliProcessor(image_processor=image_processor, tokenizer=tokenizer) + data = load(checkpoint_path) + state_dict = flatten_nested_dict(data) + del data + state_dict_transformers = slice_state_dict(state_dict, config) + del state_dict + + model = ColPaliForRetrieval(config).to(device).eval() + model.load_state_dict(state_dict_transformers) + del state_dict_transformers + + else: + processor = ColPaliProcessor.from_pretrained(pytorch_dump_folder_path) + model = ( + ColPaliForRetrieval.from_pretrained(pytorch_dump_folder_path, attn_implementation="sdpa").to(device).eval() + ) + model.config.text_config._attn_implementation = "sdpa" + + # model expansion to get random embeds of image tokens + pad_shape = 64 # for performance reasons + pre_expansion_embeddings = model.language_model.model.embed_tokens.weight.data + mu = torch.mean(pre_expansion_embeddings, dim=0).float() + n = pre_expansion_embeddings.size()[0] + sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n + dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma) + + # We add an image token so we resize the model + model.resize_token_embeddings(config.text_config.vocab_size + 2, pad_shape) + model.language_model.model.embed_tokens.weight.data[257152:] = torch.stack( + tuple((dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[257152:].shape[0]))), + dim=0, + ) + model.language_model.lm_head.weight.data[257152:] = torch.stack( + tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[257152:].shape[0]))), + dim=0, + ) + + model.save_pretrained(pytorch_dump_folder_path, max_shard_size="2GB", safe_serialization=True) + processor.save_pretrained(pytorch_dump_folder_path) + + +# + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--checkpoint_path", + required=True, + type=str, + help="Path to the .npz checkpoint", + ) + + parser.add_argument( + "--tokenizer_model_file", + required=True, + type=str, + help="Path to the sentencepiece tokenizer.model file", + ) + + parser.add_argument( + "--pytorch_dump_folder_path", + required=True, + type=str, + help="Path to the output directory where model and processor will be saved.", + ) + + parser.add_argument( + "--precision", + choices=["float32", "bfloat16", "float16"], + type=str, + help="Precision identifier for model conversion - should match the base checkpoint precision.", + ) + + parser.add_argument( + "--variant", + default="2b-test", + choices=COLPALI_VARIANTS, + type=str, + help="String identifier of the colpali variant to convert.", + ) + + parser.add_argument( + "--do_convert_weights", action="store_true", help="Whether or not to reload and convert the weights." + ) + + args = parser.parse_args() + convert_colpali_checkpoint( + checkpoint_path=args.checkpoint_path, + tokenizer_model_file=args.tokenizer_model_file, + pytorch_dump_folder_path=args.pytorch_dump_folder_path, + variant=args.variant, + precision=args.precision, + do_convert_weights=args.do_convert_weights, + ) diff --git a/src/transformers/models/colpali/modeling_colpali.py b/src/transformers/models/colpali/modeling_colpali.py new file mode 100644 index 000000000000..59fdaa49c4d9 --- /dev/null +++ b/src/transformers/models/colpali/modeling_colpali.py @@ -0,0 +1,191 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from . +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_xxx.py file directly. One of our CI enforces this +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import ClassVar, List, Optional, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...cache_utils import Cache +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from ..paligemma import ( + PaliGemmaConfig, + PaliGemmaForConditionalGeneration, + PaliGemmaPreTrainedModel, +) + + +@dataclass +class ColPaliModelOutput(ModelOutput): + """ + Base class for ColPali embeddings output. + + Args: + embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + The embeddings of the model. + """ + + embeddings: torch.Tensor + + +@add_start_docstrings( + """ + ColPali is a PaliGemma variant to produce multi-vector representations from images. + It was introduced in the paper [ColPali: Efficient Document Retrieval with Vision Language Models](https://arxiv.org/abs/2407.01449). + + Resources: + - A blog post detailing ColPali, a vision retrieval model, can be found [here](https://huggingface.co/blog/manu/colpali). 📝 + - The code for training ColPali and for the `colpali-engine` package can be found [here](https://github.com/illuin-tech/colpali). 🌎 + - Cookbooks to fine-tune ColPali (with optional quantization), generate similarity maps, ... can be found [here](https://github.com/tonywu71/colpali-cookbooks). 📚 + + Adapted from colpali-engine==0.3.0: https://github.com/illuin-tech/colpali. + """ +) +class ColPaliForRetrieval(PaliGemmaPreTrainedModel): + main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related + + def __init__(self, config: PaliGemmaConfig): + super().__init__(config=config) + + model = PaliGemmaForConditionalGeneration(config=config) + if model.language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"model.language_model.{k}" for k in model.language_model._tied_weights_keys] + self.model = model + + self.dim = 128 + self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim) + + self.post_init() + + @add_start_docstrings_to_model_forward( + """ + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`SiglipImageProcessor.__call__`] for details ([]`PaliGemmaProcessor`] uses + [`SiglipImageProcessor`] for processing images). If none, ColPali will only process text (query embeddings). + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + """ + ) + @replace_return_docstrings(output_type=ColPaliModelOutput, config_class="ColPaliConfig") + def forward( + self, + input_ids: torch.LongTensor, + pixel_values: torch.FloatTensor, + attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + num_logits_to_keep: int = 0, + ) -> ColPaliModelOutput: + r""" + Returns: + """ + outputs = self.model( + input_ids, + pixel_values, + attention_mask, + position_ids, + past_key_values, + token_type_ids, + cache_position, + inputs_embeds, + labels, + use_cache, + output_attentions, + num_logits_to_keep, + output_hidden_states=True, + ) + last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size) + proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim) + + # L2 normalization + proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim) + + proj = proj * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim) + + return ColPaliModelOutput(embeddings=proj) + + def get_input_embeddings(self): + return self.model.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.model.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.model.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.model.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.model.language_model.get_decoder() + + def tie_weights(self): + return self.model.language_model.tie_weights() + + def resize_token_embeddings( + self, + new_num_tokens: Optional[int] = None, + pad_to_multiple_of=None, + ) -> nn.Embedding: + model_embeds = self.model.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + + # Update vocab size + self.config.text_config.vocab_size = model_embeds.num_embeddings + self.config.vocab_size = model_embeds.num_embeddings + self.model.vocab_size = model_embeds.num_embeddings + + return model_embeds diff --git a/src/transformers/models/colpali/modular_colpali.py b/src/transformers/models/colpali/modular_colpali.py new file mode 100644 index 000000000000..0075a7402b4e --- /dev/null +++ b/src/transformers/models/colpali/modular_colpali.py @@ -0,0 +1,336 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from dataclasses import dataclass +from typing import ClassVar, List, Optional, Union + +import torch +import torch.utils.checkpoint +from PIL import Image +from torch import nn + +from ...cache_utils import Cache +from ...feature_extraction_utils import BatchFeature +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + logging, + replace_return_docstrings, +) +from ..paligemma import ( + PaliGemmaConfig, + PaliGemmaForConditionalGeneration, + PaliGemmaPreTrainedModel, + PaliGemmaProcessor, +) + + +if is_flash_attn_2_available(): + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +logger = logging.get_logger(__name__) + + +class ColPaliConfig(PaliGemmaConfig): + r""" + This is the configuration class to store the configuration of a [`ColPaliForRetrieval`]. It is used to instantiate an + ColPaliForRetrieval according to the specified arguments, defining the model architecture. + + The ColPali config is stricly equivalent to the PaliGemma config, but with a different model type. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.model_type = "colpali" + self.is_composition = False + + +class ColPaliProcessor(PaliGemmaProcessor): + r""" + Processor for ColPali. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.mock_image = Image.new("RGB", (16, 16), color="black") + + @staticmethod + def get_torch_device(device: str = "auto") -> str: + """ + Returns the device (string) to be used by PyTorch. + + `device` arg defaults to "auto" which will use: + - "cuda:0" if available + - else "mps" if available + - else "cpu". + """ + + if device == "auto": + if torch.cuda.is_available(): + device = "cuda:0" + elif torch.backends.mps.is_available(): # for Apple Silicon + device = "mps" + else: + device = "cpu" + logger.info(f"Using device: {device}") + + return device + + def process_images( + self, + images: List[Image.Image], + ) -> BatchFeature: + """ + Process images for ColPali. + """ + texts_doc = ["Describe the image."] * len(images) + images = [image.convert("RGB") for image in images] + + batch_doc = self( + text=texts_doc, + images=images, + return_tensors="pt", + padding="longest", + ) + return batch_doc + + def process_queries( + self, + queries: List[str], + max_length: int = 50, + suffix: Optional[str] = None, + ) -> BatchFeature: + """ + Process queries for ColPali. + """ + if suffix is None: + suffix = "" * 10 + texts_query: List[str] = [] + + for query in queries: + query = f"Question: {query}" + query += suffix # add suffix (pad tokens) + texts_query.append(query) + + batch_query = self( + images=[self.mock_image] * len(texts_query), + text=texts_query, + return_tensors="pt", + padding="longest", + max_length=max_length + self.image_seq_length, + ) + + del batch_query["pixel_values"] + + batch_query["input_ids"] = batch_query["input_ids"][..., self.image_seq_length :] + batch_query["attention_mask"] = batch_query["attention_mask"][..., self.image_seq_length :] + + return batch_query + + def score( + self, + qs: List[torch.Tensor], + ps: List[torch.Tensor], + batch_size: int = 128, + device: Optional[Union[str, torch.device]] = None, + ) -> torch.Tensor: + """ + Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. + """ + device = device or self.get_torch_device("auto") + + if len(qs) == 0: + raise ValueError("No queries provided") + if len(ps) == 0: + raise ValueError("No passages provided") + + scores_list: List[torch.Tensor] = [] + + for i in range(0, len(qs), batch_size): + scores_batch = [] + qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to( + device + ) + for j in range(0, len(ps), batch_size): + ps_batch = torch.nn.utils.rnn.pad_sequence( + ps[j : j + batch_size], batch_first=True, padding_value=0 + ).to(device) + scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2)) + scores_batch = torch.cat(scores_batch, dim=1).cpu() + scores_list.append(scores_batch) + + scores = torch.cat(scores_list, dim=0) + assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}" + + scores = scores.to(torch.float32) + return scores + + +@dataclass +class ColPaliModelOutput(ModelOutput): + """ + Base class for ColPali embeddings output. + + Args: + embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + The embeddings of the model. + """ + + embeddings: torch.Tensor + + +@add_start_docstrings( + """ + ColPali is a PaliGemma variant to produce multi-vector representations from images. + It was introduced in the paper [ColPali: Efficient Document Retrieval with Vision Language Models](https://arxiv.org/abs/2407.01449). + + Resources: + - A blog post detailing ColPali, a vision retrieval model, can be found [here](https://huggingface.co/blog/manu/colpali). 📝 + - The code for training ColPali and for the `colpali-engine` package can be found [here](https://github.com/illuin-tech/colpali). 🌎 + - Cookbooks to fine-tune ColPali (with optional quantization), generate similarity maps, ... can be found [here](https://github.com/tonywu71/colpali-cookbooks). 📚 + + Adapted from colpali-engine==0.3.0: https://github.com/illuin-tech/colpali. + """ +) +class ColPaliForRetrieval(PaliGemmaPreTrainedModel): + main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related + + def __init__(self, config: PaliGemmaConfig): + super().__init__(config=config) + + model = PaliGemmaForConditionalGeneration(config=config) + if model.language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"model.language_model.{k}" for k in model.language_model._tied_weights_keys] + self.model = model + + self.dim = 128 + self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim) + + self.post_init() + + @add_start_docstrings_to_model_forward( + """ + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`SiglipImageProcessor.__call__`] for details ([]`PaliGemmaProcessor`] uses + [`SiglipImageProcessor`] for processing images). If none, ColPali will only process text (query embeddings). + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + """ + ) + @replace_return_docstrings(output_type=ColPaliModelOutput, config_class="ColPaliConfig") + def forward( + self, + input_ids: torch.LongTensor, + pixel_values: torch.FloatTensor, + attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + num_logits_to_keep: int = 0, + ) -> ColPaliModelOutput: + r""" + Returns: + """ + outputs = self.model( + input_ids, + pixel_values, + attention_mask, + position_ids, + past_key_values, + token_type_ids, + cache_position, + inputs_embeds, + labels, + use_cache, + output_attentions, + num_logits_to_keep, + output_hidden_states=True, + ) + last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size) + proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim) + + # L2 normalization + proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim) + + proj = proj * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim) + + return ColPaliModelOutput(embeddings=proj) + + def get_input_embeddings(self): + return self.model.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.model.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.model.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.model.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.model.language_model.get_decoder() + + def tie_weights(self): + return self.model.language_model.tie_weights() + + def resize_token_embeddings( + self, + new_num_tokens: Optional[int] = None, + pad_to_multiple_of=None, + ) -> nn.Embedding: + model_embeds = self.model.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + + # Update vocab size + self.config.text_config.vocab_size = model_embeds.num_embeddings + self.config.vocab_size = model_embeds.num_embeddings + self.model.vocab_size = model_embeds.num_embeddings + + return model_embeds diff --git a/src/transformers/models/colpali/processing_colpali.py b/src/transformers/models/colpali/processing_colpali.py new file mode 100644 index 000000000000..c19f2ff09c5c --- /dev/null +++ b/src/transformers/models/colpali/processing_colpali.py @@ -0,0 +1,152 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from . +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_xxx.py file directly. One of our CI enforces this +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import torch +import torch.utils.checkpoint +from PIL import Image + +from ...feature_extraction_utils import BatchFeature +from ..paligemma import ( + PaliGemmaProcessor, +) + + +class ColPaliProcessor(PaliGemmaProcessor): + r""" + Processor for ColPali. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.mock_image = Image.new("RGB", (16, 16), color="black") + + @staticmethod + def get_torch_device(device: str = "auto") -> str: + """ + Returns the device (string) to be used by PyTorch. + + `device` arg defaults to "auto" which will use: + - "cuda:0" if available + - else "mps" if available + - else "cpu". + """ + + if device == "auto": + if torch.cuda.is_available(): + device = "cuda:0" + elif torch.backends.mps.is_available(): # for Apple Silicon + device = "mps" + else: + device = "cpu" + logger.info(f"Using device: {device}") + + return device + + def process_images( + self, + images: List[Image.Image], + ) -> BatchFeature: + """ + Process images for ColPali. + """ + texts_doc = ["Describe the image."] * len(images) + images = [image.convert("RGB") for image in images] + + batch_doc = self( + text=texts_doc, + images=images, + return_tensors="pt", + padding="longest", + ) + return batch_doc + + def process_queries( + self, + queries: List[str], + max_length: int = 50, + suffix: Optional[str] = None, + ) -> BatchFeature: + """ + Process queries for ColPali. + """ + if suffix is None: + suffix = "" * 10 + texts_query: List[str] = [] + + for query in queries: + query = f"Question: {query}" + query += suffix # add suffix (pad tokens) + texts_query.append(query) + + batch_query = self( + images=[self.mock_image] * len(texts_query), + text=texts_query, + return_tensors="pt", + padding="longest", + max_length=max_length + self.image_seq_length, + ) + + del batch_query["pixel_values"] + + batch_query["input_ids"] = batch_query["input_ids"][..., self.image_seq_length :] + batch_query["attention_mask"] = batch_query["attention_mask"][..., self.image_seq_length :] + + return batch_query + + def score( + self, + qs: List[torch.Tensor], + ps: List[torch.Tensor], + batch_size: int = 128, + device: Optional[Union[str, torch.device]] = None, + ) -> torch.Tensor: + """ + Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. + """ + device = device or self.get_torch_device("auto") + + if len(qs) == 0: + raise ValueError("No queries provided") + if len(ps) == 0: + raise ValueError("No passages provided") + + scores_list: List[torch.Tensor] = [] + + for i in range(0, len(qs), batch_size): + scores_batch = [] + qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to( + device + ) + for j in range(0, len(ps), batch_size): + ps_batch = torch.nn.utils.rnn.pad_sequence( + ps[j : j + batch_size], batch_first=True, padding_value=0 + ).to(device) + scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2)) + scores_batch = torch.cat(scores_batch, dim=1).cpu() + scores_list.append(scores_batch) + + scores = torch.cat(scores_list, dim=0) + assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}" + + scores = scores.to(torch.float32) + return scores diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index f4e471ee7ab5..72b41c97a7be 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -2157,6 +2157,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class ColPaliForRetrieval(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ColPaliProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class ConditionalDetrForObjectDetection(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/colpali/__init__.py b/tests/models/colpali/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/colpali/test_modeling_colpali.py b/tests/models/colpali/test_modeling_colpali.py new file mode 100644 index 000000000000..1ac8017de5e9 --- /dev/null +++ b/tests/models/colpali/test_modeling_colpali.py @@ -0,0 +1,105 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch ColPali model.""" + +from typing import Generator, cast + +import pytest +import torch +from PIL import Image + +from transformers.models.colpali import ColPaliForRetrieval, ColPaliProcessor +from transformers.models.colpali.processing_colpali import get_torch_device + + +@pytest.fixture(scope="module") +def colpali_model_path() -> str: + return "vidore/colpali-v1.2" + + +@pytest.fixture(scope="module") +def colpali_from_pretrained(colpali_model_path: str) -> Generator[ColPaliForRetrieval, None, None]: + device = get_torch_device("auto") + print(f"Device used: {device}") + + yield cast( + ColPaliForRetrieval, + ColPaliForRetrieval.from_pretrained( + colpali_model_path, + torch_dtype=torch.bfloat16, + device_map="cpu", + ), + ) + + +@pytest.fixture(scope="module") +def processor() -> Generator[ColPaliProcessor, None, None]: + yield cast(ColPaliProcessor, ColPaliProcessor.from_pretrained("google/paligemma-3b-mix-448")) + + +@pytest.mark.slow +def test_load_colpali_from_pretrained(colpali_from_pretrained: ColPaliForRetrieval): + assert isinstance(colpali_from_pretrained, ColPaliForRetrieval) + + +@pytest.mark.slow +def test_colpali_forward_images( + colpali_from_pretrained: ColPaliForRetrieval, + processor: ColPaliProcessor, +): + # Create a batch of dummy images + images = [ + Image.new("RGB", (32, 32), color="white"), + Image.new("RGB", (16, 16), color="black"), + ] + + # Process the image + batch_images = processor.process_images(images).to(colpali_from_pretrained.device) + + # Forward pass + with torch.no_grad(): + outputs = colpali_from_pretrained(**batch_images) + + # Assertions + assert isinstance(outputs, torch.Tensor) + assert outputs.dim() == 3 + batch_size, n_visual_tokens, emb_dim = outputs.shape + assert batch_size == len(images) + assert emb_dim == colpali_from_pretrained.dim + + +@pytest.mark.slow +def test_colpali_forward_queries( + colpali_from_pretrained: ColPaliForRetrieval, + processor: ColPaliProcessor, +): + queries = [ + "Is attention really all you need?", + "Are Benjamin, Antoine, Merve, and Jo best friends?", + ] + + # Process the queries + batch_queries = processor.process_queries(queries).to(colpali_from_pretrained.device) + + # Forward pass + with torch.no_grad(): + outputs = colpali_from_pretrained(**batch_queries) + + # Assertions + assert isinstance(outputs, torch.Tensor) + assert outputs.dim() == 3 + batch_size, n_query_tokens, emb_dim = outputs.shape + assert batch_size == len(queries) + assert emb_dim == colpali_from_pretrained.dim diff --git a/tests/models/colpali/test_processing_colpali.py b/tests/models/colpali/test_processing_colpali.py new file mode 100644 index 000000000000..ef8bde0c6979 --- /dev/null +++ b/tests/models/colpali/test_processing_colpali.py @@ -0,0 +1,49 @@ +from typing import Generator, cast + +import pytest +import torch +from PIL import Image + +from colpali_engine.models import ColPaliProcessor + + +@pytest.fixture(scope="module") +def colpali_processor_path() -> str: + return "google/paligemma-3b-mix-448" + + +@pytest.fixture(scope="module") +def processor_from_pretrained(colpali_processor_path: str) -> Generator[ColPaliProcessor, None, None]: + yield cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(colpali_processor_path)) + + +def test_load_processor_from_pretrained(processor_from_pretrained: ColPaliProcessor): + assert isinstance(processor_from_pretrained, ColPaliProcessor) + + +def test_process_images(processor_from_pretrained: ColPaliProcessor): + # Create a dummy image + image = Image.new("RGB", (16, 16), color="black") + images = [image] + + # Process the image + batch_feature = processor_from_pretrained.process_images(images) + + # Assertions + assert "pixel_values" in batch_feature + assert batch_feature["pixel_values"].shape == torch.Size([1, 3, 448, 448]) + + +def test_process_queries(processor_from_pretrained: ColPaliProcessor): + queries = [ + "Is attention really all you need?", + "Are Benjamin, Antoine, Merve, and Jo best friends?", + ] + + # Process the queries + batch_encoding = processor_from_pretrained.process_queries(queries) + + # Assertions + assert "input_ids" in batch_encoding + assert isinstance(batch_encoding["input_ids"], torch.Tensor) + assert cast(torch.Tensor, batch_encoding["input_ids"]).shape[0] == len(queries) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index b6e99fbf5edd..c5bf769f9288 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -502,7 +502,7 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef, TYPE_TO_FILE_TYPE = { "Config": "configuration", "Tokenizer": "tokenization", - "Processor": "processor", + "Processor": "processing", "ImageProcessor": "image_processing", "FeatureExtractor": "feature_extractor", }