Skip to content

Commit

Permalink
Allow multilabels in VectorDataset (microsoft#862)
Browse files Browse the repository at this point in the history
* add label_name param to VectorDataset

* small fix by black

* add tests

* reuse CustomVectorDataset

* versionadded in docstring
  • Loading branch information
pmandiola authored Oct 26, 2022
1 parent 22ed152 commit 4f2bb95
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 4 deletions.
6 changes: 3 additions & 3 deletions tests/data/vector/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"features": [
{
"type": "Feature",
"properties": {},
"properties": {"label_id": 1},
"geometry": {
"type": "Polygon",
"coordinates": [
Expand All @@ -35,7 +35,7 @@
},
{
"type": "Feature",
"properties": {},
"properties": {"label_id": 2},
"geometry": {
"type": "Polygon",
"coordinates": [
Expand All @@ -45,7 +45,7 @@
},
{
"type": "Feature",
"properties": {},
"properties": {"label_id": 3},
"geometry": {
"type": "Polygon",
"coordinates": [
Expand Down
2 changes: 1 addition & 1 deletion tests/data/vector/vector.geojson
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"type": "FeatureCollection", "crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}}, "features": [{"type": "Feature", "properties": {}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]]}}, {"type": "Feature", "properties": {}, "geometry": {"type": "Polygon", "coordinates": [[[1.0, 0.0], [1.0, 1.0], [2.0, 1.0], [2.0, 0.0], [1.0, 0.0]]]}}, {"type": "Feature", "properties": {}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 1.0], [0.0, 2.0], [1.0, 2.0], [1.0, 1.0], [0.0, 1.0]]]}}]}
{"type": "FeatureCollection", "crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}}, "features": [{"type": "Feature", "properties": {"label_id": 1}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]]}}, {"type": "Feature", "properties": {"label_id": 2}, "geometry": {"type": "Polygon", "coordinates": [[[1.0, 0.0], [1.0, 1.0], [2.0, 1.0], [2.0, 0.0], [1.0, 0.0]]]}}, {"type": "Feature", "properties": {"label_id": 3}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 1.0], [0.0, 2.0], [1.0, 2.0], [1.0, 1.0], [0.0, 1.0]]]}}]}
22 changes: 22 additions & 0 deletions tests/datasets/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,33 @@ def dataset(self) -> CustomVectorDataset:
transforms = nn.Identity()
return CustomVectorDataset(root, res=0.1, transforms=transforms)

@pytest.fixture(scope="class")
def multilabel(self) -> CustomVectorDataset:
root = os.path.join("tests", "data", "vector")
transforms = nn.Identity()
return CustomVectorDataset(
root, res=0.1, transforms=transforms, label_name="label_id"
)

def test_getitem(self, dataset: CustomVectorDataset) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)
assert torch.equal(
x["mask"].unique(), # type: ignore[no-untyped-call]
torch.tensor([0, 1], dtype=torch.uint8),
)

def test_getitem_multilabel(self, multilabel: CustomVectorDataset) -> None:
x = multilabel[multilabel.bounds]
assert isinstance(x, dict)
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)
assert torch.equal(
x["mask"].unique(), # type: ignore[no-untyped-call]
torch.tensor([0, 1, 2, 3], dtype=torch.uint8),
)

def test_empty_shapes(self, dataset: CustomVectorDataset) -> None:
query = BoundingBox(1.1, 1.9, 1.1, 1.9, 0, 0)
Expand Down
9 changes: 9 additions & 0 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ def __init__(
crs: Optional[CRS] = None,
res: float = 0.0001,
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
label_name: Optional[str] = None,
) -> None:
"""Initialize a new Dataset instance.
Expand All @@ -545,14 +546,20 @@ def __init__(
res: resolution of the dataset in units of CRS
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
label_name: name of the dataset property that has the label to be
rasterized into the mask
Raises:
FileNotFoundError: if no files are found in ``root``
.. versionadded:: 0.4
The *label_name* parameter.
"""
super().__init__(transforms)

self.root = root
self.res = res
self.label_name = label_name

# Populate the dataset index
i = 0
Expand Down Expand Up @@ -621,6 +628,8 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
shape = fiona.transform.transform_geom(
src.crs, self.crs.to_dict(), feature["geometry"]
)
if self.label_name:
shape = (shape, feature["properties"][self.label_name])
shapes.append(shape)

# Rasterize geometries
Expand Down

0 comments on commit 4f2bb95

Please sign in to comment.