Skip to content

Commit

Permalink
Adding plotting to ChesapeakeCVPR dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
calebrob6 committed Oct 6, 2022
1 parent 030d586 commit 962e13f
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 0 deletions.
10 changes: 10 additions & 0 deletions tests/datasets/test_chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,13 @@ def test_multiple_hits_query(self, dataset: ChesapeakeCVPR) -> None:
IndexError, match="query: .* spans multiple tiles which is not valid"
):
ds[dataset.bounds]

def test_plot(self, dataset: ChesapeakeCVPR) -> None:
x = dataset[dataset.bounds].copy()
dataset.plot(x, suptitle="Test")
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
x["prediction"] = x["mask"][:, :, 0].clone().unsqueeze(2)
dataset.plot(x)
plt.close()
154 changes: 154 additions & 0 deletions torchgeo/datasets/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch
from matplotlib.colors import ListedColormap
from rasterio.crs import CRS
from torch import Tensor

from .geo import GeoDataset, RasterDataset
from .utils import BoundingBox, download_url, extract_archive
Expand Down Expand Up @@ -438,6 +439,46 @@ class ChesapeakeCVPR(GeoDataset):
crs = CRS.from_epsg(3857)
res = 1

lc_cmap = {
0: (0, 0, 0, 0),
1: (0, 197, 255, 255),
2: (38, 115, 0, 255),
3: (163, 255, 115, 255),
4: (255, 170, 0, 255),
5: (156, 156, 156, 255),
6: (0, 0, 0, 255),
15: (0, 0, 0, 0),
}

nlcd_cmap = {
0: (0, 0, 0, 0),
11: (70, 107, 159, 255),
12: (209, 222, 248, 255),
21: (222, 197, 197, 255),
22: (217, 146, 130, 255),
23: (235, 0, 0, 255),
24: (171, 0, 0, 255),
31: (179, 172, 159, 255),
41: (104, 171, 95, 255),
42: (28, 95, 44, 255),
43: (181, 197, 143, 255),
52: (204, 184, 121, 255),
71: (223, 223, 194, 255),
81: (220, 217, 57, 255),
82: (171, 108, 40, 255),
90: (184, 217, 235, 255),
95: (108, 159, 184, 255),
}

prior_color_matrix = np.array(
[
[0.0, 0.77254902, 1.0, 1.0],
[0.14901961, 0.45098039, 0.0, 1.0],
[0.63921569, 1.0, 0.45098039, 1.0],
[0.61176471, 0.61176471, 0.61176471, 1.0],
]
)

valid_layers = [
"naip-new",
"naip-old",
Expand Down Expand Up @@ -540,6 +581,34 @@ def __init__(

super().__init__(transforms)

lc_colors = []
for i in range(min(self.lc_cmap.keys()), max(self.lc_cmap.keys()) + 1):
if i in self.lc_cmap:
lc_colors.append(
(
self.lc_cmap[i][0] / 255.0,
self.lc_cmap[i][1] / 255.0,
self.lc_cmap[i][2] / 255.0,
)
)
else:
lc_colors.append((0, 0, 0))
self._lc_cmap = ListedColormap(lc_colors)

nlcd_colors = []
for i in range(min(self.nlcd_cmap.keys()), max(self.nlcd_cmap.keys()) + 1):
if i in self.nlcd_cmap:
nlcd_colors.append(
(
self.nlcd_cmap[i][0] / 255.0,
self.nlcd_cmap[i][1] / 255.0,
self.nlcd_cmap[i][2] / 255.0,
)
)
else:
nlcd_colors.append((0, 0, 0))
self._nlcd_cmap = ListedColormap(nlcd_colors)

# Add all tiles into the index in epsg:3857 based on the included geojson
mint: float = 0
maxt: float = sys.maxsize
Expand Down Expand Up @@ -694,3 +763,88 @@ def _extract(self) -> None:
"""Extract the dataset."""
for subdataset in self.subdatasets:
extract_archive(os.path.join(self.root, self.filenames[subdataset]))

def plot(
self,
sample: Dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
"""Plot a sample from the dataset.
Args:
sample: a sample returned by :meth:`__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle
Returns:
a matplotlib Figure with the rendered sample
.. versionadded:: 0.4
"""
image = np.rollaxis(sample["image"].numpy(), 0, 3)
mask = np.rollaxis(sample["mask"].numpy(), 0, 3)

num_panels = len(self.layers)
showing_predictions = "prediction" in sample
if showing_predictions:
predictions = sample["prediction"].numpy()
num_panels += 1

fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 5))

i = 0
for layer in self.layers:
if layer == "naip-new" or layer == "naip-old":
img = image[:, :, :3] / 255
image = image[:, :, 4:]
axs[i].axis("off")
axs[i].imshow(img)
elif layer == "landsat-leaf-on" or layer == "landsat-leaf-off":
img = image[:, :, [3, 2, 1]] / 3000
image = image[:, :, 9:]
axs[i].axis("off")
axs[i].imshow(img)
elif layer == "nlcd":
img = mask[:, :, 0]
mask = mask[:, :, 1:]
axs[i].imshow(
img, vmin=0, vmax=95, cmap=self._nlcd_cmap, interpolation="none"
)
axs[i].axis("off")
elif layer == "lc":
img = mask[:, :, 0]
mask = mask[:, :, 1:]
axs[i].imshow(
img, vmin=0, vmax=15, cmap=self._lc_cmap, interpolation="none"
)
axs[i].axis("off")
elif layer == "buildings":
img = mask[:, :, 0]
mask = mask[:, :, 1:]
axs[i].imshow(img, vmin=0, vmax=1, cmap="gray", interpolation="none")
axs[i].axis("off")
elif layer == "prior_from_cooccurrences_101_31_no_osm_no_buildings":
img = (mask[:, :, :4] @ self.prior_color_matrix) / 255
mask = mask[:, :, 4:]
axs[i].imshow(img)
axs[i].axis("off")

if show_titles:
if layer == "prior_from_cooccurrences_101_31_no_osm_no_buildings":
axs[i].set_title("prior")
else:
axs[i].set_title(layer)
i += 1

if showing_predictions:
axs[i].imshow(
predictions, vmin=0, vmax=15, cmap=self._lc_cmap, interpolation="none"
)
axs[i].axis("off")
if show_titles:
axs[i].set_title("Predictions")

if suptitle is not None:
plt.suptitle(suptitle)
return fig

0 comments on commit 962e13f

Please sign in to comment.