Skip to content

Commit

Permalink
Enhance Clip_Score to calculate similarities between same modalities (
Browse files Browse the repository at this point in the history
#2875)

* Handle zero division error in binary IoU (Jaccard index) calculation

---------

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Jan 25, 2025
1 parent 520a868 commit 5fc2e0b
Show file tree
Hide file tree
Showing 3 changed files with 330 additions and 69 deletions.
230 changes: 178 additions & 52 deletions src/torchmetrics/functional/multimodal/clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List, Union
from typing import TYPE_CHECKING, List, Union, cast

import torch
from torch import Tensor
Expand Down Expand Up @@ -41,53 +41,143 @@ def _download_clip_for_clip_score() -> None:
_CLIPProcessor = None


def _detect_modality(input_data: Union[Tensor, List[Tensor], List[str], str]) -> Literal["image", "text"]:
"""Automatically detect the modality of the input data.
Args:
input_data: Input data that can be either image tensors or text strings
Returns:
str: Either "image" or "text"
Raises:
ValueError: If the input_data is an empty list or modality cannot be determined
"""
if isinstance(input_data, Tensor):
return "image"

if isinstance(input_data, list):
if len(input_data) == 0:
raise ValueError("Empty input list")
if isinstance(input_data[0], Tensor):
return "image"
if isinstance(input_data[0], str):
return "text"

if isinstance(input_data, str):
return "text"

raise ValueError("Could not automatically determine modality for input_data")


def _process_image_data(images: Union[Tensor, List[Tensor]]) -> List[Tensor]:
"""Helper function to process image data."""
images = [images] if not isinstance(images, list) and images.ndim == 3 else list(images)
if not all(i.ndim == 3 for i in images):
raise ValueError("Expected all images to be 3d but found image that has either more or less")
return images


def _process_text_data(texts: Union[str, List[str]]) -> List[str]:
"""Helper function to process text data."""
if not isinstance(texts, list):
texts = [texts]
return texts


def _get_features(
data: List[Union[Tensor, str]],
modality: str,
device: torch.device,
model: "_CLIPModel",
processor: "_CLIPProcessor",
) -> Tensor:
"""Get features from the CLIP model for either images or text.
Args:
data: List of input data (images or text)
modality: String indicating the type of input data (must be either "image" or "text")
device: Device to run the model on
model: CLIP model instance
processor: CLIP processor instance
Returns:
Tensor of features from the CLIP model
Raises:
ValueError: If modality is not "image" or "text"
"""
if modality == "image":
# Add type checking for images
image_data = [i for i in data if isinstance(i, Tensor)]
processed = processor(images=[i.cpu() for i in image_data], return_tensors="pt", padding=True)
return model.get_image_features(processed["pixel_values"].to(device))
if modality == "text":
processed = processor(text=data, return_tensors="pt", padding=True)
max_position_embeddings = model.config.text_config.max_position_embeddings
if processed["attention_mask"].shape[-1] > max_position_embeddings:
rank_zero_warn(
f"Encountered caption longer than {max_position_embeddings=}. Will truncate captions to this length."
"If longer captions are needed, initialize argument `model_name_or_path` with a model that supports"
"longer sequences",
UserWarning,
)
processed["attention_mask"] = processed["attention_mask"][..., :max_position_embeddings]
processed["input_ids"] = processed["input_ids"][..., :max_position_embeddings]
return model.get_text_features(processed["input_ids"].to(device), processed["attention_mask"].to(device))
raise ValueError(f"invalid modality {modality}")


def _clip_score_update(
images: Union[Tensor, List[Tensor]],
text: Union[str, list[str]],
source: Union[Tensor, List[Tensor], List[str], str],
target: Union[Tensor, List[Tensor], List[str], str],
model: _CLIPModel,
processor: _CLIPProcessor,
) -> tuple[Tensor, int]:
if not isinstance(images, list):
if images.ndim == 3:
images = [images]
else: # unwrap into list
images = list(images)

if not all(i.ndim == 3 for i in images):
raise ValueError("Expected all images to be 3d but found image that has either more or less")
source_modality = _detect_modality(source)
target_modality = _detect_modality(target)

if not isinstance(text, list):
text = [text]
source_data = (
_process_image_data(cast(Union[Tensor, List[Tensor]], source))
if source_modality == "image"
else _process_text_data(cast(Union[str, List[str]], source))
)
target_data = (
_process_image_data(cast(Union[Tensor, List[Tensor]], target))
if target_modality == "image"
else _process_text_data(cast(Union[str, List[str]], target))
)

if len(text) != len(images):
if len(source_data) != len(target_data):
raise ValueError(
f"Expected the number of images and text examples to be the same but got {len(images)} and {len(text)}"
)
device = images[0].device
processed_input = processor(text=text, images=[i.cpu() for i in images], return_tensors="pt", padding=True)

img_features = model.get_image_features(processed_input["pixel_values"].to(device))
img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True)

max_position_embeddings = model.config.text_config.max_position_embeddings
if processed_input["attention_mask"].shape[-1] > max_position_embeddings:
rank_zero_warn(
f"Encountered caption longer than {max_position_embeddings=}. Will truncate captions to this length."
"If longer captions are needed, initialize argument `model_name_or_path` with a model that supports"
"longer sequences",
UserWarning,
"Expected the number of source and target examples to be the same but got "
f"{len(source_data)} and {len(target_data)}"
)
processed_input["attention_mask"] = processed_input["attention_mask"][..., :max_position_embeddings]
processed_input["input_ids"] = processed_input["input_ids"][..., :max_position_embeddings]

txt_features = model.get_text_features(
processed_input["input_ids"].to(device), processed_input["attention_mask"].to(device)
device = (
source_data[0].device
if source_modality == "image" and isinstance(source_data[0], Tensor)
else target_data[0].device
if target_modality == "image" and isinstance(target_data[0], Tensor)
else torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True)
model = model.to(device)

# cosine similarity between feature vectors
score = 100 * (img_features * txt_features).sum(axis=-1)
return score, len(text)
source_features = _get_features(
cast(List[Union[Tensor, str]], source_data), source_modality, device, model, processor
)
target_features = _get_features(
cast(List[Union[Tensor, str]], target_data), target_modality, device, model, processor
)
source_features = source_features / source_features.norm(p=2, dim=-1, keepdim=True)
target_features = target_features / target_features.norm(p=2, dim=-1, keepdim=True)

# Calculate cosine similarity
score = 100 * (source_features * target_features).sum(axis=-1)
score = score.cpu() if source_modality == "text" and target_modality == "text" else score
return score, len(source_data)


def _get_clip_model_and_processor(
Expand All @@ -113,20 +203,20 @@ def _get_clip_model_and_processor(


def clip_score(
images: Union[Tensor, List[Tensor]],
text: Union[str, list[str]],
source: Union[Tensor, List[Tensor], List[str], str],
target: Union[Tensor, List[Tensor], List[str], str],
model_name_or_path: Literal[
"openai/clip-vit-base-patch16",
"openai/clip-vit-base-patch32",
"openai/clip-vit-large-patch14-336",
"openai/clip-vit-large-patch14",
] = "openai/clip-vit-large-patch14",
) -> Tensor:
r"""Calculate `CLIP Score`_ which is a text-to-image similarity metric.
r"""Calculates `CLIP Score`_ which is a text-to-image similarity metric.
CLIP Score is a reference free metric that can be used to evaluate the correlation between a generated caption for
an image and the actual content of the image. It has been found to be highly correlated with human judgement. The
metric is defined as:
an image and the actual content of the image, as well as the similarity between texts or images. It has been found
to be highly correlated with human judgement. The metric is defined as:
.. math::
\text{CLIPScore(I, C)} = max(100 * cos(E_I, E_C), 0)
Expand All @@ -135,15 +225,33 @@ def clip_score(
textual CLIP embedding :math:`E_C` for an caption :math:`C`. The score is bound between 0 and 100 and the closer
to 100 the better.
.. caution::
Metric is not scriptable
Additionally, the CLIP Score can be calculated for the same modalities:
.. math::
\text{CLIPScore(I_1, I_2)} = max(100 * cos(E_{I_1}, E_{I_2}), 0)
where :math:`E_{I_1}` and :math:`E_{I_2}` are the visual embeddings for images :math:`I_1` and :math:`I_2`.
.. math::
\text{CLIPScore(T_1, T_2)} = max(100 * cos(E_{T_1}, E_{T_2}), 0)
where :math:`E_{T_1}` and :math:`E_{T_2}` are the textual embeddings for texts :math:`T_1` and :math:`T_2`.
.. note:: Metric is not scriptable
Args:
images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors
text: Either a single caption or a list of captions
model_name_or_path: string indicating the version of the CLIP model to use. Available models are
`"openai/clip-vit-base-patch16"`, `"openai/clip-vit-base-patch32"`, `"openai/clip-vit-large-patch14-336"`
and `"openai/clip-vit-large-patch14"`,
source: Source input. This can be:
- Images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors.
- Text: Either a single caption or a list of captions.
target: Target input. This can be:
- Images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors.
- Text: Either a single caption or a list of captions.
model_name_or_path: String indicating the version of the CLIP model to use. Available models are:
- `"openai/clip-vit-base-patch16"`
- `"openai/clip-vit-base-patch32"`
- `"openai/clip-vit-large-patch14-336"`
- `"openai/clip-vit-large-patch14"`
Raises:
ModuleNotFoundError:
Expand All @@ -155,13 +263,31 @@ def clip_score(
Example:
>>> from torchmetrics.functional.multimodal import clip_score
>>> score = clip_score(torch.randint(255, (3, 224, 224)), "a photo of a cat", "openai/clip-vit-base-patch16")
>>> image = torch.randint(255, (3, 224, 224), generator=torch.Generator().manual_seed(42))
>>> score = clip_score(image, "a photo of a cat", "openai/clip-vit-base-patch16")
>>> score.detach()
tensor(24.4255)
Example:
>>> from torchmetrics.functional.multimodal import clip_score
>>> image1 = torch.randint(255, (3, 224, 224), generator=torch.Generator().manual_seed(42))
>>> image2 = torch.randint(255, (3, 224, 224), generator=torch.Generator().manual_seed(43))
>>> score = clip_score(image1, image2, "openai/clip-vit-base-patch16")
>>> score.detach()
tensor(99.4859)
Example:
>>> from torchmetrics.functional.multimodal import clip_score
>>> score = clip_score(
... "28-year-old chef found dead in San Francisco mall",
... "A 28-year-old chef who recently moved to San Francisco was found dead.",
... "openai/clip-vit-base-patch16"
... )
>>> score.detach()
tensor(91.3950)
"""
model, processor = _get_clip_model_and_processor(model_name_or_path)
device = images.device if isinstance(images, Tensor) else images[0].device
score, _ = _clip_score_update(images, text, model.to(device), processor)
score, _ = _clip_score_update(source, target, model, processor)
score = score.mean(0)
return torch.max(score, torch.zeros_like(score))
Loading

0 comments on commit 5fc2e0b

Please sign in to comment.