diff --git a/torchvision/prototype/datapoints/__init__.py b/torchvision/prototype/datapoints/__init__.py index 92f345e20bd..fa36302e633 100644 --- a/torchvision/prototype/datapoints/__init__.py +++ b/torchvision/prototype/datapoints/__init__.py @@ -1,5 +1,6 @@ from ._bounding_box import BoundingBox, BoundingBoxFormat from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT +from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT # TODO: these may not need to be public? from ._image import ColorSpace, Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT from ._label import Label, OneHotLabel from ._mask import Mask diff --git a/torchvision/prototype/datapoints/_bounding_box.py b/torchvision/prototype/datapoints/_bounding_box.py index 398770cbf6a..cc5cfe88fb4 100644 --- a/torchvision/prototype/datapoints/_bounding_box.py +++ b/torchvision/prototype/datapoints/_bounding_box.py @@ -16,8 +16,8 @@ class BoundingBoxFormat(StrEnum): class BoundingBox(Datapoint): - format: BoundingBoxFormat - spatial_size: Tuple[int, int] + format: BoundingBoxFormat # TODO: do not use a builtin? + spatial_size: Tuple[int, int] # TODO: This is the size of the image, not the box. Should we make the name more obvious? @classmethod def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, spatial_size: Tuple[int, int]) -> BoundingBox: diff --git a/torchvision/prototype/datapoints/_datapoint.py b/torchvision/prototype/datapoints/_datapoint.py index 659d4e958cc..b5009a9c354 100644 --- a/torchvision/prototype/datapoints/_datapoint.py +++ b/torchvision/prototype/datapoints/_datapoint.py @@ -15,6 +15,8 @@ FillTypeJIT = Union[int, float, List[float], None] +# TODO: provide a few examples of when the Datapoint type is preserved vs when it's not +# test_prototype_datapoints.py is a good starting point class Datapoint(torch.Tensor): __F: Optional[ModuleType] = None @@ -89,6 +91,7 @@ def __torch_function__( with DisableTorchFunctionSubclass(): output = func(*args, **kwargs or dict()) + # TODO: maybe we can exit the CM here? wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func) # Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be @@ -98,8 +101,14 @@ def __torch_function__( # `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would # be wrapped into a `datapoints.Image`. if wrapper and isinstance(args[0], cls): + # TODO: figure out whether + # arbitrary_tensor.to(some_img) + # should be an Image or a Tensor return wrapper(cls, args[0], output) # type: ignore[no-any-return] + # Does that mean that DisableTorchFunctionSubclass is ignored for `.inpace_()` functions? + # TODO: figure out with torch core whether this is a bug or not + # Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`, # will retain the input type. Thus, we need to unwrap here. if isinstance(output, cls): diff --git a/torchvision/prototype/datapoints/_image.py b/torchvision/prototype/datapoints/_image.py index fc20691100f..971e606ea43 100644 --- a/torchvision/prototype/datapoints/_image.py +++ b/torchvision/prototype/datapoints/_image.py @@ -58,6 +58,8 @@ def _from_tensor_shape(shape: List[int]) -> ColorSpace: class Image(Datapoint): + # For now, this is somewhat redundant with number of channels. + # TODO: decide whether we want to keep it? color_space: ColorSpace @classmethod diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py index 43224cabd38..ff00b023bdc 100644 --- a/torchvision/prototype/transforms/_transform.py +++ b/torchvision/prototype/transforms/_transform.py @@ -35,6 +35,9 @@ def forward(self, *inputs: Any) -> Any: params = self._get_params(flat_inputs) + # TODO: right now, any tensor or datapoint passed to forward() will be transformed. + # The rest is bypassed. + # What if there are tensor parameters that users don't want to be transformed? Is that a plausible scenario? flat_outputs = [ self._transform(inpt, params) if check_type(inpt, self._transformed_types) else inpt for inpt in flat_inputs ] diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index ba417a0ce84..75dea0f3f60 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -244,6 +244,8 @@ def resize( if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) ): + # TODO: Figure out whether this cond could just be: + # if torch.jit.is_scripting() or is_simple_tensor(inpt): return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias)