Skip to content

Commit

Permalink
Use x and y parameters for Image trace in imshow (for RGB or binary_s…
Browse files Browse the repository at this point in the history
…tring=True) (#2761)

* take x and y into account when using Image trace

* x and y parameters are now used for Image trace in imshow

* raise ValueError when x and y don't have numerical dtype for Image trace

* better error message

* black
  • Loading branch information
emmanuelle authored Nov 17, 2020
1 parent 9c9b98e commit 3e7967c
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 11 deletions.
65 changes: 54 additions & 11 deletions packages/python/plotly/plotly/express/_imshow.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,23 +204,19 @@ def imshow(
args = locals()
apply_default_cascade(args)
labels = labels.copy()
img_is_xarray = False
# ----- Define x and y, set labels if img is an xarray -------------------
if xarray_imported and isinstance(img, xarray.DataArray):
if binary_string:
raise ValueError(
"It is not possible to use binary image strings for xarrays."
"Please pass your data as a numpy array instead using"
"`img.values`"
)
img_is_xarray = True
y_label, x_label = img.dims[0], img.dims[1]
# np.datetime64 is not handled correctly by go.Heatmap
for ax in [x_label, y_label]:
if np.issubdtype(img.coords[ax].dtype, np.datetime64):
img.coords[ax] = img.coords[ax].astype(str)
if x is None:
x = img.coords[x_label]
x = img.coords[x_label].values
if y is None:
y = img.coords[y_label]
y = img.coords[y_label].values
if aspect is None:
aspect = "auto"
if labels.get("x", None) is None:
Expand Down Expand Up @@ -330,6 +326,42 @@ def imshow(
_vectorize_zvalue(zmin, mode="min"),
_vectorize_zvalue(zmax, mode="max"),
)
x0, y0, dx, dy = (None,) * 4
error_msg_xarray = (
"Non-numerical coordinates were passed with xarray `img`, but "
"the Image trace cannot handle it. Please use `binary_string=False` "
"for 2D data or pass instead the numpy array `img.values` to `px.imshow`."
)
if x is not None:
x = np.asanyarray(x)
if np.issubdtype(x.dtype, np.number):
x0 = x[0]
dx = x[1] - x[0]
else:
error_msg = (
error_msg_xarray
if img_is_xarray
else (
"Only numerical values are accepted for the `x` parameter "
"when an Image trace is used."
)
)
raise ValueError(error_msg)
if y is not None:
y = np.asanyarray(y)
if np.issubdtype(y.dtype, np.number):
y0 = y[0]
dy = y[1] - y[0]
else:
error_msg = (
error_msg_xarray
if img_is_xarray
else (
"Only numerical values are accepted for the `y` parameter "
"when an Image trace is used."
)
)
raise ValueError(error_msg)
if binary_string:
if zmin is None and zmax is None: # no rescaling, faster
img_rescaled = img
Expand All @@ -355,13 +387,24 @@ def imshow(
compression=binary_compression_level,
ext=binary_format,
)
trace = go.Image(source=img_str)
trace = go.Image(source=img_str, x0=x0, y0=y0, dx=dx, dy=dy)
else:
colormodel = "rgb" if img.shape[-1] == 3 else "rgba256"
trace = go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel)
trace = go.Image(
z=img,
zmin=zmin,
zmax=zmax,
colormodel=colormodel,
x0=x0,
y0=y0,
dx=dx,
dy=dy,
)
layout = {}
if origin == "lower":
if origin == "lower" or (dy is not None and dy < 0):
layout["yaxis"] = dict(autorange=True)
if dx is not None and dx < 0:
layout["xaxis"] = dict(autorange="reversed")
else:
raise ValueError(
"px.imshow only accepts 2D single-channel, RGB or RGBA images. "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from PIL import Image
from io import BytesIO
import base64
import datetime
from plotly.express.imshow_utils import rescale_intensity

img_rgb = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255]]], dtype=np.uint8)
Expand Down Expand Up @@ -204,6 +205,37 @@ def test_imshow_labels_and_ranges():
with pytest.raises(ValueError):
fig = px.imshow([[1, 2], [3, 4], [5, 6]], x=["a"])

img = np.ones((2, 2), dtype=np.uint8)
fig = px.imshow(img, x=["a", "b"])
assert fig.data[0].x == ("a", "b")

with pytest.raises(ValueError):
img = np.ones((2, 2, 3), dtype=np.uint8)
fig = px.imshow(img, x=["a", "b"])

img = np.ones((2, 2), dtype=np.uint8)
base = datetime.datetime(2000, 1, 1)
fig = px.imshow(img, x=[base, base + datetime.timedelta(hours=1)])
assert fig.data[0].x == (
datetime.datetime(2000, 1, 1, 0, 0),
datetime.datetime(2000, 1, 1, 1, 0),
)

with pytest.raises(ValueError):
img = np.ones((2, 2, 3), dtype=np.uint8)
base = datetime.datetime(2000, 1, 1)
fig = px.imshow(img, x=[base, base + datetime.timedelta(hours=1)])


def test_imshow_ranges_image_trace():
fig = px.imshow(img_rgb, x=[1, 11, 21])
assert fig.data[0].dx == 10
assert fig.data[0].x0 == 1
fig = px.imshow(img_rgb, x=[21, 11, 1])
assert fig.data[0].dx == -10
assert fig.data[0].x0 == 21
assert fig.layout.xaxis.autorange == "reversed"


def test_imshow_dataframe():
df = px.data.medals_wide(indexed=False)
Expand Down

0 comments on commit 3e7967c

Please sign in to comment.