Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Group tree dataprovider #902

Merged
merged 13 commits into from
Jan 13, 2022
79 changes: 65 additions & 14 deletions tests/unit_tests/plugin_tests/test_grouptree.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
import datetime
from pathlib import Path
from typing import List, Tuple

import pandas as pd
import pytest
from _pytest.fixtures import SubRequest

from webviz_subsurface.plugins._group_tree.group_tree_data import add_nodetype_for_ens
from webviz_subsurface._providers import EnsembleSummaryProvider
from webviz_subsurface._providers.ensemble_summary_provider._provider_impl_arrow_presampled import (
ProviderImplArrowPresampled,
)
from webviz_subsurface.plugins._group_tree._ensemble_group_tree_data import add_nodetype

ADD_NODETYPE_CASES = [
# Group leaf nodes:
# NODE1 has summary data>0 and will be classified as prod and inj
# NODE2 has summary data==0 and will be classified as other
# NODE3 has no summary data and will be classified as other
# FIELD and TMPL are classified as all three types
pytest.param(
(
pd.DataFrame(
columns=["DATE", "CHILD", "KEYWORD", "PARENT"],
data=[
Expand All @@ -23,9 +30,16 @@
],
),
pd.DataFrame(
columns=["DATE", "GGPR:NODE1", "GGIR:NODE1", "GGPR:NODE2", "GGIR:NODE2"],
columns=[
"DATE",
"REAL",
"GGPR:NODE1",
"GGIR:NODE1",
"GGPR:NODE2",
"GGIR:NODE2",
],
data=[
[datetime.date(2000, 1, 1), 1, 1, 0, 0],
[datetime.date(2000, 1, 1), 0, 1, 1, 0, 0],
],
),
pd.DataFrame(
Expand All @@ -46,7 +60,6 @@
["2000-01-01", "NODE3", "GRUPTREE", "TMPL", False, False, True],
],
),
id="add-nodetype-for-group-leaf-nodes",
),
# Well leaf nodes:
# WELL1 has WSTAT==1 and will be classified as producer
Expand All @@ -55,7 +68,7 @@
# WELL4 has WSTAT==0 and will be classified as other
# TMPL_A is classified as prod and inj
# TMPL_B is classified as other
pytest.param(
(
pd.DataFrame(
columns=["DATE", "CHILD", "KEYWORD", "PARENT"],
data=[
Expand All @@ -71,14 +84,15 @@
pd.DataFrame(
columns=[
"DATE",
"REAL",
"WSTAT:WELL1",
"WSTAT:WELL2",
"WSTAT:WELL3",
"WSTAT:WELL4",
],
data=[
[datetime.date(2000, 1, 1), 1, 2, 1, 0],
[datetime.date(2000, 2, 1), 1, 2, 2, 0],
[datetime.date(2000, 1, 1), 0, 1, 2, 1, 0],
[datetime.date(2000, 2, 1), 0, 1, 2, 2, 0],
],
),
pd.DataFrame(
Expand All @@ -101,14 +115,45 @@
["2000-01-01", "WELL4", "WELSPECS", "TMPL_B", False, False, True],
],
),
id="add-nodetype-for-well-leaf-nodes",
),
]


@pytest.mark.parametrize("gruptree, smry, expected", ADD_NODETYPE_CASES)
def test_add_nodetype(gruptree, smry, expected):
"""Test functionality of the add_nodetype_for_ens function"""
@pytest.fixture(
name="testdata",
params=ADD_NODETYPE_CASES,
)
def fixture_provider(
request: SubRequest, tmp_path: Path
) -> Tuple[pd.DataFrame, EnsembleSummaryProvider, pd.DataFrame]:

input_py = request.param
storage_dir = tmp_path
gruptree_df = input_py[0]
smry_df = input_py[1]
expected_df = input_py[2]

ProviderImplArrowPresampled.write_backing_store_from_ensemble_dataframe(
storage_dir, "dummy_key", smry_df
)
new_provider = ProviderImplArrowPresampled.from_backing_store(
storage_dir, "dummy_key"
)

if not new_provider:
raise ValueError("Failed to create EnsembleSummaryProvider")

return gruptree_df, new_provider, expected_df


def test_add_nodetype(
testdata: Tuple[pd.DataFrame, EnsembleSummaryProvider, pd.DataFrame]
) -> None:
"""Test functionality for the add_nodetype function"""
gruptree_df = testdata[0]
provider = testdata[1]
expected_df = testdata[2]

columns_to_check = [
"DATE",
"CHILD",
Expand All @@ -118,6 +163,12 @@ def test_add_nodetype(gruptree, smry, expected):
"IS_INJ",
"IS_OTHER",
]
output = add_nodetype_for_ens(gruptree, smry)

pd.testing.assert_frame_equal(output[columns_to_check], expected[columns_to_check])
wells: List[str] = gruptree_df[gruptree_df["KEYWORD"] == "WELSPECS"][
"CHILD"
].unique()

output = add_nodetype(gruptree_df, provider, wells)
pd.testing.assert_frame_equal(
output[columns_to_check], expected_df[columns_to_check]
)
2 changes: 1 addition & 1 deletion webviz_subsurface/plugins/_group_tree/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .group_tree import GroupTree
from ._plugin import GroupTree
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import webviz_subsurface_components
from dash.dependencies import Input, Output, State

from ..group_tree_data import GroupTreeData
from ._ensemble_group_tree_data import EnsembleGroupTreeData


def controllers(
app: dash.Dash, get_uuid: Callable, grouptreedata: GroupTreeData
def plugin_callbacks(
app: dash.Dash,
get_uuid: Callable,
group_tree_data: Dict[str, EnsembleGroupTreeData],
) -> None:
@app.callback(
Output({"id": get_uuid("controls"), "element": "tree_mode"}, "options"),
Expand All @@ -22,7 +24,10 @@ def controllers(
State({"id": get_uuid("options"), "element": "realization"}, "value"),
)
def _update_ensemble_options(
ensemble: str, tree_mode_state: str, stat_option_state: str, real_state: int
ensemble_name: str,
tree_mode_state: str,
stat_option_state: str,
real_state: int,
) -> Tuple[List[Dict[str, Any]], str, str, List[Dict[str, Any]], Optional[int]]:
"""Updates the selection options when the ensemble value changes"""
tree_mode_options: List[Dict[str, Any]] = [
Expand All @@ -41,13 +46,13 @@ def _update_ensemble_options(
stat_option_value = (
stat_option_state if stat_option_state is not None else "mean"
)

if not grouptreedata.tree_is_equivalent_in_all_real(ensemble):
ensemble = group_tree_data[ensemble_name]
if not ensemble.tree_is_equivalent_in_all_real():
tree_mode_options[0]["label"] = "Ensemble mean (disabled)"
tree_mode_options[0]["disabled"] = True
tree_mode_value = "single_real"

unique_real = grouptreedata.get_ensemble_unique_real(ensemble)
unique_real = ensemble.get_unique_real()

return (
tree_mode_options,
Expand All @@ -66,12 +71,16 @@ def _update_ensemble_options(
State({"id": get_uuid("controls"), "element": "ensemble"}, "value"),
)
def _render_grouptree(
tree_mode: str, stat_option: str, real: int, prod_inj_other: list, ensemble: str
tree_mode: str,
stat_option: str,
real: int,
prod_inj_other: list,
ensemble_name: str,
) -> list:
"""This callback updates the input dataset to the Grouptree component."""
data, edge_options, node_options = grouptreedata.create_grouptree_dataset(
ensemble, tree_mode, stat_option, real, prod_inj_other
)
data, edge_options, node_options = group_tree_data[
ensemble_name
].create_grouptree_dataset(tree_mode, stat_option, real, prod_inj_other)

return [
webviz_subsurface_components.GroupTree(
Expand Down
Loading