Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correlation plot with nans #1365

Merged
merged 8 commits into from
May 2, 2024
25 changes: 20 additions & 5 deletions pypesto/visualize/misc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import warnings
from collections.abc import Iterable
from numbers import Number
Expand All @@ -23,6 +24,8 @@
from ..util import assign_clusters, delete_nan_inf
from .clust_color import assign_colors_for_list

logger = logging.getLogger(__name__)


def process_result_list(
results: Union[Result, list[Result]], colors=None, legends=None
Expand Down Expand Up @@ -303,8 +306,8 @@ def process_start_indices(
"""
Process the start_indices.

Create an array of indices if a number was provided and checks that the
indices do not exceed the max_index.
Create an array of indices if a number was provided, checks that the indices
do not exceed the max_index and removes starts with non-finite fval.

Parameters
----------
Expand All @@ -323,7 +326,7 @@ def process_start_indices(
start_indices = ALL
if isinstance(start_indices, str):
if start_indices == ALL:
return np.asarray(range(len(result.optimize_result)))
start_indices = np.asarray(range(len(result.optimize_result)))
elif start_indices == ALL_CLUSTERED:
clust_ind, clust_size = assign_clusters(
delete_nan_inf(result.optimize_result.fval)[1]
Expand All @@ -336,12 +339,12 @@ def process_start_indices(
start_indices = np.concatenate(
[np.where(clust_ind == i_clust)[0] for i_clust in clust_gr2]
)
return start_indices
start_indices = start_indices
elif start_indices == FIRST_CLUSTER:
clust_ind = assign_clusters(
delete_nan_inf(result.optimize_result.fval)[1]
)[0]
return np.where(clust_ind == 0)[0]
start_indices = np.where(clust_ind == 0)[0]
else:
raise ValueError(
f"Permissible values for start_indices are {ALL}, "
Expand All @@ -359,6 +362,18 @@ def process_start_indices(
if start_index < len(result.optimize_result)
]

# filter out the indices that are not finite
start_indices_unfiltered = len(start_indices)
start_indices = [
start_index
for start_index in start_indices
if np.isfinite(result.optimize_result[start_index].fval)
]
Comment on lines +370 to +371
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check for np.isnan here as well?

Copy link
Contributor

@stephanmg stephanmg Apr 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.isfinite tests for Not a Number already, in case of NaN returns False.

if len(start_indices) != start_indices_unfiltered:
logger.warning(
"Some start indices were removed due to inf or nan function values."
)

return np.asarray(start_indices)


Expand Down
15 changes: 1 addition & 14 deletions pypesto/visualize/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,23 +601,10 @@ def optimization_scatter(
parameter_indices = process_parameter_indices(
parameter_indices=parameter_indices, result=result
)
# remove all start indices that encounter an inf value at the start
# resulting in optimize_result[start]["x"] being None
start_indices_finite = start_indices[
[
result.optimize_result[i_start]["x"] is not None
for i_start in start_indices
]
]
# compare start_indices with start_indices_finite and log a warning
if len(start_indices) != len(start_indices_finite):
logger.warning(
"Some start indices were removed due to inf values at the start."
)
# put all parameters into a dataframe, where columns are parameters
parameters = [
result.optimize_result[i_start]["x"][parameter_indices]
for i_start in start_indices_finite
for i_start in start_indices
]
x_labels = [
result.problem.x_names[parameter_index]
Expand Down
12 changes: 12 additions & 0 deletions test/visualize/test_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,3 +1179,15 @@ def test_sacess_history():
)
sacess.minimize(problem)
sacess_history(sacess.histories)


@pytest.mark.parametrize(
"result_creation",
[create_optimization_result, create_optimization_result_nan_inf],
)
@close_fig
def test_parameters_correlation_matrix(result_creation):
"""Test pypesto.visualize.parameters_correlation_matrix"""
result = result_creation()

visualize.parameters_correlation_matrix(result)
Loading