Skip to content

Commit

Permalink
Merge pull request #33 from cbegeman/add-planar-viz
Browse files Browse the repository at this point in the history
Add planar horizontal viz


This PR ports the `plot_horiz_field` function from `compass/ocean/tests/isomip_plus/viz/__init__.py` into a shared viz folder. There are relatively minor changes related to this being used outside `plot_horiz_series` and making the figure size match the domain aspect ratio. To demonstrate its functionality, it is added to the `initial_state` step of the `baroclinic_channel` test case.
  • Loading branch information
cbegeman authored Apr 14, 2023
2 parents 0b545ec + d7d3788 commit 843f7f1
Show file tree
Hide file tree
Showing 7 changed files with 354 additions and 1 deletion.
13 changes: 12 additions & 1 deletion docs/developers_guide/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,18 @@ ocean/api
compare_timers
```

### validate
#### viz

```{eval-rst}
.. currentmodule:: polaris.viz
.. autosummary::
:toctree: generated/
plot_horiz_field
```

### yaml

```{eval-rst}
.. currentmodule:: polaris.yaml
Expand Down
40 changes: 40 additions & 0 deletions docs/developers_guide/framework.md
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,46 @@ from an MPAS mesh file. Optionally, you can provide the name of an MPAS field
on cells in the mesh file that gives different weight to different cells
(`weight_field`) in the partitioning process.

(dev-visualization)=

## Visualization

Visualization is an optional, but desirable aspect of test cases. Often,
visualization is an optional step of a test case but can also be included
as part of other steps such as `initial_state` or `analysis`.

While developers can write their own visualization scripts associated with
individual test cases, the following shared visualization routines are
provided in `polaris.viz`:

{py:func}`polaris.viz.plot_horiz_field()` produces a patches-style
visualization of x-y fields across a single vertical level at a single time
step. The image file (png) is saved to the directory from which
{py:func}`polaris.viz.plot_horiz_field()` is called. The function
automatically detects whether the field specified by its variable name is
a cell-centered variable or an edge-variable and generates the patches, the
polygons characterized by the field values, accordingly.

```{image} images/baroclinic_channel_cell_patches.png
:align: center
:width: 250 px
```

```{image} images/baroclinic_channel_edge_patches.png
:align: center
:width: 250 px
```

An example function call that uses the default vertical level (top) is:

```python
plot_horiz_field(config, ds, ds_mesh, 'normalVelocity',
'final_normalVelocity.png',
t_index=t_index,
vmin=-max_velocity, vmax=max_velocity,
cmap='cmo.balance', show_patch_edges=True)
```

(dev-validation)=

## Validation
Expand Down
4 changes: 4 additions & 0 deletions polaris/ocean/tests/baroclinic_channel/default/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from polaris.ocean.tests.baroclinic_channel import BaroclinicChannelTestCase
from polaris.ocean.tests.baroclinic_channel.forward import Forward
from polaris.ocean.tests.baroclinic_channel.viz import Viz
from polaris.validate import compare_variables


Expand Down Expand Up @@ -29,6 +30,9 @@ def __init__(self, test_group, resolution):
Forward(test_case=self, ntasks=4, min_tasks=4, openmp_threads=1,
resolution=resolution, run_time_steps=3))

self.add_step(
Viz(test_case=self))

def validate(self):
"""
Compare ``temperature``, ``salinity``, ``layerThickness`` and
Expand Down
10 changes: 10 additions & 0 deletions polaris/ocean/tests/baroclinic_channel/initial_state.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import cmocean # noqa: F401
import numpy as np
import xarray as xr
from mpas_tools.io import write_netcdf
Expand All @@ -6,6 +7,7 @@

from polaris.ocean.vertical import init_vertical_coord
from polaris.step import Step
from polaris.viz import plot_horiz_field


class InitialState(Step):
Expand Down Expand Up @@ -79,6 +81,8 @@ def run(self):

init_vertical_coord(config, ds)

dsMesh['maxLevelCell'] = ds.maxLevelCell

xMin = xCell.min().values
xMax = xCell.max().values
yMin = yCell.min().values
Expand Down Expand Up @@ -140,3 +144,9 @@ def run(self):
ds['fVertex'] = coriolis_parameter * xr.ones_like(ds.xVertex)

write_netcdf(ds, 'initial_state.nc')

plot_horiz_field(ds, dsMesh, 'temperature',
'initial_temperature.png')
plot_horiz_field(ds, dsMesh, 'normalVelocity',
'initial_normalVelocity.png', cmap='cmo.balance',
show_patch_edges=True)
50 changes: 50 additions & 0 deletions polaris/ocean/tests/baroclinic_channel/viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import cmocean # noqa: F401
import numpy as np
import xarray as xr

from polaris.step import Step
from polaris.viz import plot_horiz_field


class Viz(Step):
"""
A step for plotting the results of a series of RPE runs in the baroclinic
channel test group
Attributes
----------
nus : list
A list of viscosities
"""
def __init__(self, test_case):
"""
Create the step
Parameters
----------
test_case : polaris.TestCase
The test case this step belongs to
"""
super().__init__(test_case=test_case, name='viz')
self.add_input_file(
filename='initial_state.nc',
target='../initial_state/initial_state.nc')
self.add_input_file(
filename='output.nc',
target='../forward/output.nc')

def run(self):
"""
Run this step of the test case
"""
ds_mesh = xr.load_dataset('initial_state.nc')
ds = xr.load_dataset('output.nc')
t_index = ds.sizes['Time'] - 1
plot_horiz_field(ds, ds_mesh, 'temperature',
'final_temperature.png', t_index=t_index)
max_velocity = np.max(np.abs(ds.normalVelocity.values))
plot_horiz_field(ds, ds_mesh, 'normalVelocity',
'final_normalVelocity.png',
t_index=t_index,
vmin=-max_velocity, vmax=max_velocity,
cmap='cmo.balance', show_patch_edges=True)
227 changes: 227 additions & 0 deletions polaris/viz/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import importlib.resources as imp_res # noqa: F401
import os

import cmocean # noqa: F401
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from matplotlib.collections import PatchCollection
from matplotlib.colors import LogNorm
from matplotlib.patches import Polygon


def plot_horiz_field(ds, ds_mesh, field_name, out_file_name,
title=None, t_index=None, z_index=None,
vmin=None, vmax=None, show_patch_edges=False,
cmap=None, cmap_set_under=None, cmap_scale='linear'):

"""
Plot a horizontal field from a planar domain using x,y coordinates at a
single time and depth slice.
Parameters
----------
ds : xarray.Dataset
A data set containing fieldName
ds_mesh : xarray.Dataset
A data set containing mesh variables
field_name: str
The name of the variable to plot, which must be present in ds
out_file_name: str
The path to which the plot image should be written
title: str, optional
The title of the plot
vmin, vmax : float, optional
The minimum and maximum values for the colorbar
show_patch_edges : boolean, optional
If true, patches will be plotted with visible edges
t_index, z_index: int, optional
The indices of 'Time' and 'nVertLevels' axes to select for plotting
cmap : Colormap or str, optional
A color map to plot
cmap_set_under : str or None, optional
A color for low out-of-range values
cmap_scale : {'log', 'linear'}, optional
Whether the colormap is logarithmic or linear
"""
style_filename = str(
imp_res.files('polaris.viz') / 'polaris.mplstyle')
plt.style.use(style_filename)

try:
os.makedirs(os.path.dirname(out_file_name))
except OSError:
pass

if title is None:
title = field_name

if 'maxLevelCell' not in ds_mesh:
raise ValueError(
'maxLevelCell must be added to ds_mesh before plotting.')
if field_name not in ds:
raise ValueError(
f'{field_name} must be present in ds before plotting.')

field = ds[field_name]

if 'Time' in field.dims and t_index is None:
t_index = 0
if t_index is not None:
field = field.isel(Time=t_index)
if 'nVertLevels' in field.dims and z_index is None:
z_index = 0
if z_index is not None:
field = field.isel(nVertLevels=z_index)

if 'nCells' in field.dims:
ocean_mask = ds_mesh.maxLevelCell - 1 >= 0
ocean_patches, ocean_mask = _compute_cell_patches(ds_mesh, ocean_mask)
elif 'nEdges' in field.dims:
ocean_mask = np.ones_like(field, dtype='bool')
ocean_mask = _remove_boundary_edges_from_mask(ds_mesh, ocean_mask)
ocean_patches, ocean_mask = _compute_edge_patches(ds_mesh, ocean_mask)
ocean_patches.set_array(field[ocean_mask])
if cmap is not None:
ocean_patches.set_cmap(cmap)
if cmap_set_under is not None:
current_cmap = ocean_patches.get_cmap()
current_cmap.set_under(cmap_set_under)

if show_patch_edges:
ocean_patches.set_edgecolor('black')
else:
ocean_patches.set_edgecolor('face')
ocean_patches.set_clim(vmin=vmin, vmax=vmax)

if cmap_scale == 'log':
ocean_patches.set_norm(LogNorm(vmin=max(1e-10, vmin),
vmax=vmax, clip=False))

width = ds_mesh.xCell.max() - ds_mesh.xCell.min()
length = ds_mesh.yCell.max() - ds_mesh.yCell.min()
aspect_ratio = width.values / length.values
fig_width = 4
legend_width = fig_width / 5
figsize = (fig_width + legend_width, fig_width / aspect_ratio)

plt.figure(figsize=figsize)
ax = plt.subplot(111)
ax.add_collection(ocean_patches)
ax.set_xlabel('x (km)')
ax.set_ylabel('y (km)')
ax.set_aspect('equal')
ax.autoscale(tight=True)
plt.colorbar(ocean_patches, extend='both', shrink=0.7)
plt.title(title)
plt.tight_layout(pad=0.5)
plt.savefig(out_file_name)
plt.close()


def _remove_boundary_edges_from_mask(ds, mask):
area_cell = ds.areaCell.values
mean_area_cell = np.mean(area_cell)
cells_on_edge = ds.cellsOnEdge.values - 1
vertices_on_edge = ds.verticesOnEdge.values - 1
x_cell = ds.xCell.values
y_cell = ds.yCell.values
boundary_vertex = ds.boundaryVertex.values
x_vertex = ds.xVertex.values
y_vertex = ds.yVertex.values
for edge_index in range(ds.sizes['nEdges']):
if not mask[edge_index]:
continue
cell_indices = cells_on_edge[edge_index]
vertex_indices = vertices_on_edge[edge_index, :]
if any(boundary_vertex[vertex_indices]):
mask[edge_index] = 0
continue
vertices = np.zeros((4, 2))
vertices[0, 0] = x_vertex[vertex_indices[0]]
vertices[0, 1] = y_vertex[vertex_indices[0]]
vertices[1, 0] = x_cell[cell_indices[0]]
vertices[1, 1] = y_cell[cell_indices[0]]
vertices[2, 0] = x_vertex[vertex_indices[1]]
vertices[2, 1] = y_vertex[vertex_indices[1]]
vertices[3, 0] = x_cell[cell_indices[1]]
vertices[3, 1] = y_cell[cell_indices[1]]

# Remove edges that span the periodic boundaries
dx = max(vertices[:, 0]) - min(vertices[:, 0])
dy = max(vertices[:, 1]) - min(vertices[:, 1])
if dx * dy / 10 > mean_area_cell:
mask[edge_index] = 0

return mask


def _compute_cell_patches(ds, mask):
patches = []
num_vertices_on_cell = ds.nEdgesOnCell.values
vertices_on_cell = ds.verticesOnCell.values - 1
x_vertex = ds.xVertex.values
y_vertex = ds.yVertex.values
area_cell = ds.areaCell.values
for cell_index in range(ds.sizes['nCells']):
if not mask[cell_index]:
continue
num_vertices = num_vertices_on_cell[cell_index]
vertex_indices = vertices_on_cell[cell_index, :num_vertices]
vertices = np.zeros((num_vertices, 2))
vertices[:, 0] = 1e-3 * x_vertex[vertex_indices]
vertices[:, 1] = 1e-3 * y_vertex[vertex_indices]

# Remove cells that span the periodic boundaries
dx = max(x_vertex[vertex_indices]) - min(x_vertex[vertex_indices])
dy = max(y_vertex[vertex_indices]) - min(y_vertex[vertex_indices])
if dx * dy / 10 > area_cell[cell_index]:
mask[cell_index] = False
else:
polygon = Polygon(vertices, True)
patches.append(polygon)

p = PatchCollection(patches, alpha=1.)

return p, mask


def _compute_edge_patches(ds, mask):
patches = []
cells_on_edge = ds.cellsOnEdge.values - 1
vertices_on_edge = ds.verticesOnEdge.values - 1
x_cell = ds.xCell.values
y_cell = ds.yCell.values
x_vertex = ds.xVertex.values
y_vertex = ds.yVertex.values
for edge_index in range(ds.sizes['nEdges']):
if not mask[edge_index]:
continue
cell_indices = cells_on_edge[edge_index]
vertex_indices = vertices_on_edge[edge_index, :]
vertices = np.zeros((4, 2))
vertices[0, 0] = 1e-3 * x_vertex[vertex_indices[0]]
vertices[0, 1] = 1e-3 * y_vertex[vertex_indices[0]]
vertices[1, 0] = 1e-3 * x_cell[cell_indices[0]]
vertices[1, 1] = 1e-3 * y_cell[cell_indices[0]]
vertices[2, 0] = 1e-3 * x_vertex[vertex_indices[1]]
vertices[2, 1] = 1e-3 * y_vertex[vertex_indices[1]]
vertices[3, 0] = 1e-3 * x_cell[cell_indices[1]]
vertices[3, 1] = 1e-3 * y_cell[cell_indices[1]]

polygon = Polygon(vertices, True)
patches.append(polygon)

p = PatchCollection(patches, alpha=1.)

return p, mask
Loading

0 comments on commit 843f7f1

Please sign in to comment.