Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NOMRG] TransformsV2 TODOs #7082

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchvision/prototype/datapoints/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT
from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT # TODO: these may not need to be public?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most likely not. We only introduced them for internal convenience. Whenever you see something annotates with *JIT, you can basically look up the eager counterpart to see what is actually supported.

from ._image import ColorSpace, Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT
from ._label import Label, OneHotLabel
from ._mask import Mask
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/datapoints/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class BoundingBoxFormat(StrEnum):


class BoundingBox(Datapoint):
format: BoundingBoxFormat
spatial_size: Tuple[int, int]
format: BoundingBoxFormat # TODO: do not use a builtin?
spatial_size: Tuple[int, int] # TODO: This is the size of the image, not the box
pmeier marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, spatial_size: Tuple[int, int]) -> BoundingBox:
Expand Down
9 changes: 9 additions & 0 deletions torchvision/prototype/datapoints/_datapoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
FillTypeJIT = Union[int, float, List[float], None]


# TODO: provide a few examples of when the Datapoint type is preserved vs when it's not
# test_prototype_datapoints.py is a good starting point
class Datapoint(torch.Tensor):
__F: Optional[ModuleType] = None

Expand Down Expand Up @@ -89,6 +91,7 @@ def __torch_function__(

with DisableTorchFunctionSubclass():
output = func(*args, **kwargs or dict())
# TODO: maybe we can exit the CM here?

wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func)
# Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be
Expand All @@ -98,8 +101,14 @@ def __torch_function__(
# `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would
# be wrapped into a `datapoints.Image`.
if wrapper and isinstance(args[0], cls):
# TODO: figure out whether
# arbitrary_tensor.to(some_img)
# should be an Image or a Tensor
return wrapper(cls, args[0], output) # type: ignore[no-any-return]

# Does that mean that DisableTorchFunctionSubclass is ignored for `.inpace_()` functions?
# TODO: figure out with torch core whether this is a bug or not

# Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`,
# will retain the input type. Thus, we need to unwrap here.
if isinstance(output, cls):
Expand Down
2 changes: 2 additions & 0 deletions torchvision/prototype/datapoints/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def _from_tensor_shape(shape: List[int]) -> ColorSpace:


class Image(Datapoint):
# For now, this is somewhat redundant with number of channels.
# TODO: decide whether we want to keep it?
Comment on lines +61 to +62
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For context, this was designed with the use case of other color spaces with the same number of channels. For example, #4029. Since we can always add this later on, I see no harm of removing it now. That being said, I would keep the ColorSpace enum for

which is a generalized replacement for

def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image:

and

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While going through our source again, I noticed this:

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
):
if old_color_space is None:
raise RuntimeError(
"In order to convert the color space of simple tensors, "
"the `old_color_space=...` parameter needs to be passed."
)
return convert_color_space_image_tensor(inpt, old_color_space=old_color_space, new_color_space=color_space)

As long as we can accurately identify the color space from the number of channels, we should either remove the old_color_space parameter completely or at least remove this error and set the value ourselves with

def _from_tensor_shape(shape: List[int]) -> ColorSpace:

color_space: ColorSpace

@classmethod
Expand Down
2 changes: 2 additions & 0 deletions torchvision/prototype/transforms/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def forward(self, *inputs: Any) -> Any:

params = self._get_params(flat_inputs)

# TODO: right now, any tensor or datapoint passed to forward() will be transformed
# the rest is bypassed
pmeier marked this conversation as resolved.
Show resolved Hide resolved
flat_outputs = [
self._transform(inpt, params) if check_type(inpt, self._transformed_types) else inpt for inpt in flat_inputs
]
Expand Down
2 changes: 2 additions & 0 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@ def resize(
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
# TODO: Figure out whether this cond could just be:
# if torch.jit.is_scripting() or is_simple_tensor(inpt):
Comment on lines +247 to +248
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works out 🚀 Gonna send a PR soon.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #7084

return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias)
Expand Down