Skip to content

Commit

Permalink
Datasets: add default 'root' argument (microsoft#802)
Browse files Browse the repository at this point in the history
* Datasets: add default 'root' argument

* Allow SpaceNet image to be optional

* Remove modifications to SpaceNet base class
  • Loading branch information
adamjstewart authored Oct 1, 2022
1 parent 90604e3 commit 91f1f43
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
4 changes: 2 additions & 2 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ class RasterDataset(GeoDataset):

def __init__(
self,
root: str,
root: str = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
bands: Optional[Sequence[str]] = None,
Expand Down Expand Up @@ -712,7 +712,7 @@ class NonGeoClassificationDataset(NonGeoDataset, ImageFolder): # type: ignore[m

def __init__(
self,
root: str,
root: str = "data",
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
loader: Optional[Callable[[str], Any]] = pil_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/inria.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class InriaAerialImageLabeling(NonGeoDataset):

def __init__(
self,
root: str,
root: str = "data",
split: str = "train",
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
checksum: bool = False,
Expand Down
12 changes: 6 additions & 6 deletions torchgeo/datasets/spacenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ class SpaceNet1(SpaceNet):

def __init__(
self,
root: str,
root: str = "data",
image: str = "rgb",
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
download: bool = False,
Expand Down Expand Up @@ -523,7 +523,7 @@ class SpaceNet2(SpaceNet):

def __init__(
self,
root: str,
root: str = "data",
image: str = "PS-RGB",
collections: List[str] = [],
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
Expand Down Expand Up @@ -650,7 +650,7 @@ class SpaceNet3(SpaceNet):

def __init__(
self,
root: str,
root: str = "data",
image: str = "PS-RGB",
speed_mask: Optional[bool] = False,
collections: List[str] = [],
Expand Down Expand Up @@ -907,7 +907,7 @@ class SpaceNet4(SpaceNet):

def __init__(
self,
root: str,
root: str = "data",
image: str = "PS-RGBNIR",
angles: List[str] = [],
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
Expand Down Expand Up @@ -1082,7 +1082,7 @@ class SpaceNet5(SpaceNet3):

def __init__(
self,
root: str,
root: str = "data",
image: str = "PS-RGB",
speed_mask: Optional[bool] = False,
collections: List[str] = [],
Expand Down Expand Up @@ -1179,7 +1179,7 @@ class SpaceNet7(SpaceNet):

def __init__(
self,
root: str,
root: str = "data",
split: str = "train",
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
download: bool = False,
Expand Down

0 comments on commit 91f1f43

Please sign in to comment.