diff --git a/detectron2/engine/defaults.py b/detectron2/engine/defaults.py index c649bf8ff7..0b699666c2 100644 --- a/detectron2/engine/defaults.py +++ b/detectron2/engine/defaults.py @@ -16,6 +16,7 @@ import weakref from collections import OrderedDict from typing import Optional +import numpy as np import torch from fvcore.nn.precise_bn import get_bn_modules from omegaconf import OmegaConf @@ -294,30 +295,49 @@ def __init__(self, cfg): self.input_format = cfg.INPUT.FORMAT assert self.input_format in ["RGB", "BGR"], self.input_format - def __call__(self, original_image): + def __call__(self, original_images): """ Args: - original_image (np.ndarray): an image of shape (H, W, C) (in BGR order). - + original_images (np.ndarray or List[np.ndarray]): + an image of shape (H, W, C) or (B, H, W, C) (in BGR order). Returns: - predictions (dict): + predictions (dict or List[dict]): the output of the model for one image only. See :doc:`/tutorials/models` for details about the format. """ with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258 - # Apply pre-processing to image. - if self.input_format == "RGB": - # whether the model expects BGR inputs or RGB - original_image = original_image[:, :, ::-1] - height, width = original_image.shape[:2] - image = self.aug.get_transform(original_image).apply_image(original_image) - image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) - image.to(self.cfg.MODEL.DEVICE) - - inputs = {"image": image, "height": height, "width": width} - - predictions = self.model([inputs])[0] - return predictions + if isinstance(original_images, np.array): + original_image = original_images + # Apply pre-processing to image. + if self.input_format == "RGB": + # whether the model expects BGR inputs or RGB + original_image = original_image[:, :, ::-1] + height, width = original_image.shape[:2] + image = self.aug.get_transform(original_image).apply_image(original_image) + image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) + image.to(self.cfg.MODEL.DEVICE) + + inputs = {"image": image, "height": height, "width": width} + + predictions = self.model([inputs])[0] + return predictions + elif isinstance(original_images, list): + batch_inputs = [] + for original_image in original_images: + if self.input_format == "RGB": + # whether the model expects BGR inputs or RGB + original_image = original_image[:, :, ::-1] + height, width = original_image.shape[:2] + image = self.aug.get_transform(original_image).apply_image(original_image) + image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) + image.to(self.cfg.MODEL.DEVICE) + + inputs = {"image": image, "height": height, "width": width} + batch_inputs.append(inputs) + predictions = self.model(batch_inputs) + return predictions + else: + return None class DefaultTrainer(TrainerBase):