Skip to content

Commit

Permalink
support new directroy structure created by radiant-mlhub>0.5
Browse files Browse the repository at this point in the history
  • Loading branch information
SpontaneousDuck committed Feb 9, 2023
1 parent d897e33 commit 2286ec1
Show file tree
Hide file tree
Showing 14 changed files with 9 additions and 33 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ dependencies:
- pytorch-lightning>=1.5.1
- git+https://github.com/pytorch/pytorch_sphinx_theme
- pyupgrade>=2.4
- radiant-mlhub>=0.2.1,<0.5
- radiant-mlhub>0.5
- rtree>=1
- scikit-image>=0.18
- scikit-learn>=0.22
Expand Down
Binary file not shown.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Binary file not shown.
Binary file not shown.
18 changes: 5 additions & 13 deletions tests/datasets/test_nasa_marine_debris.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

class Dataset:
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join("tests", "data", "nasa_marine_debris", "*.tar.gz")
for tarball in glob.iglob(glob_path):
shutil.copy(tarball, output_dir)
ds_folder = os.path.join("tests", "data", "nasa_marine_debris", "nasa_marine_debris")
output_dir = os.path.join(output_dir, "nasa_marine_debris")
shutil.copytree(ds_folder, output_dir, dirs_exist_ok=True)


def fetch(dataset_id: str, **kwargs: str) -> Dataset:
Expand All @@ -29,9 +29,9 @@ def fetch(dataset_id: str, **kwargs: str) -> Dataset:
class TestNASAMarineDebris:
@pytest.fixture()
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NASAMarineDebris:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.5.0")
monkeypatch.setattr(radiant_mlhub.Dataset, "fetch", fetch)
md5s = ["fe8698d1e68b3f24f0b86b04419a797d", "d8084f5a72778349e07ac90ec1e1d990"]
md5s = ["29dc40158bb6a7c53daa6b815d3821c7"]
monkeypatch.setattr(NASAMarineDebris, "md5s", md5s)
root = str(tmp_path)
transforms = nn.Identity()
Expand All @@ -53,14 +53,6 @@ def test_already_downloaded(
) -> None:
NASAMarineDebris(root=str(tmp_path), download=True)

def test_already_downloaded_not_extracted(
self, dataset: NASAMarineDebris, tmp_path: Path
) -> None:
shutil.rmtree(dataset.root)
os.makedirs(str(tmp_path), exist_ok=True)
Dataset().download(output_dir=str(tmp_path))
NASAMarineDebris(root=str(tmp_path), download=False)

def test_not_downloaded(self, tmp_path: Path) -> None:
err = "Dataset not found in `root` directory and `download=False`, "
"either specify a different `root` directory or use `download=True` "
Expand Down
22 changes: 3 additions & 19 deletions torchgeo/datasets/nasa_marine_debris.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ class NASAMarineDebris(NonGeoDataset):
"""

dataset_id = "nasa_marine_debris"
directories = ["nasa_marine_debris_source", "nasa_marine_debris_labels"]
filenames = ["nasa_marine_debris_source.tar.gz", "nasa_marine_debris_labels.tar.gz"]
md5s = ["fe8698d1e68b3f24f0b86b04419a797d", "d8084f5a72778349e07ac90ec1e1d990"]
directories = ["nasa_marine_debris/nasa_marine_debris_source", "nasa_marine_debris/nasa_marine_debris_labels"]
filenames = ["nasa_marine_debris.tar.gz"]
md5s = ["29dc40158bb6a7c53daa6b815d3821c7"]
class_label = "marine_debris"

def __init__(
Expand Down Expand Up @@ -187,19 +187,6 @@ def _verify(self) -> None:
if all(exists):
return

# Check if zip file already exists (if so then extract)
exists = []
for filename in self.filenames:
filepath = os.path.join(self.root, filename)
if os.path.exists(filepath):
exists.append(True)
extract_archive(filepath)
else:
exists.append(False)

if all(exists):
return

# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
Expand All @@ -211,9 +198,6 @@ def _verify(self) -> None:
# TODO: need a checksum check in here post downloading
# Download and extract the dataset
download_radiant_mlhub_dataset(self.dataset_id, self.root, self.api_key)
for filename in self.filenames:
filepath = os.path.join(self.root, filename)
extract_archive(filepath)

def plot(
self,
Expand Down

0 comments on commit 2286ec1

Please sign in to comment.