Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sensitivity analysis var name #300

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
89 changes: 48 additions & 41 deletions autoemulate/sensitivity_analysis.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Dict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from SALib.analyze.sobol import analyze
from SALib.sample.sobol import sample
from SALib.util import ResultDict

from autoemulate.utils import _ensure_2d

Expand Down Expand Up @@ -96,12 +99,14 @@ def _generate_problem(X):

return {
"num_vars": X.shape[1],
"names": [f"x{i+1}" for i in range(X.shape[1])],
"names": [f"X{i+1}" for i in range(X.shape[1])],
"bounds": [[X[:, i].min(), X[:, i].max()] for i in range(X.shape[1])],
}


def _sobol_analysis(model, problem=None, X=None, N=1024, conf_level=0.95):
def _sobol_analysis(
model, problem=None, X=None, N=1024, conf_level=0.95
) -> Dict[str, ResultDict]:
"""
Perform Sobol sensitivity analysis on a fitted emulator.

Expand Down Expand Up @@ -148,56 +153,58 @@ def _sobol_analysis(model, problem=None, X=None, N=1024, conf_level=0.95):
return results


def _sobol_results_to_df(results):
def _sobol_results_to_df(results: Dict[str, ResultDict]) -> pd.DataFrame:
"""
Convert Sobol results to a (long-format)pandas DataFrame.
Convert Sobol results to a (long-format) pandas DataFrame.

Parameters:
-----------
results : dict
The Sobol indices returned by sobol_analysis.
problem : dict, optional
The problem definition, including 'names'.

Returns:
--------
pd.DataFrame
A DataFrame with columns: 'output', 'parameter', 'index', 'value', 'confidence'.
"""
rename_dict = {
"variable": "index",
"S1": "value",
"S1_conf": "confidence",
"ST": "value",
"ST_conf": "confidence",
"S2": "value",
"S2_conf": "confidence",
}
rows = []
for output, indices in results.items():
for index_type in ["S1", "ST", "S2"]:
values = indices.get(index_type)
conf_values = indices.get(f"{index_type}_conf")
if values is None or conf_values is None:
continue

if index_type in ["S1", "ST"]:
rows.extend(
{
"output": output,
"parameter": f"X{i+1}",
"index": index_type,
"value": value,
"confidence": conf,
}
for i, (value, conf) in enumerate(zip(values, conf_values))
)

elif index_type == "S2":
n = values.shape[0]
rows.extend(
{
"output": output,
"parameter": f"X{i+1}-X{j+1}",
"index": index_type,
"value": values[i, j],
"confidence": conf_values[i, j],
}
for i in range(n)
for j in range(i + 1, n)
if not np.isnan(values[i, j])
)

return pd.DataFrame(rows)
for output, result in results.items():
s1, st, s2 = result.to_df()
s1 = (
s1.reset_index()
.rename(columns={"index": "parameter"})
.rename(columns=rename_dict)
)
s1["index"] = "S1"
st = (
st.reset_index()
.rename(columns={"index": "parameter"})
.rename(columns=rename_dict)
)
st["index"] = "ST"
s2 = (
s2.reset_index()
.rename(columns={"index": "parameter"})
.rename(columns=rename_dict)
)
s2["index"] = "S2"

df = pd.concat([s1, st, s2])
df["output"] = output
rows.append(df[["output", "parameter", "index", "value", "confidence"]])

return pd.concat(rows)


# plotting --------------------------------------------------------------------
Expand Down Expand Up @@ -241,7 +248,7 @@ def _create_bar_plot(ax, output_data, output_name):
ax.set_title(f"Output: {output_name}")


def _plot_sensitivity_analysis(results, index="S1", n_cols=None, figsize=None):
def _plot_sensitivity_analysis(results, problem, index="S1", n_cols=None, figsize=None):
"""
Plot the sensitivity analysis results.

Expand All @@ -263,7 +270,7 @@ def _plot_sensitivity_analysis(results, index="S1", n_cols=None, figsize=None):
"""
with plt.style.context("fast"):
# prepare data
results = _validate_input(results, index)
results = _validate_input(results, problem, index)
unique_outputs = results["output"].unique()
n_outputs = len(unique_outputs)

Expand Down
15 changes: 9 additions & 6 deletions tests/test_sensitivity_analysis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
import pandas as pd
import pytest
from sklearn.datasets import make_regression

from autoemulate.emulators import RandomForest
from autoemulate.experimental_design import LatinHypercube
Expand Down Expand Up @@ -150,7 +149,11 @@ def sobol_results_1d(model_1d):

# # test conversion to DataFrame --------------------------------------------------
@pytest.mark.filterwarnings("ignore::FutureWarning")
def test_sobol_results_to_df(sobol_results_1d):
@pytest.mark.parametrize(
"expected_names",
[["c", "v0", "c", "v0", ["c", "v0"]]],
)
def test_sobol_results_to_df(sobol_results_1d, expected_names):
df = _sobol_results_to_df(sobol_results_1d)
assert isinstance(df, pd.DataFrame)
assert df.columns.tolist() == [
Expand All @@ -160,7 +163,7 @@ def test_sobol_results_to_df(sobol_results_1d):
"value",
"confidence",
]
assert ["X1", "X2", "X1-X2"] in df["parameter"].unique()
assert expected_names == df["parameter"].to_list()
assert all(isinstance(x, float) for x in df["value"])
assert all(isinstance(x, float) for x in df["confidence"])

Expand All @@ -172,12 +175,12 @@ def test_sobol_results_to_df(sobol_results_1d):
@pytest.mark.filterwarnings("ignore::FutureWarning")
def test_validate_input(sobol_results_1d):
with pytest.raises(ValueError):
_validate_input(sobol_results_1d, "S3")
_validate_input(sobol_results_1d, index="S3")


@pytest.mark.filterwarnings("ignore::FutureWarning")
def test_validate_input_valid(sobol_results_1d):
Si = _validate_input(sobol_results_1d, "S1")
Si = _validate_input(sobol_results_1d, index="S1")
assert isinstance(Si, pd.DataFrame)


Expand Down Expand Up @@ -207,5 +210,5 @@ def test_generate_problem():
X = np.array([[0, 0], [1, 1], [2, 2]])
problem = _generate_problem(X)
assert problem["num_vars"] == 2
assert problem["names"] == ["x1", "x2"]
assert problem["names"] == ["X1", "X2"]
assert problem["bounds"] == [[0, 2], [0, 2]]
Loading