From 446002571cba290a8cc0567e8fb60be74784d348 Mon Sep 17 00:00:00 2001 From: PaulJonasJost Date: Tue, 9 Apr 2024 15:27:37 +0200 Subject: [PATCH 1/4] Added option to sample startpoints of a problem, from the problem directly. Also safety checks for startindices like "all" or "clustered" --- pypesto/visualize/misc.py | 13 ++++++++++--- test/visualize/test_visualize.py | 12 ++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/pypesto/visualize/misc.py b/pypesto/visualize/misc.py index d23f710b5..c08c3e816 100644 --- a/pypesto/visualize/misc.py +++ b/pypesto/visualize/misc.py @@ -323,7 +323,7 @@ def process_start_indices( start_indices = ALL if isinstance(start_indices, str): if start_indices == ALL: - return np.asarray(range(len(result.optimize_result))) + start_indices = np.asarray(range(len(result.optimize_result))) elif start_indices == ALL_CLUSTERED: clust_ind, clust_size = assign_clusters( delete_nan_inf(result.optimize_result.fval)[1] @@ -336,12 +336,12 @@ def process_start_indices( start_indices = np.concatenate( [np.where(clust_ind == i_clust)[0] for i_clust in clust_gr2] ) - return start_indices + start_indices = start_indices elif start_indices == FIRST_CLUSTER: clust_ind = assign_clusters( delete_nan_inf(result.optimize_result.fval)[1] )[0] - return np.where(clust_ind == 0)[0] + start_indices = np.where(clust_ind == 0)[0] else: raise ValueError( f"Permissible values for start_indices are {ALL}, " @@ -359,6 +359,13 @@ def process_start_indices( if start_index < len(result.optimize_result) ] + # filter out the indices that are not finite + start_indices = [ + start_index + for start_index in start_indices + if np.isfinite(result.optimize_result[start_index].fval) + ] + return np.asarray(start_indices) diff --git a/test/visualize/test_visualize.py b/test/visualize/test_visualize.py index 79aee2f30..e74f22443 100644 --- a/test/visualize/test_visualize.py +++ b/test/visualize/test_visualize.py @@ -1184,3 +1184,15 @@ def test_sacess_history(): ) sacess.minimize(problem) sacess_history(sacess.histories) + + +@pytest.mark.parametrize( + "result_creation", + [create_optimization_result, create_optimization_result_nan_inf], +) +@close_fig +def test_parameters_correlation_matrix(result_creation): + """Test pypesto.visualize.parameters_correlation_matrix""" + result = result_creation() + + visualize.parameters_correlation_matrix(result) From 0c7e033693fb98b1092b894780a229511d7912fa Mon Sep 17 00:00:00 2001 From: PaulJonasJost Date: Tue, 9 Apr 2024 15:27:37 +0200 Subject: [PATCH 2/4] Fixed error in correlation matrix, if there were nn values. Also safety checks in general for start indices. Added a test for correlation matrix --- pypesto/visualize/misc.py | 13 ++++++++++--- test/visualize/test_visualize.py | 12 ++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/pypesto/visualize/misc.py b/pypesto/visualize/misc.py index d23f710b5..c08c3e816 100644 --- a/pypesto/visualize/misc.py +++ b/pypesto/visualize/misc.py @@ -323,7 +323,7 @@ def process_start_indices( start_indices = ALL if isinstance(start_indices, str): if start_indices == ALL: - return np.asarray(range(len(result.optimize_result))) + start_indices = np.asarray(range(len(result.optimize_result))) elif start_indices == ALL_CLUSTERED: clust_ind, clust_size = assign_clusters( delete_nan_inf(result.optimize_result.fval)[1] @@ -336,12 +336,12 @@ def process_start_indices( start_indices = np.concatenate( [np.where(clust_ind == i_clust)[0] for i_clust in clust_gr2] ) - return start_indices + start_indices = start_indices elif start_indices == FIRST_CLUSTER: clust_ind = assign_clusters( delete_nan_inf(result.optimize_result.fval)[1] )[0] - return np.where(clust_ind == 0)[0] + start_indices = np.where(clust_ind == 0)[0] else: raise ValueError( f"Permissible values for start_indices are {ALL}, " @@ -359,6 +359,13 @@ def process_start_indices( if start_index < len(result.optimize_result) ] + # filter out the indices that are not finite + start_indices = [ + start_index + for start_index in start_indices + if np.isfinite(result.optimize_result[start_index].fval) + ] + return np.asarray(start_indices) diff --git a/test/visualize/test_visualize.py b/test/visualize/test_visualize.py index 79aee2f30..e74f22443 100644 --- a/test/visualize/test_visualize.py +++ b/test/visualize/test_visualize.py @@ -1184,3 +1184,15 @@ def test_sacess_history(): ) sacess.minimize(problem) sacess_history(sacess.histories) + + +@pytest.mark.parametrize( + "result_creation", + [create_optimization_result, create_optimization_result_nan_inf], +) +@close_fig +def test_parameters_correlation_matrix(result_creation): + """Test pypesto.visualize.parameters_correlation_matrix""" + result = result_creation() + + visualize.parameters_correlation_matrix(result) From 08045a3755f87bf76645bcdbcc9ea7de16197a49 Mon Sep 17 00:00:00 2001 From: Maren Philipps Date: Wed, 17 Apr 2024 16:22:05 +0200 Subject: [PATCH 3/4] log filtering in process_start_indices --- pypesto/visualize/misc.py | 8 ++++++++ pypesto/visualize/parameters.py | 15 +-------------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/pypesto/visualize/misc.py b/pypesto/visualize/misc.py index c08c3e816..cb1ae29c1 100644 --- a/pypesto/visualize/misc.py +++ b/pypesto/visualize/misc.py @@ -1,3 +1,4 @@ +import logging import warnings from collections.abc import Iterable from numbers import Number @@ -23,6 +24,8 @@ from ..util import assign_clusters, delete_nan_inf from .clust_color import assign_colors_for_list +logger = logging.getLogger(__name__) + def process_result_list( results: Union[Result, list[Result]], colors=None, legends=None @@ -360,11 +363,16 @@ def process_start_indices( ] # filter out the indices that are not finite + start_indices_unfiltered = len(start_indices) start_indices = [ start_index for start_index in start_indices if np.isfinite(result.optimize_result[start_index].fval) ] + if len(start_indices) != start_indices_unfiltered: + logger.warning( + "Some start indices were removed due to inf or nan function values." + ) return np.asarray(start_indices) diff --git a/pypesto/visualize/parameters.py b/pypesto/visualize/parameters.py index a482a94ae..81b10e1e2 100644 --- a/pypesto/visualize/parameters.py +++ b/pypesto/visualize/parameters.py @@ -601,23 +601,10 @@ def optimization_scatter( parameter_indices = process_parameter_indices( parameter_indices=parameter_indices, result=result ) - # remove all start indices that encounter an inf value at the start - # resulting in optimize_result[start]["x"] being None - start_indices_finite = start_indices[ - [ - result.optimize_result[i_start]["x"] is not None - for i_start in start_indices - ] - ] - # compare start_indices with start_indices_finite and log a warning - if len(start_indices) != len(start_indices_finite): - logger.warning( - "Some start indices were removed due to inf values at the start." - ) # put all parameters into a dataframe, where columns are parameters parameters = [ result.optimize_result[i_start]["x"][parameter_indices] - for i_start in start_indices_finite + for i_start in start_indices ] x_labels = [ result.problem.x_names[parameter_index] From 983ab4c315fd8e928c9c73b04e6f2f9c1cd42271 Mon Sep 17 00:00:00 2001 From: Maren Philipps Date: Wed, 17 Apr 2024 16:26:36 +0200 Subject: [PATCH 4/4] update docstring --- pypesto/visualize/misc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pypesto/visualize/misc.py b/pypesto/visualize/misc.py index cb1ae29c1..5d7c1b491 100644 --- a/pypesto/visualize/misc.py +++ b/pypesto/visualize/misc.py @@ -306,8 +306,8 @@ def process_start_indices( """ Process the start_indices. - Create an array of indices if a number was provided and checks that the - indices do not exceed the max_index. + Create an array of indices if a number was provided, checks that the indices + do not exceed the max_index and removes starts with non-finite fval. Parameters ----------