Skip to content

Commit

Permalink
Add more patching visualizations
Browse files Browse the repository at this point in the history
  • Loading branch information
relativityhd committed Oct 23, 2024
1 parent 49d9458 commit c491bfe
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 8 deletions.
8 changes: 7 additions & 1 deletion darts-segmentation/src/darts_segmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def predict_in_patches(
overlap: int,
batch_size: int,
device=torch.device,
return_weights: bool = False,
) -> torch.Tensor:
"""Predict on a tensor.
Expand All @@ -108,6 +109,7 @@ def predict_in_patches(
batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
Tensor will be sliced into patches and these again will be infered in batches.
device (torch.device): The device to use for the prediction.
return_weights (bool, optional): Whether to return the weights. Can be used for debugging. Defaults to False.
Returns:
The predicted tensor.
Expand Down Expand Up @@ -173,4 +175,8 @@ def predict_in_patches(
# Remove the 1px border and the padding
prediction = prediction[:, 1:-1, 1:-1]
logger.debug(f"Predicting took {time.time() - start_time:.2f}s")
return prediction

if return_weights:
return prediction, weights
else:
return prediction
121 changes: 114 additions & 7 deletions notebooks/patch-inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from darts_segmentation.utils import patch_coords\n",
"import lovely_tensors as lt\n",
"import matplotlib.patches as mpl_patches\n",
"import matplotlib.pyplot as plt"
"import matplotlib.pyplot as plt\n",
"import torch\n",
"from darts_segmentation.utils import patch_coords, predict_in_patches"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -31,7 +32,14 @@
"overlap = 3\n",
"\n",
"# Create an example tile (already as torch tensor)\n",
"tensor_tiles = torch.rand((3, 1, h, w))"
"tensor_tiles = torch.rand((3, 1, h, w)) * 0.2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Patching"
]
},
{
Expand All @@ -42,7 +50,7 @@
"source": [
"# Visualize the patching\n",
"fig, ax = plt.subplots(1, 1, figsize=(20, 20))\n",
"ax.imshow(tensor_tiles[0, 0])\n",
"ax.imshow(tensor_tiles[0, 0], vmin=0, vmax=1, cmap=\"gray\")\n",
"colors = [\"red\", \"orange\", \"grey\", \"brown\", \"yellow\", \"purple\", \"teal\"]\n",
"for i, (y, x, patch_idx_y, patch_idx_x) in enumerate(patch_coords(h, w, patch_size, overlap)):\n",
" c = colors[i % len(colors)]\n",
Expand All @@ -53,12 +61,111 @@
" ax.text(x, y, f\"{i}: {patch_idx_x}-{patch_idx_y} ({x}-{y})\", bbox={\"facecolor\": \"white\"})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Weights of overlap"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"# Example parameters\n",
"h, w = 8000, 8000\n",
"patch_size = 1024\n",
"overlap = 128\n",
"\n",
"# Create an example tile (already as torch tensor)\n",
"tensor_tiles = torch.rand((3, 1, h, w)) * 0.2\n",
"\n",
"\n",
"def mock_model(x: torch.Tensor) -> torch.Tensor: # noqa: D103\n",
" return x * 3\n",
"\n",
"\n",
"res, weights = predict_in_patches(\n",
" mock_model, tensor_tiles, patch_size, overlap, batch_size=1, device=\"cpu\", return_weights=True\n",
")\n",
"expected = torch.sigmoid(tensor_tiles * 3).squeeze(1)\n",
"\n",
"diff = torch.abs(res - expected)\n",
"\n",
"print(f\"{'expected': <20}{lt.lovely(expected)}\")\n",
"print(f\"{'res': <20}{lt.lovely(res)}\")\n",
"print(f\"{'diff': <20}{lt.lovely(diff)}\")\n",
"print(f\"{'weights': <20}{lt.lovely(weights)}\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, axs = plt.subplots(1, 3, figsize=(20, 10))\n",
"axs[0].imshow(res[0], vmin=0, vmax=1, cmap=\"gray\")\n",
"axs[0].set_title(\"Result\")\n",
"axs[1].imshow(expected[0], vmin=0, vmax=1, cmap=\"gray\")\n",
"axs[1].set_title(\"Input\")\n",
"im = axs[2].imshow(diff[0], cmap=\"gray\")\n",
"axs[2].set_title(\"Difference\")\n",
"plt.colorbar(im)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.imshow(diff[0], cmap=\"viridis\", vmin=0, vmax=1e-8)\n",
"plt.colorbar()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.imshow(diff[0], cmap=\"viridis\", vmin=0, vmax=1e-8)\n",
"plt.colorbar()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create a soft margin for the patches\n",
"margin_ramp = torch.cat(\n",
" [\n",
" torch.linspace(0, 1, overlap),\n",
" torch.ones(patch_size - 2 * overlap),\n",
" torch.linspace(1, 0, overlap),\n",
" ]\n",
")\n",
"soft_margin = margin_ramp.reshape(1, 1, patch_size) * margin_ramp.reshape(1, patch_size, 1)\n",
"plt.imshow(soft_margin[0], cmap=\"gray\")\n",
"plt.title(\"Soft margin\")\n",
"plt.colorbar()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.imshow(weights[0], cmap=\"hot\")\n",
"# add colorbar\n",
"plt.colorbar()"
]
}
],
"metadata": {
Expand Down

0 comments on commit c491bfe

Please sign in to comment.