Skip to content

Commit

Permalink
feat(ui): bind grid selection to query parameters
Browse files Browse the repository at this point in the history
the selected row is bound to the query parameters to make it possible to
share a direct URL to a detail view.
  • Loading branch information
luis-dk committed Oct 25, 2024
1 parent 86a04df commit 7555216
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 14 deletions.
73 changes: 64 additions & 9 deletions testgen/ui/services/form_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import testgen.common.date_service as date_service
import testgen.ui.services.authentication_service as authentication_service
import testgen.ui.services.database_service as db
from testgen.ui.navigation.router import Router

"""
Shared rendering of UI elements
Expand Down Expand Up @@ -762,14 +763,31 @@ def render_insert_form(


def render_grid_select(
df,
df: pd.DataFrame,
show_columns,
str_prompt=None,
int_height=400,
do_multi_select=False,
do_multi_select: bool | None = None,
selection_mode: typing.Literal["single", "multiple", "disabled"] = "single",
show_column_headers=None,
render_highlights=True,
bind_to_query_name: str | None = None,
bind_to_query_prop: str | None = None,
key: str = "aggrid",
):
"""
:param do_multi_select: DEPRECATED. boolean to choose between single
or multiple selection.
:param selection_mode: one of single, multiple or disabled. defaults
to single.
:param bind_to_query_name: name of the query param where to bind the
selected row.
:param bind_to_query_prop: name of the property of the selected row
which value will be set in the query param.
:param key: Streamlit cache key for the grid. required when binding
selection to query.
"""

show_prompt(str_prompt)

# Set grid formatting
Expand Down Expand Up @@ -837,12 +855,40 @@ def render_grid_select(
}
"""
)
data_changed: bool = True
rendering_counter = st.session_state.get(f"{key}_counter") or 0
previous_dataframe = st.session_state.get(f"{key}_dataframe")

if previous_dataframe is not None:
data_changed = not df.equals(previous_dataframe)

dct_col_to_header = dict(zip(show_columns, show_column_headers, strict=True)) if show_column_headers else None

gb = GridOptionsBuilder.from_dataframe(df)
selection_mode = "multiple" if do_multi_select else "single"
gb.configure_selection(selection_mode=selection_mode, use_checkbox=do_multi_select)
selection_mode_ = selection_mode
if do_multi_select is not None:
selection_mode_ = "multiple" if do_multi_select else "single"

pre_selected_rows: typing.Any = {}
if bind_to_query_name and bind_to_query_prop:
bound_value = st.query_params.get(bind_to_query_name)
bound_items_indexes = df[df[bind_to_query_prop] == bound_value].index
if len(bound_items_indexes) > 0:
# https://github.com/PablocFonseca/streamlit-aggrid/issues/207#issuecomment-1793039564
pre_selected_rows = {str(bound_items_indexes[0]): True}
else:
if data_changed and st.query_params.get(bind_to_query_name):
rendering_counter += 1
Router().set_query_params({bind_to_query_name: None})

gb.configure_selection(
selection_mode=selection_mode_,
use_checkbox=selection_mode_ == "multiple",
pre_selected_rows=pre_selected_rows,
)

if bind_to_query_prop and bind_to_query_prop.isalnum():
gb.configure_grid_options(getRowId=JsCode(f"""function(row) {{ return row.data.{bind_to_query_prop}; }}"""))

all_columns = list(df.columns)

Expand All @@ -853,8 +899,8 @@ def render_grid_select(
"field": column,
"header_name": str_header if str_header else ut_prettify_header(column),
"hide": column not in show_columns,
"headerCheckboxSelection": do_multi_select and column == show_columns[0],
"headerCheckboxSelectionFilteredOnly": do_multi_select and column == show_columns[0],
"headerCheckboxSelection": selection_mode_ == "multiple" and column == show_columns[0],
"headerCheckboxSelectionFilteredOnly": selection_mode_ == "multiple" and column == show_columns[0],
}
highlight_kwargs = {"cellStyle": cellstyle_jscode}

Expand Down Expand Up @@ -888,7 +934,8 @@ def render_grid_select(
theme="balham",
enable_enterprise_modules=False,
allow_unsafe_jscode=True,
update_mode=GridUpdateMode.SELECTION_CHANGED,
update_mode=GridUpdateMode.NO_UPDATE,
update_on=["selectionChanged"],
data_return_mode=DataReturnMode.FILTERED_AND_SORTED,
columns_auto_size_mode=ColumnsAutoSizeMode.FIT_CONTENTS,
height=int_height,
Expand All @@ -897,10 +944,18 @@ def render_grid_select(
"padding-bottom": "0px !important",
}
},
key=f"{key}_{selection_mode_}_{rendering_counter}",
reload_data=data_changed,
)

if len(grid_data["selected_rows"]):
return grid_data["selected_rows"]
st.session_state[f"{key}_counter"] = rendering_counter
st.session_state[f"{key}_dataframe"] = df

selected_rows = grid_data["selected_rows"]
if len(selected_rows) > 0:
if bind_to_query_name and bind_to_query_prop:
Router().set_query_params({bind_to_query_name: selected_rows[0][bind_to_query_prop]})
return selected_rows


def render_logo(logo_path: str = logo_file):
Expand Down
8 changes: 7 additions & 1 deletion testgen/ui/views/profiling_anomalies.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def render(self, run_id: str, issue_class: str | None = None, issue_type: str |
f"Profiling run with ID '{run_id}' does not exist. Redirecting to list of Profiling Runs ...",
"profiling-runs",
)
return

run_date, _table_group_id, table_group_name, project_code = run_parentage
run_date = date_service.get_timezoned_timestamp(st.session_state, run_date)
Expand Down Expand Up @@ -130,7 +131,12 @@ def render(self, run_id: str, issue_class: str | None = None, issue_type: str |

# Show main grid and retrieve selections
selected = fm.render_grid_select(
df_pa, lst_show_columns, int_height=400, do_multi_select=do_multi_select
df_pa,
lst_show_columns,
int_height=400,
do_multi_select=do_multi_select,
bind_to_query_name="selected",
bind_to_query_prop="id",
)

with export_button_column:
Expand Down
8 changes: 7 additions & 1 deletion testgen/ui/views/profiling_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def render(self, run_id: str, table_name: str | None = None, column_name: str |
f"Profiling run with ID '{run_id}' does not exist. Redirecting to list of Profiling Runs ...",
"profiling-runs",
)
return

run_date, table_group_id, table_group_name, project_code = run_parentage
run_date = date_service.get_timezoned_timestamp(st.session_state, run_date)
Expand Down Expand Up @@ -105,7 +106,12 @@ def render(self, run_id: str, table_name: str | None = None, column_name: str |
with st.expander("📜 **Table CREATE script with suggested datatypes**"):
st.code(generate_create_script(df), "sql")

selected_row = fm.render_grid_select(df, show_columns)
selected_row = fm.render_grid_select(
df,
show_columns,
bind_to_query_name="selected",
bind_to_query_prop="id",
)

with export_button_column:
testgen.flex_row_end()
Expand Down
2 changes: 2 additions & 0 deletions testgen/ui/views/test_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,8 @@ def show_test_defs_grid(
do_multi_select=do_multi_select,
show_column_headers=show_column_headers,
render_highlights=False,
bind_to_query_name="selected",
bind_to_query_prop="id",
)

with export_container:
Expand Down
12 changes: 9 additions & 3 deletions testgen/ui/views/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def render(self, run_id: str, status: str | None = None, test_type: str | None =
f"Test run with ID '{run_id}' does not exist. Redirecting to list of Test Runs ...",
"test-runs",
)
return

run_date, test_suite_name, project_code = run_parentage
run_date = date_service.get_timezoned_timestamp(st.session_state, run_date)
Expand Down Expand Up @@ -478,7 +479,12 @@ def show_result_detail(str_run_id, str_sel_test_status, test_type_id, sorting_co
]

selected_rows = fm.render_grid_select(
df, lst_show_columns, do_multi_select=do_multi_select, show_column_headers=lst_show_headers
df,
lst_show_columns,
do_multi_select=do_multi_select,
show_column_headers=lst_show_headers,
bind_to_query_name="selected",
bind_to_query_prop="test_result_id",
)

with export_container:
Expand Down Expand Up @@ -523,7 +529,7 @@ def show_result_detail(str_run_id, str_sel_test_status, test_type_id, sorting_co
if not selected_rows:
st.markdown(":orange[Select a record to see more information.]")
else:
selected_row = selected_rows[len(selected_rows) - 1]
selected_row = selected_rows[0]
dfh = get_test_result_history(selected_row)
show_hist_columns = ["test_date", "threshold_value", "result_measure", "result_status"]

Expand Down Expand Up @@ -582,7 +588,7 @@ def show_result_detail(str_run_id, str_sel_test_status, test_type_id, sorting_co
fm.show_subheader(selected_row["test_name_short"])
st.markdown(f"###### {selected_row['test_description']}")
st.caption(empty_if_null(selected_row["measure_uom_description"]))
fm.render_grid_select(dfh, show_hist_columns)
fm.render_grid_select(dfh, show_hist_columns, selection_mode="disabled")
with pg_col2:
ut_tab1, ut_tab2 = st.tabs(["History", "Test Definition"])
with ut_tab1:
Expand Down

0 comments on commit 7555216

Please sign in to comment.