Skip to content

Commit

Permalink
Ensure columns hidden with pandas .hide() works as expected (#10383)
Browse files Browse the repository at this point in the history
* check  hidden_columns

* add changeset

* add test

* lint

* lint

* fix test

* Update gradio/components/dataframe.py

Co-authored-by: Abubakar Abid <[email protected]>

* lint

* lint

* lint

* claude said this will fix the test

* lint

---------

Co-authored-by: gradio-pr-bot <[email protected]>
Co-authored-by: Abubakar Abid <[email protected]>
  • Loading branch information
3 people authored Jan 17, 2025
1 parent 2b7ba48 commit 9517043
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 17 deletions.
5 changes: 5 additions & 0 deletions .changeset/full-cups-melt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": patch
---

fix:Ensure columns hidden with pandas `.hide()` works as expected
57 changes: 40 additions & 17 deletions gradio/components/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def postprocess(
) -> DataframeData:
"""
Parameters:
value: Expects data any of these formats: `pandas.DataFrame`, `pandas.Styler`, `numpy.array`, `polars.DataFrame`, `list[list]`, `list`, or a `dict` with keys 'data' (and optionally 'headers'), or `str` path to a csv, which is rendered as the spreadsheet.
value: Expects data in any of these formats: `pandas.DataFrame`, `pandas.Styler`, `numpy.array`, `polars.DataFrame`, `list[list]`, `list`, or a `dict` with keys 'data' (and optionally 'headers'), or `str` path to a csv, which is rendered as the spreadsheet.
Returns:
the uploaded spreadsheet data as an object with `headers` and `data` keys and optional `metadata` key
"""
Expand All @@ -286,10 +286,14 @@ def postprocess(
)

if value is None or self._is_empty(value):
return self.postprocess(self.empty_input)
return DataframeData(
headers=self.headers, data=[["" for _ in range(len(self.headers))]]
)
if isinstance(value, dict):
if len(value) == 0:
return DataframeData(headers=self.headers, data=[[]])
return DataframeData(
headers=self.headers, data=[["" for _ in range(len(self.headers))]]
)
return DataframeData(
headers=value.get("headers", []), data=value.get("data", [[]])
)
Expand All @@ -299,28 +303,39 @@ def postprocess(
if len(value) == 0:
return DataframeData(
headers=[str(col) for col in value.columns], # Convert to strings
data=[[]], # type: ignore
data=[["" for _ in range(len(value.columns))]],
)
return DataframeData(
headers=[str(col) for col in value.columns], # Convert to strings
data=value.to_dict(orient="split")["data"], # type: ignore
headers=[str(col) for col in value.columns],
data=value.to_dict(orient="split")["data"],
)
elif isinstance(value, Styler):
if self.interactive:
warnings.warn(
"Cannot display Styler object in interactive mode. Will display as a regular pandas dataframe instead."
)
df: pd.DataFrame = value.data # type: ignore
visible_cols = [
i
for i, col in enumerate(df.columns)
if i not in getattr(value, "hidden_columns", [])
]
df = df.iloc[:, visible_cols]

if len(df) == 0:
return DataframeData(
headers=list(df.columns),
data=[[]],
metadata=self.__extract_metadata(value), # type: ignore
data=[["" for _ in range(len(df.columns))]],
metadata=self.__extract_metadata(
value, getattr(value, "hidden_columns", [])
), # type: ignore
)
return DataframeData(
headers=list(df.columns),
data=df.to_dict(orient="split")["data"], # type: ignore
metadata=self.__extract_metadata(value), # type: ignore
metadata=self.__extract_metadata(
value, getattr(value, "hidden_columns", [])
), # type: ignore
)
elif _is_polars_available() and isinstance(value, _import_polars().DataFrame):
if len(value) == 0:
Expand Down Expand Up @@ -360,22 +375,30 @@ def __get_cell_style(cell_id: str, cell_styles: list[dict]) -> str:
return styles_str

@staticmethod
def __extract_metadata(df: Styler) -> dict[str, list[list]]:
def __extract_metadata(
df: Styler, hidden_cols: list[int] | None = None
) -> dict[str, list[list]]:
metadata = {"display_value": [], "styling": []}
style_data = df._compute()._translate(None, None) # type: ignore
cell_styles = style_data.get("cellstyle", [])
hidden_cols = hidden_cols if hidden_cols is not None else []
for i in range(len(style_data["body"])):
metadata["display_value"].append([])
metadata["styling"].append([])
row_display = []
row_styling = []
col_idx = 0
for j in range(len(style_data["body"][i])):
cell_type = style_data["body"][i][j]["type"]
if cell_type != "td":
continue
display_value = style_data["body"][i][j]["display_value"]
cell_id = style_data["body"][i][j]["id"]
styles_str = Dataframe.__get_cell_style(cell_id, cell_styles)
metadata["display_value"][i].append(display_value)
metadata["styling"][i].append(styles_str)
if col_idx not in hidden_cols:
display_value = style_data["body"][i][j]["display_value"]
cell_id = style_data["body"][i][j]["id"]
styles_str = Dataframe.__get_cell_style(cell_id, cell_styles)
row_display.append(display_value)
row_styling.append(styles_str)
col_idx += 1
metadata["display_value"].append(row_display)
metadata["styling"].append(row_styling)
return metadata

@staticmethod
Expand Down
29 changes: 29 additions & 0 deletions test/components/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,32 @@ def test_dataframe_postprocess_styler(self):
],
},
}

def test_dataframe_hidden_columns(self):
"""Test that hidden columns are properly excluded from the output"""
component = gr.Dataframe()
df = pd.DataFrame(
{"a": [1, 2, 3], "b": [4, 5, 6], "color": ["red", "blue", "green"]}
)
styled_df = df.style.hide(axis=1, subset=["color"])
output = component.postprocess(styled_df).model_dump()
assert output == {
"data": [
[1, 4],
[2, 5],
[3, 6],
],
"headers": ["a", "b"],
"metadata": {
"display_value": [
["1", "4"],
["2", "5"],
["3", "6"],
],
"styling": [
["", ""],
["", ""],
["", ""],
],
},
}

0 comments on commit 9517043

Please sign in to comment.