Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xarray support in imshow #2166

Merged
merged 25 commits into from
Mar 23, 2020
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
10a7b8e
xarray and imshow
emmanuelle Feb 3, 2020
f449c04
add xarray support in imshow
emmanuelle Feb 6, 2020
4954663
aspect ratio: pixels are not square by default for xarrays, plus time…
emmanuelle Feb 10, 2020
b2ee809
imshow tutorial: add xarray example
emmanuelle Feb 10, 2020
7563750
Merge branch 'master' into imshow-xarray
emmanuelle Feb 10, 2020
8282e0a
update tutorials + use long_name
emmanuelle Feb 10, 2020
568d1c1
change for CI
emmanuelle Feb 10, 2020
e9531df
Merge branch 'master' into imshow-xarray
emmanuelle Feb 11, 2020
d63a76a
comment
emmanuelle Feb 13, 2020
c107581
Merge branch 'imshow-xarray' of https://github.com/plotly/plotly.py i…
emmanuelle Feb 13, 2020
70def8e
try to regenerate cache
emmanuelle Feb 13, 2020
4e3717e
Merge branch 'master' into imshow-xarray
emmanuelle Mar 10, 2020
1621271
tmp
emmanuelle Mar 19, 2020
ae984b0
added labels to imshow
emmanuelle Mar 19, 2020
503e962
blacken
emmanuelle Mar 19, 2020
11414e0
pinning orca
emmanuelle Mar 19, 2020
e64b05a
label
emmanuelle Mar 19, 2020
33890df
removed colorbar key
emmanuelle Mar 19, 2020
e3ec430
corrected bug
emmanuelle Mar 20, 2020
d70fe14
datashader example
emmanuelle Mar 20, 2020
b051468
Update packages/python/plotly/plotly/express/_imshow.py
emmanuelle Mar 20, 2020
e8302d1
hover
emmanuelle Mar 20, 2020
67dd774
Merge branch 'imshow-xarray' of https://github.com/plotly/plotly.py i…
emmanuelle Mar 20, 2020
941779c
generalizing imshow(labels, x, y)
nicolaskruchten Mar 21, 2020
21df030
docs and tests for imshow changes
nicolaskruchten Mar 22, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 7 additions & 21 deletions doc/python/datashader.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jupyter:
name: python
nbconvert_exporter: python
pygments_lexer: ipython3
version: 3.6.8
version: 3.7.3
plotly:
description:
How to use datashader to rasterize large datasets, and visualize
Expand Down Expand Up @@ -98,32 +98,18 @@ fig.show()
```

```python
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
import numpy as np
import datashader as ds
df = pd.read_parquet('https://raw.githubusercontent.com/plotly/datasets/master/2015_flights.parquet')

cvs = ds.Canvas(plot_width=100, plot_height=100)
agg = cvs.points(df, 'SCHEDULED_DEPARTURE', 'DEPARTURE_DELAY')
x = np.array(agg.coords['SCHEDULED_DEPARTURE'])
y = np.array(agg.coords['DEPARTURE_DELAY'])

# Assign nan to zero values so that the corresponding pixels are transparent
agg = np.array(agg.values, dtype=np.float)
agg[agg<1] = np.nan

fig = go.Figure(go.Heatmap(
z=np.log10(agg), x=x, y=y,
hoverongaps=False,
hovertemplate='Scheduled departure: %{x:.1f}h <br>Depature delay: %{y} <br>Log10(Count): %{z}',
colorbar=dict(title='Count (Log)', tickprefix='1.e')))
fig.update_xaxes(title_text='Scheduled departure')
fig.update_yaxes(title_text='Departure delay')
agg.values = np.log10(agg.values)
agg.attrs['long_name'] = 'Log10(count)'
nicolaskruchten marked this conversation as resolved.
Show resolved Hide resolved
fig = px.imshow(agg, origin='lower')
fig.update_traces(hoverongaps=False)
fig.update_layout(coloraxis_colorbar=dict(title='Count (Log)', tickprefix='1.e'))
fig.show()

```

```python

```
28 changes: 28 additions & 0 deletions doc/python/imshow.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,34 @@ fig.update_layout(coloraxis_showscale=False)
fig.show()
```

### Display an xarray image with px.imshow

[xarrays](http://xarray.pydata.org/en/stable/) are labeled arrays (with labeled axes and coordinates). If you pass an xarray image to `px.imshow`, its axes labels and coordinates will be used for ticks. (If you don't want this behavior, just pass `img.values` which is a NumPy array if `img` is an xarray).

```python
import plotly.express as px
import xarray as xr
# Load xarray from dataset included in the xarray tutorial
# We remove 273.5 to display Celsius degrees instead of Kelvin degrees
airtemps = xr.tutorial.open_dataset('air_temperature').air.isel(lon=20) - 273.5
nicolaskruchten marked this conversation as resolved.
Show resolved Hide resolved
airtemps.attrs['long_name'] = 'Temperature' # used for hover
fig = px.imshow(airtemps.T, color_continuous_scale='RdBu_r', origin='lower')
fig.show()
```

### Display an xarray image with square pixels

For xarrays, by default `px.imshow` does not constrain pixels to be square, since axes often correspond to different physical quantities (e.g. time and space), contrary to a plain camera image where pixels are square (most of the time). If you want to impose square pixels, set the parameter `aspect` to "equal" as below.

```python
import plotly.express as px
import xarray as xr
airtemps = xr.tutorial.open_dataset('air_temperature').air.isel(time=500) - 273.5
airtemps.attrs['long_name'] = 'Temperature' # used for hover
fig = px.imshow(airtemps, color_continuous_scale='RdBu_r', aspect='equal')
fig.show()
```

### Display multichannel image data with go.Image

It is also possible to use the `go.Image` trace from the low-level `graph_objects` API in order to display image data. Note that `go.Image` only accepts multichannel images. For single images, use [`go.Heatmap`](/python/heatmaps).
Expand Down
67 changes: 59 additions & 8 deletions packages/python/plotly/plotly/express/_imshow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
from ._core import apply_default_cascade
import numpy as np

try:
import xarray

xarray_imported = True
except ImportError:
xarray_imported = False

_float_types = []

# Adapted from skimage.util.dtype
Expand Down Expand Up @@ -68,14 +75,15 @@ def imshow(
template=None,
width=None,
height=None,
aspect=None,
):
"""
Display an image, i.e. data on a 2D regular raster.

Parameters
----------

img: array-like image
img: array-like image, or xarray
The image data. Supported array shapes are

- (M, N): an image with scalar data. The data is visualized
Expand Down Expand Up @@ -122,6 +130,14 @@ def imshow(
height: number
The figure height in pixels, defaults to 600.

aspect: 'equal', 'auto', or None
- 'equal': Ensures an aspect ratio of 1 or pixels (square pixels)
- 'auto': The axes is kept fixed and the aspect ratio of pixels is
adjusted so that the data fit in the axes. In general, this will
result in non-square pixels.
- if None, 'equal' is used for numpy arrays and 'auto' for xarrays
(which have typically heterogeneous coordinates)

Returns
-------
fig : graph_objects.Figure containing the displayed image
Expand All @@ -137,11 +153,33 @@ def imshow(

In order to update and customize the returned figure, use
`go.Figure.update_traces` or `go.Figure.update_layout`.

If an xarray is passed, dimensions names and coordinates are used for
axes labels and ticks.
"""
args = locals()
apply_default_cascade(args)
img_is_xarray = False
if xarray_imported:
if isinstance(img, xarray.DataArray):
y_label, x_label = img.dims[0], img.dims[1]
nicolaskruchten marked this conversation as resolved.
Show resolved Hide resolved
# np.datetime64 is not handled correctly by go.Heatmap
for ax in [x_label, y_label]:
if np.issubdtype(img.coords[ax].dtype, np.datetime64):
img.coords[ax] = img.coords[ax].astype(str)
x = img.coords[x_label]
y = img.coords[y_label]
img_is_xarray = True
if aspect is None:
aspect = "auto"
z_name = img.attrs["long_name"] if "long_name" in img.attrs else "z"
nicolaskruchten marked this conversation as resolved.
Show resolved Hide resolved

if not img_is_xarray:
if aspect is None:
aspect = "equal"

img = np.asanyarray(img)

# Cast bools to uint8 (also one byte)
if img.dtype == np.bool:
img = 255 * img.astype(np.uint8)
Expand All @@ -150,10 +188,10 @@ def imshow(
if img.ndim == 2:
trace = go.Heatmap(z=img, coloraxis="coloraxis1")
autorange = True if origin == "lower" else "reversed"
layout = dict(
xaxis=dict(scaleanchor="y", constrain="domain"),
yaxis=dict(autorange=autorange, constrain="domain"),
)
layout = dict(yaxis=dict(autorange=autorange))
if aspect == "equal":
layout["xaxis"] = dict(scaleanchor="y", constrain="domain")
layout["yaxis"]["constrain"] = "domain"
colorscale_validator = ColorscaleValidator("colorscale", "imshow")
if zmin is not None and zmax is None:
zmax = img.max()
Expand Down Expand Up @@ -185,12 +223,25 @@ def imshow(
)

layout_patch = dict()
for v in ["title", "height", "width"]:
if args[v]:
layout_patch[v] = args[v]
for attr_name in ["title", "height", "width"]:
if args[attr_name]:
layout_patch[attr_name] = args[attr_name]
if "title" not in layout_patch and args["template"].layout.margin.t is None:
layout_patch["margin"] = {"t": 60}
fig = go.Figure(data=trace, layout=layout)
fig.update_layout(layout_patch)
if img_is_xarray:
Copy link
Contributor

@nicolaskruchten nicolaskruchten Mar 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually think we can do this for np inputs also, no, given the labels dict? We can just default x_label to x etc

if img.ndim <= 2:
hovertemplate = (
x_label
+ ": %{x} <br>"
+ y_label
+ ": %{y} <br>"
+ z_name
+ " : %{z}<extra></extra>"
)
fig.update_traces(x=x, y=y, hovertemplate=hovertemplate)
fig.update_xaxes(title_text=x_label)
fig.update_yaxes(title_text=y_label)
fig.update_layout(template=args["template"], overwrite=True)
return fig
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import plotly.express as px
import numpy as np
import pytest
import xarray as xr

img_rgb = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255]]], dtype=np.uint8)
img_gray = np.arange(100).reshape((10, 10))
Expand Down Expand Up @@ -58,8 +59,9 @@ def test_colorscale():

def test_wrong_dimensions():
imgs = [1, np.ones((5,) * 3), np.ones((5,) * 4)]
msg = "px.imshow only accepts 2D single-channel, RGB or RGBA images."
for img in imgs:
with pytest.raises(ValueError) as err_msg:
with pytest.raises(ValueError, match=msg):
fig = px.imshow(img)


Expand Down Expand Up @@ -114,3 +116,13 @@ def test_zmin_zmax_range_color():
fig = px.imshow(img, zmax=0.8)
assert fig.layout.coloraxis.cmin == 0.0
assert fig.layout.coloraxis.cmax == 0.8


def test_imshow_xarray():
img = np.random.random((20, 30))
da = xr.DataArray(img, dims=["dim_rows", "dim_cols"])
fig = px.imshow(da)
# Dimensions are used for axis labels and coordinates
assert fig.layout.xaxis.title.text == "dim_cols"
assert fig.layout.yaxis.title.text == "dim_rows"
assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_cols"]))
2 changes: 1 addition & 1 deletion packages/python/plotly/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ deps=
retrying==1.3.3
pytest==3.5.1
pandas==0.24.2
xarray==0.10.9
backports.tempfile==1.0
optional: --editable=file:///{toxinidir}/../plotly-geo
optional: numpy==1.16.5
Expand All @@ -71,7 +72,6 @@ deps=
optional: pyshp==1.2.10
optional: pillow==5.2.0
optional: matplotlib==2.2.3
optional: xarray==0.10.9
optional: scikit-image==0.14.4

; CORE ENVIRONMENTS
Expand Down