Skip to content

Commit

Permalink
🔨 Split CUB attributes and certainty and add doc
Browse files Browse the repository at this point in the history
  • Loading branch information
o-laurent committed Jan 15, 2025
1 parent 71cf8df commit ff55d0c
Showing 1 changed file with 37 additions and 11 deletions.
48 changes: 37 additions & 11 deletions torch_uncertainty/datasets/classification/cub.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def __init__(
returns them instead of the images. Defaults to False.
download (bool, optional): If True, downloads the dataset from the internet and puts it
in root directory. If dataset is already downloaded, it is not downloaded again.
Defaults to
Defaults to False.
Reference:
Wah, C. and Branson, S. and Welinder, P. and Perona, P. and Belongie, S. Caltech-UCSD
Birds 200.
Expand All @@ -56,10 +57,9 @@ def __init__(
super().__init__(Path(root) / "CUB_200_2011" / "images", transform, target_transform)

training_idx = self._load_train_idx()
self.attributes, self.certainties = self._load_attributes()
if load_attributes:
self.samples = zip(
self._load_attributes(), [sam[1] for sam in self.samples], strict=False
)
self.samples = zip(self.attributes, [sam[1] for sam in self.samples], strict=False)
self.attribute_names = self._load_attribute_names()
self.loader = torch.nn.Identity()

Expand All @@ -69,38 +69,64 @@ def __init__(
self.classnames = self._load_classnames()

def _load_classnames(self) -> list[str]:
"""Load the classnames of the dataset.
Returns:
list[str]: the list containing the names of the 200 classes.
"""
with Path(self.folder_root / "CUB_200_2011" / "classes.txt").open("r") as f:
return [
line.split(" ")[1].split(".")[1].replace("\n", "").replace("_", " ") for line in f
]

def _load_train_idx(self) -> Tensor:
"""Load the index of the training data to make the split.
Returns:
Tensor: whether the images belong to the training or test split.
"""
with (self.folder_root / "CUB_200_2011" / "train_test_split.txt").open("r") as f:
return torch.as_tensor([int(line.split(" ")[1]) for line in f])

def _load_attributes(self) -> Tensor:
attributes = []
def _load_attributes(self) -> tuple[Tensor, Tensor]:
"""Load the attributes associated to each image.
Returns:
tuple[Tensor, Tensor]: The presence of the 312 attributes along with their certainty.
"""
attributes, certainty = [], []
with (self.folder_root / "CUB_200_2011" / "attributes" / "image_attribute_labels.txt").open(
"r"
) as f:
attributes = [
0.5 + 2 * (int(line.split(" ")[2]) - 0.5) * (int(line.split(" ")[3]) - 1) * 1 / 6
for line in f
]
return rearrange(torch.as_tensor(attributes), "(n c) -> n c", c=312)
attributes = [int(line.split(" ")[2]) for line in f]
certainty = [(int(line.split(" ")[3]) - 1) / 3 for line in f]
return rearrange(torch.as_tensor(attributes), "(n c) -> n c", c=312), rearrange(
torch.as_tensor(certainty), "(n c) -> n c", c=312
)

def _load_attribute_names(self) -> list[str]:
"""Load the names of the attributes.
Returns:
list[str]: The list of the names of the 312 attributes.
"""
with (self.folder_root / "attributes.txt").open("r") as f:
return [line.split(" ")[1].replace("\n", "").replace("_", " ") for line in f]

def _check_integrity(self) -> bool:
"""Check the integrity of the dataset.
Returns:
bool: True when the md5 of the archive corresponds.
"""
fpath = self.folder_root / self.filename
return check_integrity(
fpath,
self.tgz_md5,
)

def _download(self):
"""Download the dataset from caltec.edu."""
if self._check_integrity():
logging.info("Files already downloaded and verified")
return
Expand Down

0 comments on commit ff55d0c

Please sign in to comment.