Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring of ImageProcessorFast #35069

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
2f00f0c
add init and base image processing functions
yonigozlan Dec 3, 2024
cfadb72
add add_fast_image_processor to transformers-cli
yonigozlan Dec 3, 2024
2cd73cb
add working fast image processor clip
yonigozlan Dec 3, 2024
932bd68
add fast image processor to doc, working tests
yonigozlan Dec 4, 2024
23d79ce
remove "to be implemented" SigLip
yonigozlan Dec 4, 2024
3f2d8a6
fix unprotected import
yonigozlan Dec 4, 2024
6a9d332
fix unprotected vision import
yonigozlan Dec 4, 2024
a1e2663
update ViTImageProcessorFast
yonigozlan Dec 4, 2024
fa74e7e
increase threshold slow fast ewuivalence
yonigozlan Dec 4, 2024
9dbd765
add fast img blip
yonigozlan Dec 4, 2024
d39ff52
add fast class in tests with cli
yonigozlan Dec 4, 2024
f609730
improve cli
yonigozlan Dec 5, 2024
8f7774d
add fast image processor convnext
yonigozlan Dec 6, 2024
809e1f0
add LlavaPatchingMixin and fast image processor for llava_next and ll…
yonigozlan Dec 7, 2024
f6e6cc2
add device kwarg to ImagesKwargs for fast processing on cuda
yonigozlan Dec 9, 2024
e1ce148
cleanup
yonigozlan Dec 9, 2024
a24d89c
fix unprotected import
yonigozlan Dec 9, 2024
522e200
group images by sizes and add batch processing
yonigozlan Dec 11, 2024
deefc5a
Add batch equivalence tests, skip when center_crop is used
yonigozlan Dec 11, 2024
6a2478e
cleanup
yonigozlan Dec 11, 2024
7d76305
update init and cli
yonigozlan Dec 11, 2024
142ed25
fix-copies
yonigozlan Dec 11, 2024
75bf56f
refactor convnext, cleanup base
yonigozlan Dec 16, 2024
de1fa18
fix
yonigozlan Dec 16, 2024
2ffc41d
remove patching mixins, add piped torchvision transforms for ViT
yonigozlan Dec 17, 2024
b524406
fix unbatched processing
yonigozlan Dec 17, 2024
9c2e2a4
fix f strings
yonigozlan Dec 17, 2024
8c773e0
protect imports
yonigozlan Dec 17, 2024
90fceba
change llava onevision to class transforms (test)
yonigozlan Dec 18, 2024
e878bdd
fix convnext
yonigozlan Dec 18, 2024
57acb7e
improve formatting (following Pavel review)
yonigozlan Jan 6, 2025
2a25104
fix handling device arg
yonigozlan Jan 6, 2025
4784fc8
improve cli
yonigozlan Jan 6, 2025
3ccd291
fix
yonigozlan Jan 6, 2025
053cdcb
fix inits
yonigozlan Jan 16, 2025
1b45e6e
Merge remote-tracking branch 'upstream/main' into improve-fast-image-…
yonigozlan Jan 21, 2025
9246945
Add distinction between preprocess and _preprocess, and support for a…
yonigozlan Jan 21, 2025
6ccd230
uniformize qwen2_vl fast
yonigozlan Jan 22, 2025
c4b8389
fix docstrings
yonigozlan Jan 22, 2025
e5c1e01
add add fast image processor llava
yonigozlan Jan 22, 2025
aef2fb4
remove min_pixels max_pixels from accepted size
yonigozlan Jan 22, 2025
7078a14
nit
yonigozlan Jan 22, 2025
aa94873
nit
yonigozlan Jan 22, 2025
13a125b
refactor fast image processors docstrings
yonigozlan Jan 28, 2025
8adb893
Merge remote-tracking branch 'upstream/main' into improve-fast-image-…
yonigozlan Jan 28, 2025
67d65f2
cleanup and remove fast class transforms
yonigozlan Jan 28, 2025
d225448
update add fast image processor transformers cli
yonigozlan Jan 28, 2025
80c6824
cleanup docstring
yonigozlan Jan 28, 2025
b96adfa
Merge remote-tracking branch 'upstream/main' into improve-fast-image-…
yonigozlan Jan 30, 2025
dbaacd1
uniformize pixtral fast and make _process_image explicit
yonigozlan Jan 30, 2025
b660e9d
Merge remote-tracking branch 'upstream/main' into improve-fast-image-…
yonigozlan Jan 30, 2025
b43ede1
fix prepare image structure llava next/onevision
yonigozlan Jan 30, 2025
3b05cbd
Use typed kwargs instead of explicit args
yonigozlan Feb 4, 2025
95db4a9
nit fix import Unpack
yonigozlan Feb 4, 2025
d9e1fcd
Merge branch 'main' into improve-fast-image-processor-base
yonigozlan Feb 4, 2025
6bd7a1b
clearly separate pops and gets in base preprocess. Use explicit typed…
yonigozlan Feb 4, 2025
565e482
Merge branch 'main' into improve-fast-image-processor-base
yonigozlan Feb 4, 2025
f85c06f
make qwen2_vl preprocess arguments hashable
yonigozlan Feb 4, 2025
1a7b0c4
Merge branch 'improve-fast-image-processor-base' of https://github.co…
yonigozlan Feb 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 75 additions & 96 deletions src/transformers/image_processing_utils_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, TypedDict, Union

import numpy as np

Expand Down Expand Up @@ -44,6 +44,7 @@
validate_fast_preprocess_arguments,
validate_kwargs,
)
from .processing_utils import Unpack
from .utils import (
TensorType,
add_start_docstrings,
Expand Down Expand Up @@ -126,6 +127,28 @@ def divide_to_patches(
return patches


class DefaultFastImageProcessorInitKwargs(TypedDict, total=False):
do_resize: Optional[bool]
size: Optional[Dict[str, int]]
default_to_square: Optional[bool]
resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]]
do_center_crop: Optional[bool]
crop_size: Optional[Dict[str, int]]
do_rescale: Optional[bool]
rescale_factor: Optional[Union[int, float]]
do_normalize: Optional[bool]
image_mean: Optional[Union[float, List[float]]]
image_std: Optional[Union[float, List[float]]]
do_convert_rgb: Optional[bool]


class DefaultFastImageProcessorPreprocessKwargs(DefaultFastImageProcessorInitKwargs):
return_tensors: Optional[Union[str, TensorType]]
data_format: Optional[ChannelDimension]
input_data_format: Optional[Union[str, ChannelDimension]]
device: Optional["torch.device"]


BASE_IMAGE_PROCESSOR_FAST_DOCSTRING = r"""

Args:
Expand Down Expand Up @@ -228,59 +251,35 @@ class BaseImageProcessorFast(BaseImageProcessor):
do_resize = None
do_center_crop = None
do_rescale = None
rescale_factor = 1 / 255
do_normalize = None
do_convert_rgb = None
model_input_names = ["pixel_values"]
valid_extra_kwargs = []
valid_init_kwargs = DefaultFastImageProcessorInitKwargs
valid_preprocess_kwargs = DefaultFastImageProcessorPreprocessKwargs

def __init__(
self,
do_resize: bool = None,
size: Dict[str, int] = None,
default_to_square: bool = None,
resample: Union["PILImageResampling", "F.InterpolationMode"] = None,
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
do_rescale: bool = None,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = None,
image_mean: Union[float, List[float]] = None,
image_std: Union[float, List[float]] = None,
do_convert_rgb: bool = None,
**kwargs,
**kwargs: Unpack[valid_init_kwargs],
) -> None:
size = size if size is not None else self.size
default_to_square = (
default_to_square
if default_to_square is not None
else self.default_to_square
if self.default_to_square is not None
else True
super().__init__(**kwargs)
size = kwargs.pop("size", self.size)
self.default_to_square = (
self.default_to_square if self.default_to_square is not None else True
) # compatibility with slow processors
self.size = (
get_size_dict(size=size, default_to_square=kwargs.pop("default_to_square", self.default_to_square))
if size is not None
else None
)
size = get_size_dict(size, default_to_square=default_to_square) if size is not None else None
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size, param_name="crop_size") if crop_size is not None else None
for key in self.valid_extra_kwargs:
crop_size = kwargs.pop("crop_size", self.crop_size)
self.crop_size = get_size_dict(crop_size, param_name="crop_size") if crop_size is not None else None
for key in self.valid_init_kwargs.__annotations__.keys():
kwarg = kwargs.pop(key, None)
if kwarg is not None:
setattr(self, key, kwarg)
else:
setattr(self, key, getattr(self, key, None))
yonigozlan marked this conversation as resolved.
Show resolved Hide resolved
if kwargs:
logger.warning_once(f"Found kwargs that are not in valid_extra_kwargs: {kwargs.keys()}")

super().__init__(**kwargs)
self.do_resize = do_resize if do_resize is not None else self.do_resize
self.size = size if size is not None else self.size
self.resample = resample if resample is not None else self.resample
self.do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
self.crop_size = crop_size if crop_size is not None else self.crop_size
self.do_rescale = do_rescale if do_rescale is not None else self.do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize if do_normalize is not None else self.do_normalize
self.image_mean = image_mean if image_mean is not None else self.image_mean
self.image_std = image_std if image_std is not None else self.image_std
self.do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb

def resize(
self,
Expand Down Expand Up @@ -557,81 +556,61 @@ def _prepare_process_arguments(
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[int] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
device: Optional["torch.device"] = None,
**kwargs,
**kwargs: Unpack[valid_preprocess_kwargs],
) -> BatchFeature:
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_extra_kwargs)

do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
default_to_square = kwargs.pop(
"default_to_square", self.default_to_square if self.default_to_square is not None else True
validate_kwargs(
captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_preprocess_kwargs.__annotations__.keys()
)
for kwarg in self.valid_preprocess_kwargs.__annotations__.keys():
kwargs.setdefault(kwarg, getattr(self, kwarg, None))

size = kwargs.pop("size", self.size)
size = (
get_size_dict(size=size, default_to_square=kwargs.pop("default_to_square", self.default_to_square))
if size is not None
else None
)
size = get_size_dict(size=size, default_to_square=default_to_square) if size is not None else None
resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = kwargs.pop("crop_size", self.crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size") if crop_size is not None else None
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
kwargs_dict = {
kwarg: kwargs.pop(kwarg) if kwargs.get(kwarg) is not None else getattr(self, kwarg, None)
for kwarg in self.valid_extra_kwargs
}
data_format = kwargs.pop("data_format", None)
data_format = data_format if data_format is not None else ChannelDimension.FIRST

images = self._prepare_input_images(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here you will have a n issue: you are poping way too many time from the kwargs, so you pop then you use self, values will be different.
This should be simplified. Just update self with kwargs, then do the rest with always self. something like that

images=images,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
device=device,
do_convert_rgb=kwargs.pop("do_convert_rgb", self.do_convert_rgb),
input_data_format=kwargs.pop("input_data_format", None),
device=kwargs.pop("device", None),
)

image_mean, image_std, size, crop_size, interpolation = self._prepare_process_arguments(
do_resize=do_resize,
do_resize=kwargs.get("do_resize", self.do_resize),
size=size,
resample=resample,
resample=kwargs.pop("resample", self.resample),
crop_size=crop_size,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
return_tensors=return_tensors,
data_format=data_format,
do_rescale=kwargs.get("do_rescale", self.do_rescale),
rescale_factor=kwargs.get("rescale_factor", self.rescale_factor),
do_normalize=kwargs.get("do_normalize", self.do_normalize),
image_mean=kwargs.pop("image_mean", self.image_mean),
image_std=kwargs.pop("image_std", self.image_std),
return_tensors=kwargs.get("return_tensors", None),
data_format=kwargs.pop("data_format", ChannelDimension.FIRST),
device=images[0].device,
)

return self._preprocess(
images=images,
do_resize=do_resize,
do_resize=kwargs.pop("do_resize", self.do_resize),
size=size,
interpolation=interpolation,
do_center_crop=do_center_crop,
do_center_crop=kwargs.pop("do_center_crop", self.do_center_crop),
crop_size=crop_size,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
do_rescale=kwargs.pop("do_rescale", self.do_rescale),
rescale_factor=kwargs.pop("rescale_factor", self.rescale_factor),
do_normalize=kwargs.pop("do_normalize", self.do_normalize),
image_mean=image_mean,
image_std=image_std,
return_tensors=return_tensors,
**kwargs_dict,
return_tensors=kwargs.pop("return_tensors", None),
**kwargs,
)

def _preprocess(
Expand Down
55 changes: 18 additions & 37 deletions src/transformers/models/convnext/image_processing_convnext_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
BaseImageProcessorFast,
DefaultFastImageProcessorInitKwargs,
DefaultFastImageProcessorPreprocessKwargs,
group_images_by_shape,
reorder_images,
)
Expand All @@ -29,8 +31,10 @@
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
)
from ...processing_utils import Unpack
from ...utils import (
TensorType,
add_start_docstrings,
Expand All @@ -50,6 +54,14 @@
from torchvision.transforms import functional as F


class ConvNextFastImageProcessorInitKwargs(DefaultFastImageProcessorInitKwargs):
crop_pct: Optional[float]


class ConvNextFastImageProcessorPreprocessKwargs(DefaultFastImageProcessorPreprocessKwargs):
crop_pct: Optional[float]


@add_start_docstrings(
r"Constructs a fast ConvNeXT image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
Expand All @@ -69,42 +81,11 @@ class ConvNextImageProcessorFast(BaseImageProcessorFast):
do_rescale = True
do_normalize = True
crop_pct = 224 / 256
valid_extra_kwargs = ["crop_pct"]
valid_init_kwargs = ConvNextFastImageProcessorInitKwargs
valid_preprocess_kwargs = ConvNextFastImageProcessorPreprocessKwargs

def __init__(
self,
do_resize: bool = None,
size: Dict[str, int] = None,
default_to_square: bool = None,
resample: Union[PILImageResampling, "F.InterpolationMode"] = None,
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
do_rescale: bool = None,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = None,
image_mean: Union[float, List[float]] = None,
image_std: Union[float, List[float]] = None,
do_convert_rgb: bool = None,
# Additional arguments
crop_pct=None,
**kwargs,
):
super().__init__(
do_resize=do_resize,
size=size,
default_to_square=default_to_square,
resample=resample,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_convert_rgb=do_convert_rgb,
crop_pct=crop_pct,
**kwargs,
)
def __init__(self, **kwargs: Unpack[valid_init_kwargs]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self, **kwargs: Unpack[valid_init_kwargs]):
def __init__(self, **kwargs: Unpack[ConvNextFastImageProcessorInitKwargs]):

this is better and explicit

super().__init__(**kwargs)

@add_start_docstrings(
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
Expand All @@ -114,8 +95,8 @@ def __init__(
overridden by `crop_pct` in the`preprocess` method.
""",
)
def preprocess(self, *args, **kwargs) -> BatchFeature:
return super().preprocess(*args, **kwargs)
def preprocess(self, images: ImageInput, **kwargs: Unpack[valid_preprocess_kwargs]) -> BatchFeature:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment let's be explicit if possible

return super().preprocess(images, **kwargs)

def resize(
self,
Expand Down
Loading
Loading