Skip to content

Commit

Permalink
[ggplot] add shape aesthetic support to point geom (#12207)
Browse files Browse the repository at this point in the history
* [ggplot] add shape aesthetic support to point geom

* add shape scale

* docs and export

* lint

* missing imports

* wip

* fixes legend categories

* fix var
  • Loading branch information
iris-garden authored Sep 23, 2022
1 parent ef14a51 commit 767bdd6
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 64 deletions.
4 changes: 3 additions & 1 deletion hail/python/hail/ggplot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .scale import scale_x_continuous, scale_y_continuous, scale_x_discrete, scale_y_discrete, scale_x_genomic, \
scale_x_log10, scale_y_log10, scale_x_reverse, scale_y_reverse, scale_color_discrete, scale_color_hue, scale_color_identity,\
scale_color_manual, scale_color_continuous, scale_fill_discrete, scale_fill_hue, scale_fill_identity, scale_fill_continuous,\
scale_fill_manual
scale_fill_manual, scale_shape_manual, scale_shape_auto
from .facets import vars, facet_wrap

__all__ = [
Expand Down Expand Up @@ -50,6 +50,8 @@
"scale_fill_discrete",
"scale_fill_hue",
"scale_fill_manual",
"scale_shape_manual",
"scale_shape_auto",
"facet_wrap",
"vars"
]
28 changes: 19 additions & 9 deletions hail/python/hail/ggplot/geoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,21 @@ def get_stat(self):

def _add_aesthetics_to_trace_args(self, trace_args, df):
for aes_name, (plotly_name, default) in self.aes_to_arg.items():
value = None

if hasattr(self, aes_name) and getattr(self, aes_name) is not None:
trace_args[plotly_name] = getattr(self, aes_name)
value = getattr(self, aes_name)
elif aes_name in df.attrs:
trace_args[plotly_name] = df.attrs[aes_name]
value = df.attrs[aes_name]
elif aes_name in df.columns:
trace_args[plotly_name] = df[aes_name]
value = df[aes_name]
elif default is not None:
trace_args[plotly_name] = default
value = default

if plotly_name == "name" and trace_args.get(plotly_name, None) is not None:
trace_args[plotly_name] += f" & {value}"
elif value is not None:
trace_args[plotly_name] = value

def _update_legend_trace_args(self, trace_args, legend_cache):
if "name" in trace_args:
Expand Down Expand Up @@ -88,14 +95,17 @@ class GeomPoint(Geom):
"size": ("marker_size", None),
"tooltip": ("hovertext", None),
"color_legend": ("name", None),
"alpha": ("marker_opacity", None)
"alpha": ("marker_opacity", None),
"shape": ("marker_symbol", None),
"shape_legend": ("name", None),
}

def __init__(self, aes, color=None, size=None, alpha=None):
def __init__(self, aes, color=None, size=None, alpha=None, shape=None):
super().__init__(aes)
self.color = color
self.size = size
self.alpha = alpha
self.shape = shape

def apply_to_fig(self, parent, grouped_data, fig_so_far, precomputed, facet_row, facet_col, legend_cache):
def plot_group(df):
Expand All @@ -119,17 +129,17 @@ def get_stat(self):
return StatIdentity()


def geom_point(mapping=aes(), *, color=None, size=None, alpha=None):
def geom_point(mapping=aes(), *, color=None, size=None, alpha=None, shape=None):
"""Create a scatter plot.
Supported aesthetics: ``x``, ``y``, ``color``, ``alpha``, ``tooltip``
Supported aesthetics: ``x``, ``y``, ``color``, ``alpha``, ``tooltip``, ``shape``
Returns
-------
:class:`FigureAttribute`
The geom to be applied.
"""
return GeomPoint(mapping, color=color, size=size, alpha=alpha)
return GeomPoint(mapping, color=color, size=size, alpha=alpha, shape=shape)


class GeomLine(GeomLineBasic):
Expand Down
9 changes: 8 additions & 1 deletion hail/python/hail/ggplot/ggplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .labels import Labels
from .scale import Scale, ScaleContinuous, ScaleDiscrete, scale_x_continuous, scale_x_genomic, scale_y_continuous, \
scale_x_discrete, scale_y_discrete, scale_color_discrete, scale_color_continuous, scale_fill_discrete, \
scale_fill_continuous
scale_fill_continuous, scale_shape_auto
from .aes import Aesthetic, aes
from .facets import Faceter
from .utils import is_continuous_type, is_genomic_type, check_scale_continuity
Expand Down Expand Up @@ -91,6 +91,13 @@ def add_default_scales(self, aesthetic):
self.scales["fill"] = scale_fill_discrete()
elif aesthetic_str == "fill" and is_continuous:
self.scales["fill"] = scale_fill_continuous()
elif aesthetic_str == "shape" and not is_continuous:
self.scales["shape"] = scale_shape_auto()
elif aesthetic_str == "shape" and is_continuous:
raise ValueError(
"The 'shape' aesthetic does not support continuous "
"types. Specify values of a discrete type instead."
)
else:
if is_continuous:
self.scales[aesthetic_str] = ScaleContinuous(aesthetic_str)
Expand Down
169 changes: 133 additions & 36 deletions hail/python/hail/ggplot/scale.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import abc
from .geoms import FigureAttribute

from hail.context import get_reference
from hail import tstr

from .utils import categorical_strings_to_colors, continuous_nums_to_colors, is_continuous_type, is_discrete_type
from collections.abc import Mapping

import plotly.express as px
import plotly

from hail.context import get_reference
from hail import tstr

from .geoms import FigureAttribute
from .utils import continuous_nums_to_colors, is_continuous_type, is_discrete_type


class Scale(FigureAttribute):
def __init__(self, aesthetic_name):
Expand Down Expand Up @@ -150,6 +152,9 @@ class ScaleDiscrete(Scale):
def __init__(self, aesthetic_name):
super().__init__(aesthetic_name)

def get_values(self, categories):
return None

def transform_data(self, field_expr):
return field_expr

Expand All @@ -162,30 +167,50 @@ def is_continuous(self):
def valid_dtype(self, dtype):
return is_discrete_type(dtype)


class ScaleColorManual(ScaleDiscrete):

def __init__(self, aesthetic_name, values):
super().__init__(aesthetic_name)
self.values = values

def create_local_transformer(self, groups_of_dfs):
categorical_strings = set()
categories = set()
for group_of_dfs in groups_of_dfs:
for df in group_of_dfs:
if self.aesthetic_name in df.attrs:
categorical_strings.add(df.attrs[self.aesthetic_name])

unique_color_mapping = categorical_strings_to_colors(categorical_strings, self.values)
categories.add(df.attrs[self.aesthetic_name])

values = self.get_values(categories)

if values is None:
return super().create_local_transformer(groups_of_dfs)
elif isinstance(values, Mapping):
mapping = values
elif isinstance(values, list):
if len(categories) > len(values):
raise ValueError(
f"Not enough scale values specified. Found {len(categories)} "
f"distinct categories in {categories} and only {len(values)} "
f"scale values were provided in {values}."
)
mapping = dict(zip(categories, values))
else:
raise TypeError(
"Expected scale values to be a Mapping or list, but received a(n) "
f"{type(values)}: {values}."
)

def transform(df):
df.attrs[f"{self.aesthetic_name}_legend"] = df.attrs[self.aesthetic_name]
df.attrs[self.aesthetic_name] = unique_color_mapping[df.attrs[self.aesthetic_name]]
df.attrs[self.aesthetic_name] = mapping[df.attrs[self.aesthetic_name]]
return df

return transform


class ScaleDiscreteManual(ScaleDiscrete):
def __init__(self, aesthetic_name, values):
super().__init__(aesthetic_name)
self.values = values

def get_values(self, categories):
return self.values


class ScaleColorContinuous(ScaleContinuous):

def create_local_transformer(self, groups_of_dfs):
Expand Down Expand Up @@ -217,26 +242,71 @@ def transform(df):


class ScaleColorHue(ScaleDiscrete):
def create_local_transformer(self, groups_of_dfs):
categorical_strings = set()
for group_of_dfs in groups_of_dfs:
for df in group_of_dfs:
if self.aesthetic_name in df.attrs:
categorical_strings.add(df.attrs[self.aesthetic_name])

num_categories = len(categorical_strings)
def get_values(self, categories):
num_categories = len(categories)
step = 1.0 / num_categories
interpolation_values = [step * i for i in range(num_categories)]
hsv_scale = px.colors.get_colorscale("HSV")
colors = px.colors.sample_colorscale(hsv_scale, interpolation_values)
unique_color_mapping = dict(zip(categorical_strings, colors))

def transform(df):
df.attrs[f"{self.aesthetic_name}_legend"] = df.attrs[self.aesthetic_name]
df.attrs[self.aesthetic_name] = unique_color_mapping[df.attrs[self.aesthetic_name]]
return df

return transform
return px.colors.sample_colorscale(hsv_scale, interpolation_values)


class ScaleShapeAuto(ScaleDiscrete):
def get_values(self, categories):
return [
"circle",
"square",
"diamond",
"cross",
"x",
"triangle-up",
"triangle-down",
"triangle-left",
"triangle-right",
"triangle-ne",
"triangle-se",
"triangle-sw",
"triangle-nw",
"pentagon",
"hexagon",
"hexagon2",
"octagon",
"star",
"hexagram",
"star-triangle-up",
"star-triangle-down",
"star-square",
"star-diamond",
"diamond-tall",
"diamond-wide",
"hourglass",
"bowtie",
"circle-cross",
"circle-x",
"square-cross",
"square-x",
"diamond-cross",
"diamond-x",
"cross-thin",
"x-thin",
"asterisk",
"hash",
"y-up",
"y-down",
"y-left",
"y-right",
"line-ew",
"line-ns",
"line-ne",
"line-nw",
"arrow-up",
"arrow-down",
"arrow-left",
"arrow-right",
"arrow-bar-up",
"arrow-bar-down",
"arrow-bar-left",
"arrow-bar-right",
]


class ScaleColorContinuousIdentity(ScaleContinuous):
Expand Down Expand Up @@ -469,7 +539,7 @@ def scale_color_manual(*, values):
:class:`.FigureAttribute`
The scale to be applied.
"""
return ScaleColorManual("color", values=values)
return ScaleDiscreteManual("color", values=values)


def scale_fill_discrete():
Expand Down Expand Up @@ -531,4 +601,31 @@ def scale_fill_manual(*, values):
:class:`.FigureAttribute`
The scale to be applied.
"""
return ScaleColorManual("fill", values=values)
return ScaleDiscreteManual("fill", values=values)


def scale_shape_manual(*, values):
"""A scale that assigns shapes to discrete aesthetics. See `the plotly documentation <https://plotly.com/python-api-reference/generated/plotly.graph_objects.scatter.html#plotly.graph_objects.scatter.Marker.symbol>`__ for a list of supported shapes.
Parameters
----------
values: :class:`list` of :class:`str`
The shapes from which to choose.
Returns
-------
:class:`.FigureAttribute`
The scale to be applied.
"""
return ScaleDiscreteManual("shape", values=values)


def scale_shape_auto():
"""A scale that automatically assigns shapes to discrete aesthetics.
Returns
-------
:class:`.FigureAttribute`
The scale to be applied.
"""
return ScaleShapeAuto("shape")
17 changes: 0 additions & 17 deletions hail/python/hail/ggplot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,6 @@ def should_use_scale_for_grouping(scale):
return (scale.aesthetic_name not in excluded_from_grouping) and scale.is_discrete()


# Map strings to numbers that will index into a color scale.
def categorical_strings_to_colors(string_set, color_values):

if isinstance(color_values, list):
if len(string_set) > len(color_values):
print(f"Not enough colors specified. Found {len(string_set)} distinct values of color aesthetic and only {len(color_values)} colors were provided.")
color_dict = {}
for idx, element in enumerate(string_set):
if element not in color_dict:
color_dict[element] = color_values[idx]

else:
color_dict = color_values

return color_dict


def continuous_nums_to_colors(min_color, max_color, continuous_color_scale):
def adjust_color(input_color):
return (input_color - min_color) / max_color - min_color
Expand Down

0 comments on commit 767bdd6

Please sign in to comment.