Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correlation plot with nans #1365

Merged
merged 8 commits into from
May 2, 2024
13 changes: 10 additions & 3 deletions pypesto/visualize/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def process_start_indices(
start_indices = ALL
if isinstance(start_indices, str):
if start_indices == ALL:
return np.asarray(range(len(result.optimize_result)))
start_indices = np.asarray(range(len(result.optimize_result)))
elif start_indices == ALL_CLUSTERED:
clust_ind, clust_size = assign_clusters(
delete_nan_inf(result.optimize_result.fval)[1]
Expand All @@ -336,12 +336,12 @@ def process_start_indices(
start_indices = np.concatenate(
[np.where(clust_ind == i_clust)[0] for i_clust in clust_gr2]
)
return start_indices
start_indices = start_indices
elif start_indices == FIRST_CLUSTER:
clust_ind = assign_clusters(
delete_nan_inf(result.optimize_result.fval)[1]
)[0]
return np.where(clust_ind == 0)[0]
start_indices = np.where(clust_ind == 0)[0]
else:
raise ValueError(
f"Permissible values for start_indices are {ALL}, "
Expand All @@ -359,6 +359,13 @@ def process_start_indices(
if start_index < len(result.optimize_result)
]

# filter out the indices that are not finite
start_indices = [
start_index
for start_index in start_indices
if np.isfinite(result.optimize_result[start_index].fval)
]
Comment on lines +370 to +371
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check for np.isnan here as well?

Copy link
Contributor

@stephanmg stephanmg Apr 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.isfinite tests for Not a Number already, in case of NaN returns False.


return np.asarray(start_indices)


Expand Down
12 changes: 12 additions & 0 deletions test/visualize/test_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,3 +1179,15 @@ def test_sacess_history():
)
sacess.minimize(problem)
sacess_history(sacess.histories)


@pytest.mark.parametrize(
"result_creation",
[create_optimization_result, create_optimization_result_nan_inf],
)
@close_fig
def test_parameters_correlation_matrix(result_creation):
"""Test pypesto.visualize.parameters_correlation_matrix"""
result = result_creation()

visualize.parameters_correlation_matrix(result)
Loading