Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support N>1 spanners #345

Merged
merged 6 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions great_tables/_spanners.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,20 @@ def tab_spanner(
```
"""

crnt_spanner_ids = [span.spanner_id for span in data._spanners]
crnt_spanner_ids = set([span.spanner_id for span in data._spanners])

if id is None:
id = label

if isinstance(columns, (str, int)):
columns = [columns]
elif columns is None:
columns = []

if isinstance(spanners, (str, int)):
spanners = [spanners]
elif spanners is None:
spanners = []

# validations ----
if level is not None and level < 0:
Expand All @@ -142,11 +146,7 @@ def tab_spanner(

# select columns ----

if columns is None:
# TODO: null_means is unimplemented
raise NotImplementedError("columns must be specified")

selected_column_names = resolve_cols_c(data=data, expr=columns, null_means="nothing")
selected_column_names = resolve_cols_c(data=data, expr=columns, null_means="nothing") or []

# select spanner ids ----
# TODO: this supports tidyselect
Expand All @@ -157,8 +157,10 @@ def tab_spanner(
else:
spanner_ids = []

# Check that we've selected something explicitly
if not len(selected_column_names) and not len(spanner_ids):
return data
# TODO: null_means is unimplemented
raise NotImplementedError("columns/spanners must be specified")

# get column names associated with selected spanners ----
_vars = [span.vars for span in data._spanners if span.spanner_id in spanner_ids]
Expand Down Expand Up @@ -187,10 +189,9 @@ def tab_spanner(
)

spanners = data._spanners.append_entry(new_span)

new_data = data._replace(_spanners=spanners)

if gather and not len(spanner_ids) and level == 0:
if gather and not len(spanner_ids) and level == 0 and column_names:
return cols_move(new_data, columns=column_names, after=column_names[0])

return new_data
Expand Down Expand Up @@ -518,7 +519,6 @@ def spanners_print_matrix(

non_empty_spans = [span for crnt_vars, span in zip(_vars, spanners) if len(crnt_vars)]
new_levels = [_lvls.index(span.spanner_level) for span in non_empty_spans]

crnt_spans = Spanners(non_empty_spans).relevel(new_levels)

if not crnt_spans:
Expand Down
84 changes: 43 additions & 41 deletions great_tables/_utils_render_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,12 @@ def create_columns_component_h(data: GTData) -> str:
)

spanner_ids, spanner_col_names = spanners_print_matrix(
spanners=data._spanners, boxhead=boxhead, include_hidden=False, ids=True
spanners=data._spanners, boxhead=boxhead, include_hidden=False, ids=False
)

level_1_index = 0
# Last is column labels
# So take second to last
level_1_index = -2

# A list of <th> elements that will go in the first level; this
# includes spanner labels and column labels for solo columns (don't
Expand Down Expand Up @@ -341,34 +343,21 @@ def create_columns_component_h(data: GTData) -> str:
table_col_headings = tags.tr(level_1_spanners, class_="gt_col_headings gt_spanner_row")

if _get_spanners_matrix_height(data=data) > 2:
# TODO: functions like seq_len don't exist
higher_spanner_rows_idx = seq_len(nrow(spanner_ids) - 2) # noqa

# Spanners are listed top to bottom, so we need to work bottom to top
# We can skip the last (column labels) and second to last (first spanner)
higher_spanner_rows_idx = range(0, len(spanner_ids) - 2)
higher_spanner_rows = TagList()

for i in higher_spanner_rows_idx:
spanner_ids_row = spanner_ids[i]
spanners_row = spanners[i]
# TODO: shouldn't use np here
spanners_vars = list(set(spanner_ids_row[~np.isnan(spanner_ids_row)].tolist())) # noqa

# Replace NA values in spanner_ids_row with an empty string
# TODO: shouldn't use np here
spanner_ids_row[np.isnan(spanner_ids_row)] = "" # noqa

spanners_rle = [(k, len(list(g))) for k, g in groupby(list(spanner_ids_row))]

sig_cells = [1] + [
i + 1
for i, (k, _) in enumerate(spanners_rle[:-1])
if k is None or k != spanners_rle[i - 1][0]
]

colspans = [
spanners_rle[j][1] if (j + 1) in sig_cells else 0
for j in range(len(spanner_ids_row))
]

for k, v in spanners_row.items():
if v is None:
spanners_row[k] = ""

spanner_ids_index = list(spanners_row.values())
spanners_rle = list(seq_groups(seq=list(spanner_ids_index)))
group_spans = [[x[1]] + [0] * (x[1] - 1) for x in spanners_rle]
colspans = list(chain(*group_spans))
level_i_spanners = []

for colspan, span_label in zip(colspans, spanners_row.values()):
Expand All @@ -386,38 +375,51 @@ def create_columns_component_h(data: GTData) -> str:
# )
spanner_style = None

if span_label:
span = tags.span(
HTML(_process_text(span_label)),
class_="gt_column_spanner",
)
else:
span = tags.span(HTML("&nbsp;"))

level_i_spanners.append(
tags.th(
TagList(
tags.span(HTML(span_label)),
tags.span(HTML("&nbsp;"), class_="gt_column_spanner_inner"),
),
class_="gt_center gt_columns_top_border gt_column_spanner_outer",
span,
class_="gt_center gt_columns_bottom_border gt_columns_top_border gt_column_spanner_outer",
rowspan=1,
colspan=colspans[j],
colspan=colspan,
style=spanner_style,
scope="colgroup" if colspans[j] > 1 else "col",
scope="colgroup" if colspan > 1 else "col",
)
)

if len(stub_layout) > 0 and i == 1:
level_i_spanners = tags.th(
TagList(level_i_spanners),
rowspan=max(list(higher_spanner_rows_idx)),
colspan=len(stub_layout),
scope="colgroup" if len(stub_layout) > 1 else "col",
if len(stub_layout) > 0:
level_i_spanners.insert(
0,
tags.th(
tags.span(HTML("&nbsp")),
class_=f"gt_col_heading gt_columns_bottom_border gt_{str(stubhead_label_alignment)}",
rowspan=1,
colspan=len(stub_layout),
scope="colgroup" if len(stub_layout) > 1 else "col",
),
)

higher_spanner_rows = TagList(
higher_spanner_rows,
TagList(tags.tr(level_i_spanners, class_="gt_col_headings gt_spanner_row")),
TagList(
tags.tr(
level_i_spanners,
class_="gt_col_headings gt_spanner_row",
)
),
)

table_col_headings = TagList(
higher_spanner_rows,
table_col_headings,
)

return str(table_col_headings)


Expand Down
58 changes: 58 additions & 0 deletions tests/__snapshots__/test_utils_render_html.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,64 @@
</tbody>
'''
# ---
# name: test_multiple_spanners_pads_for_stubhead_label
'''
<tr class="gt_col_headings gt_spanner_row">
<th class="gt_col_heading gt_columns_bottom_border gt_left" rowspan="1" colspan="1" scope="col">
<span>&nbsp</span>
</th>
<th class="gt_center gt_columns_bottom_border gt_columns_top_border gt_column_spanner_outer" rowspan="1" colspan="3" scope="colgroup">
<span class="gt_column_spanner">E</span>
</th>
<th class="gt_center gt_columns_bottom_border gt_columns_top_border gt_column_spanner_outer" rowspan="1" colspan="4" scope="colgroup">
<span>&nbsp;</span>
</th>
</tr>
<tr class="gt_col_headings gt_spanner_row">
<th class="gt_col_heading gt_columns_bottom_border gt_left" rowspan="1" colspan="1" scope="col">
<span>&nbsp</span>
</th>
<th class="gt_center gt_columns_bottom_border gt_columns_top_border gt_column_spanner_outer" rowspan="1" colspan="2" scope="colgroup">
<span>&nbsp;</span>
</th>
<th class="gt_center gt_columns_bottom_border gt_columns_top_border gt_column_spanner_outer" rowspan="1" colspan="3" scope="colgroup">
<span class="gt_column_spanner">D</span>
</th>
<th class="gt_center gt_columns_bottom_border gt_columns_top_border gt_column_spanner_outer" rowspan="1" colspan="2" scope="colgroup">
<span>&nbsp;</span>
</th>
</tr>
<tr class="gt_col_headings gt_spanner_row">
<th class="gt_col_heading gt_columns_bottom_border gt_left" rowspan="1" colspan="1" scope="col">
<span>&nbsp</span>
</th>
<th class="gt_center gt_columns_bottom_border gt_columns_top_border gt_column_spanner_outer" rowspan="1" colspan="2" scope="colgroup">
<span class="gt_column_spanner">C</span>
</th>
<th class="gt_center gt_columns_bottom_border gt_columns_top_border gt_column_spanner_outer" rowspan="1" colspan="1" scope="col">
<span class="gt_column_spanner">B</span>
</th>
<th class="gt_center gt_columns_bottom_border gt_columns_top_border gt_column_spanner_outer" rowspan="1" colspan="4" scope="colgroup">
<span>&nbsp;</span>
</th>
</tr>
<tr class="gt_col_headings gt_spanner_row">
<th class="gt_col_heading gt_columns_bottom_border gt_left" rowspan="2" colspan="1" scope="col" id="Group">Group</th>
<th class="gt_center gt_columns_top_border gt_column_spanner_outer" rowspan="1" colspan="3" scope="colgroup" id="A">
<span class="gt_column_spanner">A</span>
</th>
<th class="gt_col_heading gt_columns_bottom_border gt_right" rowspan="2" colspan="1" scope="col" id="date">date</th>
<th class="gt_col_heading gt_columns_bottom_border gt_right" rowspan="2" colspan="1" scope="col" id="time">time</th>
<th class="gt_col_heading gt_columns_bottom_border gt_right" rowspan="2" colspan="1" scope="col" id="datetime">datetime</th>
<th class="gt_col_heading gt_columns_bottom_border gt_right" rowspan="2" colspan="1" scope="col" id="currency">currency</th>
</tr>
<tr class="gt_col_headings">
<th class="gt_col_heading gt_columns_bottom_border gt_right" rowspan="1" colspan="1" scope="col" id="num">num</th>
<th class="gt_col_heading gt_columns_bottom_border gt_left" rowspan="1" colspan="1" scope="col" id="char">char</th>
<th class="gt_col_heading gt_columns_bottom_border gt_left" rowspan="1" colspan="1" scope="col" id="fctr">fctr</th>
</tr>
'''
# ---
# name: test_render_groups_reordered
'''
<tbody class="gt_table_body">
Expand Down
25 changes: 24 additions & 1 deletion tests/test_spanners.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import polars as pl
import polars.selectors as cs
import pytest
from great_tables import GT
from great_tables import GT, exibble
from great_tables._gt_data import Boxhead, ColInfo, ColInfoTypeEnum, SpannerInfo, Spanners
from great_tables._spanners import (
cols_hide,
Expand Down Expand Up @@ -144,6 +144,29 @@ def test_tab_spanners_overlap():
assert new_gt._spanners[1] == dst_span


def test_multiple_spanners_above_one():
from great_tables import GT, exibble

gt = (
GT(exibble, rowname_col="row", groupname_col="group")
.tab_spanner("A", ["num", "char", "fctr"])
.tab_spanner("B", ["fctr"])
.tab_spanner("C", ["num", "char"])
.tab_spanner("D", ["fctr", "date", "time"])
.tab_spanner("E", spanners=["B", "C"])
)

# Assert that the spanners have been added in the correct
# format and in the correct levels

assert len(gt._spanners) == 5
assert gt._spanners[0] == SpannerInfo("A", 0, "A", vars=["num", "char", "fctr"])
assert gt._spanners[1] == SpannerInfo("B", 1, "B", vars=["fctr"])
assert gt._spanners[2] == SpannerInfo("C", 1, "C", vars=["num", "char"])
assert gt._spanners[3] == SpannerInfo("D", 2, "D", vars=["fctr", "date", "time"])
assert gt._spanners[4] == SpannerInfo("E", 3, "E", vars=["fctr", "num", "char"])


def test_tab_spanners_with_gather():
df = pd.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]})
src_gt = GT(df)
Expand Down
36 changes: 35 additions & 1 deletion tests/test_utils_render_html.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import pandas as pd
import polars as pl
from great_tables import GT, exibble, html, loc, md, style
from great_tables._utils_render_html import create_body_component_h, create_source_notes_component_h
from great_tables._utils_render_html import (
create_body_component_h,
create_columns_component_h,
create_heading_component_h,
create_source_notes_component_h,
)

small_exibble = exibble[["num", "char"]].head(3)

Expand All @@ -13,6 +18,20 @@ def assert_rendered_source_notes(snapshot, gt):
assert snapshot == source_notes


def assert_rendered_heading(snapshot, gt):
built = gt._build_data("html")
heading = create_heading_component_h(built).make_string()

assert snapshot == heading


def assert_rendered_columns(snapshot, gt):
built = gt._build_data("html")
columns = create_columns_component_h(built)

assert snapshot == columns


def assert_rendered_body(snapshot, gt):
built = gt._build_data("html")
body = create_body_component_h(built)
Expand Down Expand Up @@ -157,3 +176,18 @@ def test_render_polars_list_col(snapshot):
gt = GT(pl.DataFrame({"x": [[1, 2]]}))

assert_rendered_body(snapshot, gt)


def test_multiple_spanners_pads_for_stubhead_label(snapshot):
# NOTE: see test_spanners.test_multiple_spanners_above_one
gt = (
GT(exibble, rowname_col="row", groupname_col="group")
.tab_spanner("A", ["num", "char", "fctr"])
.tab_spanner("B", ["fctr"])
.tab_spanner("C", ["num", "char"])
.tab_spanner("D", ["fctr", "date", "time"])
.tab_spanner("E", spanners=["B", "C"])
.tab_stubhead(label="Group")
)

assert_rendered_columns(snapshot, gt)
Loading