Skip to content

Commit

Permalink
Add check if path is vsi (#1612)
Browse files Browse the repository at this point in the history
* Add check if path is vsi

* Add url to reference for apache vsi syntax

* Add missing check to if

* Copy rasterio SCHEMES definition into torchgeo

* Check all schemes, not only last

* Simplify method path_is_vsi

* Add tests

* Remove print

* Update test names

* Add missing comma in list

* Update torchgeo/datasets/utils.py

Co-authored-by: Adam J. Stewart <[email protected]>

* Update torchgeo/datasets/utils.py

Co-authored-by: Adam J. Stewart <[email protected]>

* Use pytest tmp_path for test

* Warn if some of input paths are invalid

* Update docstring for mocked class

* Handle tests failing due to UserWarning

* Remove unnecessary filterwarning

* Test CustomGeoDataset instead of MockRasterDataset

* Merge two similar tests

* str instead of as_posix

Wait with pathlib syntax

Co-authored-by: Adam J. Stewart <[email protected]>

---------

Co-authored-by: Adrian Tofting <[email protected]>
Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
3 people authored and nilsleh committed Nov 6, 2023
1 parent bb3d77a commit 5c4c716
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 3 deletions.
21 changes: 20 additions & 1 deletion tests/datasets/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pickle
from collections.abc import Iterable
from pathlib import Path
from typing import Union
from typing import Optional, Union

import pytest
import torch
Expand Down Expand Up @@ -33,11 +33,13 @@ def __init__(
bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5),
crs: CRS = CRS.from_epsg(4087),
res: float = 1,
paths: Optional[Union[str, Iterable[str]]] = None,
) -> None:
super().__init__()
self.index.insert(0, tuple(bounds))
self._crs = crs
self.res = res
self.paths = paths or []

def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]:
hits = self.index.intersection(tuple(query), objects=True)
Expand Down Expand Up @@ -152,6 +154,23 @@ def test_and_nongeo(self, dataset: GeoDataset) -> None:
):
dataset & ds2 # type: ignore[operator]

def test_files_property_for_non_existing_file_or_dir(self, tmp_path: Path) -> None:
paths = [str(tmp_path), str(tmp_path / "non_existing_file.tif")]
with pytest.warns(UserWarning, match="Path was ignored."):
assert len(CustomGeoDataset(paths=paths).files) == 0

def test_files_property_for_virtual_files(self) -> None:
# Tests only a subset of schemes and combinations.
paths = [
"file://directory/file.tif",
"zip://archive.zip!folder/file.tif",
"az://azure_bucket/prefix/file.tif",
"/vsiaz/azure_bucket/prefix/file.tif",
"zip+az://azure_bucket/prefix/archive.zip!folder_in_archive/file.tif",
"/vsizip//vsiaz/azure_bucket/prefix/archive.zip/folder_in_archive/file.tif",
]
assert len(CustomGeoDataset(paths=paths).files) == len(paths)


class TestRasterDataset:
@pytest.fixture(params=zip([["R", "G", "B"], None], [True, False]))
Expand Down
17 changes: 15 additions & 2 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
import re
import sys
import warnings
from collections.abc import Iterable, Sequence
from typing import Any, Callable, Optional, Union, cast

Expand All @@ -29,7 +30,13 @@
from torchvision.datasets import ImageFolder
from torchvision.datasets.folder import default_loader as pil_loader

from .utils import BoundingBox, concat_samples, disambiguate_timestamp, merge_samples
from .utils import (
BoundingBox,
concat_samples,
disambiguate_timestamp,
merge_samples,
path_is_vsi,
)


class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
Expand Down Expand Up @@ -298,8 +305,14 @@ def files(self) -> set[str]:
if os.path.isdir(path):
pathname = os.path.join(path, "**", self.filename_glob)
files |= set(glob.iglob(pathname, recursive=True))
else:
elif os.path.isfile(path) or path_is_vsi(path):
files.add(path)
else:
warnings.warn(
f"Could not find any relevant files for provided path '{path}'. "
f"Path was ignored.",
UserWarning,
)

return files

Expand Down
24 changes: 24 additions & 0 deletions torchgeo/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,3 +737,27 @@ def percentile_normalization(
(img - lower_percentile) / (upper_percentile - lower_percentile + 1e-5), 0, 1
)
return img_normalized


def path_is_vsi(path: str) -> bool:
"""Checks if the given path is pointing to a Virtual File System.
.. note::
Does not check if the path exists, or if it is a dir or file.
VSI can for instance be Cloud Storage Blobs or zip-archives.
They will start with a prefix indicating this.
For examples of these, see references for the two accepted syntaxes.
* https://gdal.org/user/virtual_file_systems.html
* https://rasterio.readthedocs.io/en/latest/topics/datasets.html
Args:
path: string representing a directory or file
Returns:
True if path is on a virtual file system, else False
.. versionadded:: 0.6
"""
return "://" in path or path.startswith("/vsi")

0 comments on commit 5c4c716

Please sign in to comment.