Skip to content

Commit

Permalink
Correlation plot with nans (#1365)
Browse files Browse the repository at this point in the history
* Added option to sample startpoints of a problem, from the problem directly. Also safety checks for startindices like "all" or "clustered"

* Fixed error in correlation matrix, if there were nn values. Also safety checks in general for start indices. Added a test for correlation matrix

* log filtering in process_start_indices

* update docstring

---------

Co-authored-by: Maren Philipps <[email protected]>
  • Loading branch information
PaulJonasJost and m-philipps authored May 2, 2024
1 parent 3edf785 commit b95f422
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 19 deletions.
25 changes: 20 additions & 5 deletions pypesto/visualize/misc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import warnings
from collections.abc import Iterable
from numbers import Number
Expand All @@ -23,6 +24,8 @@
from ..util import assign_clusters, delete_nan_inf
from .clust_color import assign_colors_for_list

logger = logging.getLogger(__name__)


def process_result_list(
results: Union[Result, list[Result]], colors=None, legends=None
Expand Down Expand Up @@ -303,8 +306,8 @@ def process_start_indices(
"""
Process the start_indices.
Create an array of indices if a number was provided and checks that the
indices do not exceed the max_index.
Create an array of indices if a number was provided, checks that the indices
do not exceed the max_index and removes starts with non-finite fval.
Parameters
----------
Expand All @@ -323,7 +326,7 @@ def process_start_indices(
start_indices = ALL
if isinstance(start_indices, str):
if start_indices == ALL:
return np.asarray(range(len(result.optimize_result)))
start_indices = np.asarray(range(len(result.optimize_result)))
elif start_indices == ALL_CLUSTERED:
clust_ind, clust_size = assign_clusters(
delete_nan_inf(result.optimize_result.fval)[1]
Expand All @@ -336,12 +339,12 @@ def process_start_indices(
start_indices = np.concatenate(
[np.where(clust_ind == i_clust)[0] for i_clust in clust_gr2]
)
return start_indices
start_indices = start_indices
elif start_indices == FIRST_CLUSTER:
clust_ind = assign_clusters(
delete_nan_inf(result.optimize_result.fval)[1]
)[0]
return np.where(clust_ind == 0)[0]
start_indices = np.where(clust_ind == 0)[0]
else:
raise ValueError(
f"Permissible values for start_indices are {ALL}, "
Expand All @@ -359,6 +362,18 @@ def process_start_indices(
if start_index < len(result.optimize_result)
]

# filter out the indices that are not finite
start_indices_unfiltered = len(start_indices)
start_indices = [
start_index
for start_index in start_indices
if np.isfinite(result.optimize_result[start_index].fval)
]
if len(start_indices) != start_indices_unfiltered:
logger.warning(
"Some start indices were removed due to inf or nan function values."
)

return np.asarray(start_indices)


Expand Down
15 changes: 1 addition & 14 deletions pypesto/visualize/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,23 +601,10 @@ def optimization_scatter(
parameter_indices = process_parameter_indices(
parameter_indices=parameter_indices, result=result
)
# remove all start indices that encounter an inf value at the start
# resulting in optimize_result[start]["x"] being None
start_indices_finite = start_indices[
[
result.optimize_result[i_start]["x"] is not None
for i_start in start_indices
]
]
# compare start_indices with start_indices_finite and log a warning
if len(start_indices) != len(start_indices_finite):
logger.warning(
"Some start indices were removed due to inf values at the start."
)
# put all parameters into a dataframe, where columns are parameters
parameters = [
result.optimize_result[i_start]["x"][parameter_indices]
for i_start in start_indices_finite
for i_start in start_indices
]
x_labels = [
result.problem.x_names[parameter_index]
Expand Down
12 changes: 12 additions & 0 deletions test/visualize/test_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,3 +1179,15 @@ def test_sacess_history():
)
sacess.minimize(problem)
sacess_history(sacess.histories)


@pytest.mark.parametrize(
"result_creation",
[create_optimization_result, create_optimization_result_nan_inf],
)
@close_fig
def test_parameters_correlation_matrix(result_creation):
"""Test pypesto.visualize.parameters_correlation_matrix"""
result = result_creation()

visualize.parameters_correlation_matrix(result)

0 comments on commit b95f422

Please sign in to comment.