Skip to content

Commit

Permalink
ENH: Plotting for lower rank channels (#751)
Browse files Browse the repository at this point in the history
* MAINT: Make contour[f] use _parse_plot_args

* ENH: Initial implementation of 2D plotting lower rank

* ENH: Make plotting 1D work for >=2D data

* TST: 2D tests for plotting

* TST: Add test for 1D plot

* DOC: Update docstrings for plotting for lower rank channels

* ENH: make sure quick methods do not spit out identical plots

* TEST: pin coverage version to avoid alpha

* remove print statement

* LGTM fix, use a copy of parameter rather than the original

Using the original would lead to unwanted behavior on subsequent calls by modifying the default parameter
  • Loading branch information
ksunden authored and darienmorrow committed Oct 17, 2018
1 parent 40d2de2 commit 752401b
Show file tree
Hide file tree
Showing 7 changed files with 254 additions and 107 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ addons:
install:
- pip install cython
- pip install .
- pip install -U pytest pytest-mp
- pip install -U pytest pytest-mp "coverage<5"
before_script:
- "export DISPLAY=:99.0"
- "sh -e /etc/init.d/xvfb start"
Expand Down
176 changes: 77 additions & 99 deletions WrightTools/artists/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,22 +108,28 @@ def _parse_limits(self, zi=None, data=None, channel_index=None, dynamic_range=Fa

def _parse_plot_args(self, *args, **kwargs):
plot_type = kwargs.pop("plot_type")
if plot_type not in ["pcolor", "pcolormesh"]:
if plot_type not in ["pcolor", "pcolormesh", "contourf", "contour"]:
raise NotImplementedError
args = list(args) # offer pop, append etc
dynamic_range = kwargs.pop("dynamic_range", False)
if isinstance(args[0], Data):
data = args.pop(0)
if plot_type in ["pcolor", "pcolormesh"]:
ndim = 2
if not data.ndim == ndim:
raise wt_exceptions.DimensionalityError(ndim, data.ndim)
# arrays
channel = kwargs.pop("channel", 0)
channel_index = wt_kit.get_index(data.channel_names, channel)
zi = data.channels[channel_index][:]
xi = data.axes[0].full
yi = data.axes[1].full
squeeze = np.array(data.channels[channel_index].shape) == 1
xa = data.axes[0]
ya = data.axes[1]
for sq, xs, ys in zip(squeeze, xa.shape, ya.shape):
if sq and (xs != 1 or ys != 1):
raise wt_exceptions.ValueError("Cannot squeeze axis to fit channel")
squeeze = tuple([0 if i else slice(None) for i in squeeze])
zi = data.channels[channel_index].points
xi = xa.full[squeeze]
yi = ya.full[squeeze]
if plot_type in ["pcolor", "pcolormesh", "contourf", "contour"]:
ndim = 2
if not zi.ndim == ndim:
raise wt_exceptions.DimensionalityError(ndim, data.ndim)
if plot_type in ["pcolor", "pcolormesh"]:
X, Y = pcolor_helper(xi, yi)
else:
Expand All @@ -133,8 +139,23 @@ def _parse_plot_args(self, *args, **kwargs):
kwargs = self._parse_limits(
data=data, channel_index=channel_index, dynamic_range=dynamic_range, **kwargs
)
# cmap
kwargs = self._parse_cmap(data=data, channel_index=channel_index, **kwargs)
if plot_type == "contourf":
if "levels" not in kwargs.keys():
kwargs["levels"] = np.linspace(kwargs["vmin"], kwargs["vmax"], 256)
elif plot_type == "contour":
if "levels" not in kwargs.keys():
if data.channels[channel_index].signed:
n = 11
else:
n = 6
kwargs["levels"] = np.linspace(kwargs.pop("vmin"), kwargs.pop("vmax"), n)[1:-1]
# colors
if "colors" not in kwargs.keys():
kwargs["colors"] = "k"
if "alpha" not in kwargs.keys():
kwargs["alpha"] = 0.5
if plot_type in ["pcolor", "pcolormesh", "contourf"]:
kwargs = self._parse_cmap(data=data, channel_index=channel_index, **kwargs)
else:
xi, yi, zi = args[:3]
if plot_type in ["pcolor", "pcolormesh"]:
Expand All @@ -147,7 +168,11 @@ def _parse_plot_args(self, *args, **kwargs):
data = None
channel_index = 0
kwargs = self._parse_limits(zi=args[2], **kwargs)
kwargs = self._parse_cmap(**kwargs)
if plot_type == "contourf":
if "levels" not in kwargs.keys():
kwargs["levels"] = np.linspace(kwargs["vmin"], kwargs["vmax"], 256)
if plot_type in ["pcolor", "pcolormesh", "contourf"]:
kwargs = self._parse_cmap(**kwargs)
# labels
self._apply_labels(
autolabel=kwargs.pop("autolabel", False),
Expand All @@ -157,7 +182,8 @@ def _parse_plot_args(self, *args, **kwargs):
channel_index=channel_index,
)
# decoration
self.set_facecolor([0.75] * 3)
if plot_type != "contour":
self.set_facecolor([0.75] * 3)
return args, kwargs

def add_sideplot(self, along, pad=0, height=0.75, ymin=0, ymax=1.1):
Expand Down Expand Up @@ -200,6 +226,11 @@ def add_sideplot(self, along, pad=0, height=0.75, ymin=0, ymax=1.1):
def contour(self, *args, **kwargs):
"""Plot contours.
If a 3D or higher Data object is passed, a lower dimensional
channel can be plotted, provided the ``squeeze`` of the channel
has ``ndim==2`` and the first two axes do not span dimensions
other than those spanned by that channel.
Parameters
----------
data : 2D WrightTools.data.Data object
Expand All @@ -224,53 +255,17 @@ def contour(self, *args, **kwargs):
-------
matplotlib.contour.QuadContourSet
"""
args = list(args) # offer pop, append etc
channel = kwargs.pop("channel", 0)
dynamic_range = kwargs.pop("dynamic_range", False)
# unpack data object, if given
if isinstance(args[0], Data):
data = args.pop(0)
if not data.ndim == 2:
raise wt_exceptions.DimensionalityError(2, data.ndim)
# arrays
channel_index = wt_kit.get_index(data.channel_names, channel)
signed = data.channels[channel_index].signed
xi = data.axes[0].full
yi = data.axes[1].full
zi = data.channels[channel_index][:]
args = [xi, yi, zi] + args
# limits
kwargs = self._parse_limits(
data=data, channel_index=channel_index, dynamic_range=dynamic_range, **kwargs
)
# levels
if "levels" not in kwargs.keys():
if signed:
n = 11
else:
n = 6
kwargs["levels"] = np.linspace(kwargs.pop("vmin"), kwargs.pop("vmax"), n)[1:-1]
# colors
if "colors" not in kwargs.keys():
kwargs["colors"] = "k"
if "alpha" not in kwargs.keys():
kwargs["alpha"] = 0.5
# labels
self._apply_labels(
autolabel=kwargs.pop("autolabel", False),
xlabel=kwargs.pop("xlabel", None),
ylabel=kwargs.pop("ylabel", None),
data=data,
channel_index=channel_index,
)
else:
kwargs = self._parse_limits(zi=args[2], dynamic_range=dynamic_range, **kwargs)
# call parent
args, kwargs = self._parse_plot_args(*args, **kwargs, plot_type="contour")
return super().contour(*args, **kwargs)

def contourf(self, *args, **kwargs):
"""Plot contours.
If a 3D or higher Data object is passed, a lower dimensional
channel can be plotted, provided the ``squeeze`` of the channel
has ``ndim==2`` and the first two axes do not span dimensions
other than those spanned by that channel.
Parameters
----------
data : 2D WrightTools.data.Data object
Expand All @@ -295,44 +290,7 @@ def contourf(self, *args, **kwargs):
-------
matplotlib.contour.QuadContourSet
"""
args = list(args) # offer pop, append etc
channel = kwargs.pop("channel", 0)
dynamic_range = kwargs.pop("dynamic_range", False)
# unpack data object, if given
if isinstance(args[0], Data):
data = args.pop(0)
if not data.ndim == 2:
raise wt_exceptions.DimensionalityError(2, data.ndim)
# arrays
channel_index = wt_kit.get_index(data.channel_names, channel)
xi = data.axes[0].full
yi = data.axes[1].full
zi = data.channels[channel_index][:]
args = [xi, yi, zi] + args
# limits
kwargs = self._parse_limits(
data=data, channel_index=channel_index, dynamic_range=dynamic_range, **kwargs
)
# cmap
kwargs = self._parse_cmap(data=data, channel_index=channel_index, **kwargs)
else:
data = None
channel_index = 0
kwargs = self._parse_limits(zi=args[2], dynamic_range=dynamic_range, **kwargs)
kwargs = self._parse_cmap(**kwargs)
# levels
if "levels" not in kwargs.keys():
vmin = kwargs.pop("vmin", args[2].min())
vmax = kwargs.pop("vmax", args[2].max())
kwargs["levels"] = np.linspace(vmin, vmax, 256)
# labels
self._apply_labels(
autolabel=kwargs.pop("autolabel", False),
xlabel=kwargs.pop("xlabel", None),
ylabel=kwargs.pop("ylabel", None),
data=data,
channel_index=channel_index,
)
args, kwargs = self._parse_plot_args(*args, **kwargs, plot_type="contourf")
# Overloading contourf in an attempt to fix aliasing problems when saving vector graphics
# see https://stackoverflow.com/questions/15822159
# also see https://stackoverflow.com/a/32911283
Expand Down Expand Up @@ -386,6 +344,11 @@ def legend(self, *args, **kwargs):
def pcolor(self, *args, **kwargs):
"""Create a pseudocolor plot of a 2-D array.
If a 3D or higher Data object is passed, a lower dimensional
channel can be plotted, provided the ``squeeze`` of the channel
has ``ndim==2`` and the first two axes do not span dimensions
other than those spanned by that channel.
Uses pcolor_helper to ensure that color boundaries are drawn
bisecting point positions, when possible.
Expand Down Expand Up @@ -419,6 +382,11 @@ def pcolor(self, *args, **kwargs):
def pcolormesh(self, *args, **kwargs):
"""Create a pseudocolor plot of a 2-D array.
If a 3D or higher Data object is passed, a lower dimensional
channel can be plotted, provided the ``squeeze`` of the channel
has ``ndim==2`` and the first two axes do not span dimensions
other than those spanned by that channel.
Uses pcolor_helper to ensure that color boundaries are drawn
bisecting point positions, when possible.
Quicker than pcolor
Expand Down Expand Up @@ -453,6 +421,11 @@ def pcolormesh(self, *args, **kwargs):
def plot(self, *args, **kwargs):
"""Plot lines and/or markers.
If a 2D or higher Data object is passed, a lower dimensional
channel can be plotted, provided the ``squeeze`` of the channel
has ``ndim==1`` and the first axis does not span dimensions
other than that spanned by the channel.
Parameters
----------
data : 1D WrightTools.data.Data object
Expand Down Expand Up @@ -480,15 +453,20 @@ def plot(self, *args, **kwargs):
"""
args = list(args) # offer pop, append etc
# unpack data object, if given
if hasattr(args[0], "id"): # TODO: replace once class comparison works...
if isinstance(args[0], Data):
data = args.pop(0)
channel = kwargs.pop("channel", 0)
if not data.ndim == 1:
raise wt_exceptions.DimensionalityError(1, data.ndim)
# arrays
channel_index = wt_kit.get_index(data.channel_names, channel)
xi = data.axes[0][:]
zi = data.channels[channel_index][:].T
squeeze = np.array(data.channels[channel_index].shape) == 1
xa = data.axes[0]
for sq, xs in zip(squeeze, xa.shape):
if sq and xs != 1:
raise wt_exceptions.ValueError("Cannot squeeze axis to fit channel")
squeeze = tuple([0 if i else slice(None) for i in squeeze])
zi = data.channels[channel_index].points
xi = xa[squeeze]
if not zi.ndim == 1:
raise wt_exceptions.DimensionalityError(1, data.ndim)
args = [xi, zi] + args
else:
data = None
Expand Down
18 changes: 13 additions & 5 deletions WrightTools/artists/_quick.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,14 @@ def quick1D(
list of strings
List of saved image files (if any).
"""
# prepare data
chopped = data.chop(axis, at=at, verbose=False)
# channel index
channel_index = wt_kit.get_index(data.channel_names, channel)
shape = data.channels[channel_index].shape
collapse = [i for i in range(len(shape)) if shape[i] == 1]
at = at.copy()
at.update({c: 0 for c in collapse})
# prepare data
chopped = data.chop(axis, at=at, verbose=False)
# prepare figure
fig = None
if len(chopped) > 10:
Expand All @@ -87,7 +91,7 @@ def quick1D(
else:
fname = data.natural_name
else:
folder_name = "mpl_1D " + wt_kit.TimeStamp().path
folder_name = "quick1D " + wt_kit.TimeStamp().path
os.mkdir(folder_name)
save_directory = folder_name
# chew through image generation
Expand Down Expand Up @@ -204,10 +208,14 @@ def quick2D(
list of strings
List of saved image files (if any).
"""
# prepare data
chopped = data.chop(xaxis, yaxis, at=at, verbose=False)
# channel index
channel_index = wt_kit.get_index(data.channel_names, channel)
shape = data.channels[channel_index].shape
collapse = [i for i in range(len(shape)) if shape[i] == 1]
at = at.copy()
at.update({c: 0 for c in collapse})
# prepare data
chopped = data.chop(xaxis, yaxis, at=at, verbose=False)
# colormap
# get colormap
if data.channels[channel_index].signed:
Expand Down
12 changes: 10 additions & 2 deletions WrightTools/data/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def chop(self, *args, at={}, parent=None, verbose=True) -> wt_collection.Collect
args[i] = wt_kit.string2identifier(arg)

# normalize the at keys to the natural name
for k in list(at.keys()):
for k in [ak for ak in at.keys() if type(ak) == str]:
for op in operators:
if op in k:
nk = k.replace(op, operator_to_identifier[op])
Expand All @@ -369,18 +369,26 @@ def chop(self, *args, at={}, parent=None, verbose=True) -> wt_collection.Collect
# get output collection
out = wt_collection.Collection(name="chop", parent=parent)
# get output shape
kept = args + list(at.keys())
kept = args + [ak for ak in at.keys() if type(ak) == str]
kept_axes = [self._axes[self.axis_names.index(a)] for a in kept]
removed_axes = [a for a in self._axes if a not in kept_axes]
removed_shape = wt_kit.joint_shape(*removed_axes)
if removed_shape == ():
removed_shape = (1,) * self.ndim
removed_shape = list(removed_shape)
for i in at.keys():
if type(i) == int:
removed_shape[i] = 1
removed_shape = tuple(removed_shape)
# iterate
i = 0
for idx in np.ndindex(removed_shape):
idx = np.array(idx, dtype=object)
idx[np.array(removed_shape) == 1] = slice(None)
for axis, point in at.items():
if type(axis) == int:
idx[axis] = point
continue
point, units = point
destination_units = self._axes[self.axis_names.index(axis)].units
point = wt_units.converter(point, units, destination_units)
Expand Down
Loading

0 comments on commit 752401b

Please sign in to comment.