diff --git a/torch_uncertainty/datasets/classification/cub.py b/torch_uncertainty/datasets/classification/cub.py index 69606a49..0924b40e 100644 --- a/torch_uncertainty/datasets/classification/cub.py +++ b/torch_uncertainty/datasets/classification/cub.py @@ -38,7 +38,8 @@ def __init__( returns them instead of the images. Defaults to False. download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. - Defaults to + Defaults to False. + Reference: Wah, C. and Branson, S. and Welinder, P. and Perona, P. and Belongie, S. Caltech-UCSD Birds 200. @@ -56,10 +57,9 @@ def __init__( super().__init__(Path(root) / "CUB_200_2011" / "images", transform, target_transform) training_idx = self._load_train_idx() + self.attributes, self.certainties = self._load_attributes() if load_attributes: - self.samples = zip( - self._load_attributes(), [sam[1] for sam in self.samples], strict=False - ) + self.samples = zip(self.attributes, [sam[1] for sam in self.samples], strict=False) self.attribute_names = self._load_attribute_names() self.loader = torch.nn.Identity() @@ -69,31 +69,56 @@ def __init__( self.classnames = self._load_classnames() def _load_classnames(self) -> list[str]: + """Load the classnames of the dataset. + + Returns: + list[str]: the list containing the names of the 200 classes. + """ with Path(self.folder_root / "CUB_200_2011" / "classes.txt").open("r") as f: return [ line.split(" ")[1].split(".")[1].replace("\n", "").replace("_", " ") for line in f ] def _load_train_idx(self) -> Tensor: + """Load the index of the training data to make the split. + + Returns: + Tensor: whether the images belong to the training or test split. + """ with (self.folder_root / "CUB_200_2011" / "train_test_split.txt").open("r") as f: return torch.as_tensor([int(line.split(" ")[1]) for line in f]) - def _load_attributes(self) -> Tensor: - attributes = [] + def _load_attributes(self) -> tuple[Tensor, Tensor]: + """Load the attributes associated to each image. + + Returns: + tuple[Tensor, Tensor]: The presence of the 312 attributes along with their certainty. + """ + attributes, certainty = [], [] with (self.folder_root / "CUB_200_2011" / "attributes" / "image_attribute_labels.txt").open( "r" ) as f: - attributes = [ - 0.5 + 2 * (int(line.split(" ")[2]) - 0.5) * (int(line.split(" ")[3]) - 1) * 1 / 6 - for line in f - ] - return rearrange(torch.as_tensor(attributes), "(n c) -> n c", c=312) + attributes = [int(line.split(" ")[2]) for line in f] + certainty = [(int(line.split(" ")[3]) - 1) / 3 for line in f] + return rearrange(torch.as_tensor(attributes), "(n c) -> n c", c=312), rearrange( + torch.as_tensor(certainty), "(n c) -> n c", c=312 + ) def _load_attribute_names(self) -> list[str]: + """Load the names of the attributes. + + Returns: + list[str]: The list of the names of the 312 attributes. + """ with (self.folder_root / "attributes.txt").open("r") as f: return [line.split(" ")[1].replace("\n", "").replace("_", " ") for line in f] def _check_integrity(self) -> bool: + """Check the integrity of the dataset. + + Returns: + bool: True when the md5 of the archive corresponds. + """ fpath = self.folder_root / self.filename return check_integrity( fpath, @@ -101,6 +126,7 @@ def _check_integrity(self) -> bool: ) def _download(self): + """Download the dataset from caltec.edu.""" if self._check_integrity(): logging.info("Files already downloaded and verified") return