From 30f3cd3066fed1b3f0911d2ae97ca2933f53d954 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Mon, 4 Dec 2023 16:43:19 -0500 Subject: [PATCH 1/2] fix: spanners_print_matrix excludes selected cols that are in the stub --- great_tables/_spanners.py | 6 +++++- tests/test_spanners.py | 11 +++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/great_tables/_spanners.py b/great_tables/_spanners.py index 2481a1620..911207373 100644 --- a/great_tables/_spanners.py +++ b/great_tables/_spanners.py @@ -502,7 +502,11 @@ def spanners_print_matrix( for span_ii, span in enumerate(crnt_spans): for var in span.vars: - label_matrix[span.spanner_level][var] = spanner_reprs[span_ii] + # This if clause skips spanned columns that are not in the + # boxhead vars we are planning to use (e.g. not in the visible ones + # or in the stub). + if var in label_matrix[span.spanner_level]: + label_matrix[span.spanner_level][var] = spanner_reprs[span_ii] # reverse order , so if you were to print it out, level 0 would appear on the bottom label_matrix.reverse() diff --git a/tests/test_spanners.py b/tests/test_spanners.py index c5afd60f2..4a2e3a56d 100644 --- a/tests/test_spanners.py +++ b/tests/test_spanners.py @@ -74,6 +74,17 @@ def test_spanners_print_matrix_arg_include_hidden(spanners, boxhead): ] +def test_spanners_print_matrix_exclude_stub(): + """spanners_print_matrix omits a selected column if it's in the stub.""" + info = SpannerInfo(spanner_id="a", spanner_level=0, vars=["x", "y"], built="A") + spanners = Spanners([info]) + boxh = Boxhead([ColInfo(var="x"), ColInfo(var="y", type=ColInfoTypeEnum.stub)]) + + mat, vars = spanners_print_matrix(spanners, boxh, omit_columns_row=True) + assert vars == ["x"] + assert mat == [{"x": "A"}] + + def test_empty_spanner_matrix(): mat, vars = empty_spanner_matrix(["a", "b"], omit_columns_row=False) From 48809a54a04bf52a00ea2ec182f442e0033a3277 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Mon, 4 Dec 2023 17:16:40 -0500 Subject: [PATCH 2/2] docs: use polars replace method instead --- docs/examples/index.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/examples/index.qmd b/docs/examples/index.qmd index ee4932be6..32870d6c0 100644 --- a/docs/examples/index.qmd +++ b/docs/examples/index.qmd @@ -103,7 +103,7 @@ wide_pops = ( pl.col("country_code_2").is_in(list(region_to_country)) & pl.col("year").is_in([2000, 2010, 2020]) ) - .with_columns(pl.col("country_code_2").map_dict(region_to_country).alias("region")) + .with_columns(pl.col("country_code_2").replace(region_to_country).alias("region")) .pivot(index=["country_name", "region"], columns="year", values="population") .sort("2020", descending=True) )