From feb2841438bdfa3042cdf583dded947ff4f1be5b Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Fri, 21 Jan 2022 10:29:12 +0000 Subject: [PATCH 1/8] #1906 update docstring --- pybamm/solvers/solution.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pybamm/solvers/solution.py b/pybamm/solvers/solution.py index 46056ff2f8..c98d5a6a73 100644 --- a/pybamm/solvers/solution.py +++ b/pybamm/solvers/solution.py @@ -553,12 +553,19 @@ def save_data(self, filename, variables=None, to_format="pickle", short_names=No - 'pickle' (default): creates a pickle file with the data dictionary - 'matlab': creates a .mat file, for loading in matlab - 'csv': creates a csv file (0D variables only) + - 'json': creates a json file short_names : dict, optional Dictionary of shortened names to use when saving. This may be necessary when saving to MATLAB, since no spaces or special characters are allowed in MATLAB variable names. Note that not all the variables need to be given a short name. + Returns + ------- + data : str, optional + if 'csv' or 'json' is chosen, then this string is returned, otherwise None + + """ if variables is None: # variables not explicitly provided -> save all variables that have been From 58446f58a554c2b0525a8a81b00ed18d0ef9f17d Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Fri, 21 Jan 2022 11:15:22 +0000 Subject: [PATCH 2/8] #1906 add test for json solution saver --- tests/unit/test_solvers/test_solution.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_solvers/test_solution.py b/tests/unit/test_solvers/test_solution.py index 7298de45e6..d8ada301be 100644 --- a/tests/unit/test_solvers/test_solution.py +++ b/tests/unit/test_solvers/test_solution.py @@ -1,6 +1,7 @@ # # Tests for the Solution class # +import json import pybamm import unittest import numpy as np @@ -267,12 +268,27 @@ def test_save(self): ): solution.save_data("test.csv", to_format="csv") # only save "c" and "2c" - solution.save_data("test.csv", ["c", "2c"], to_format="csv") + csv_str = solution.save_data("test.csv", ["c", "2c"], to_format="csv") + + # check string is the same as the file + self.assertEqual(csv_str, open('test.csv').read()) + # read csv df = pd.read_csv("test.csv") np.testing.assert_array_almost_equal(df["c"], solution.data["c"]) np.testing.assert_array_almost_equal(df["2c"], solution.data["2c"]) + # to json + json_str = solution.save_data("test.json", to_format="json") + + # check string is the same as the file + self.assertEqual(json_str, open('test.json').read()) + + # check if string has the right values + json_data = json.loads(json_str) + np.testing.assert_array_almost_equal(json_data["c"], solution.data["c"]) + np.testing.assert_array_almost_equal(json_data["d"], solution.data["d"]) + # raise error if format is unknown with self.assertRaisesRegex(ValueError, "format 'wrong_format' not recognised"): solution.save_data("test.csv", to_format="wrong_format") @@ -284,6 +300,7 @@ def test_save(self): np.testing.assert_array_equal(solution["c"].entries, solution_load["c"].entries) np.testing.assert_array_equal(solution["d"].entries, solution_load["d"].entries) + def test_solution_evals_with_inputs(self): model = pybamm.lithium_ion.SPM() geometry = model.default_geometry From 854c79450ddcbfccc2dbc80eee519a320f860c41 Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Fri, 21 Jan 2022 11:30:52 +0000 Subject: [PATCH 3/8] #1906 implement json save_data, and allow function to return a string if needed --- pybamm/solvers/solution.py | 33 +++++++++++++++++++++--- tests/unit/test_solvers/test_solution.py | 20 ++++++++++---- 2 files changed, 45 insertions(+), 8 deletions(-) diff --git a/pybamm/solvers/solution.py b/pybamm/solvers/solution.py index c98d5a6a73..4e86acca51 100644 --- a/pybamm/solvers/solution.py +++ b/pybamm/solvers/solution.py @@ -2,6 +2,7 @@ # Solution class # import casadi +import json import numbers import numpy as np import pickle @@ -10,6 +11,18 @@ from scipy.io import savemat +class NumpyEncoder(json.JSONEncoder): + """ + Numpy serialiser helper class that converts numpy arrays to a list + https://stackoverflow.com/questions/26646362/numpy-array-is-not-json-serializable + """ + + def default(self, obj): + if isinstance(obj, np.ndarray): + return obj.tolist() + return json.JSONEncoder.default(self, obj) + + class Solution(object): """ Class containing the solution of, and various attributes associated with, a PyBaMM @@ -543,7 +556,7 @@ def save_data(self, filename, variables=None, to_format="pickle", short_names=No Parameters ---------- filename : str - The name of the file to save data to + The name of the file to save data to. If None, then a str is returned variables : list, optional List of variables to save. If None, saves all of the variables that have been created so far @@ -563,7 +576,7 @@ def save_data(self, filename, variables=None, to_format="pickle", short_names=No Returns ------- data : str, optional - if 'csv' or 'json' is chosen, then this string is returned, otherwise None + str if 'csv' or 'json' is chosen and filename is None, otherwise None """ @@ -595,9 +608,17 @@ def save_data(self, filename, variables=None, to_format="pickle", short_names=No data_short_names[name] = var if to_format == "pickle": + if filename is None: + raise ValueError( + "pickle format must be written to a file" + ) with open(filename, "wb") as f: pickle.dump(data_short_names, f, pickle.HIGHEST_PROTOCOL) elif to_format == "matlab": + if filename is None: + raise ValueError( + "matlab format must be written to a file" + ) # Check all the variable names only contain a-z, A-Z or _ or numbers for name in data_short_names.keys(): # Check the string only contains the following ASCII: @@ -632,7 +653,13 @@ def save_data(self, filename, variables=None, to_format="pickle", short_names=No ) ) df = pd.DataFrame(data_short_names) - df.to_csv(filename, index=False) + return df.to_csv(filename, index=False) + elif to_format == "json": + if filename is None: + return json.dumps(data_short_names, cls=NumpyEncoder) + else: + with open(filename, "w") as outfile: + json.dump(data_short_names, outfile, cls=NumpyEncoder) else: raise ValueError("format '{}' not recognised".format(to_format)) diff --git a/tests/unit/test_solvers/test_solution.py b/tests/unit/test_solvers/test_solution.py index d8ada301be..0bc32348b6 100644 --- a/tests/unit/test_solvers/test_solution.py +++ b/tests/unit/test_solvers/test_solution.py @@ -238,9 +238,13 @@ def test_save(self): # test save data with self.assertRaises(ValueError): solution.save_data("test.pickle") + # set variables first then save solution.update(["c", "d"]) + with self.assertRaisesRegex(ValueError, "pickle"): + solution.save_data(None, to_format="pickle") solution.save_data("test.pickle") + data_load = pybamm.load("test.pickle") np.testing.assert_array_equal(solution.data["c"], data_load["c"]) np.testing.assert_array_equal(solution.data["d"], data_load["d"]) @@ -251,6 +255,9 @@ def test_save(self): np.testing.assert_array_equal(solution.data["c"], data_load["c"].flatten()) np.testing.assert_array_equal(solution.data["d"], data_load["d"]) + with self.assertRaisesRegex(ValueError, "matlab"): + solution.save_data(None, to_format="matlab") + # to matlab with bad variables name fails solution.update(["c + d"]) with self.assertRaisesRegex(ValueError, "Invalid character"): @@ -268,10 +275,12 @@ def test_save(self): ): solution.save_data("test.csv", to_format="csv") # only save "c" and "2c" - csv_str = solution.save_data("test.csv", ["c", "2c"], to_format="csv") + solution.save_data("test.csv", ["c", "2c"], to_format="csv") + csv_str = solution.save_data(None, ["c", "2c"], to_format="csv") # check string is the same as the file - self.assertEqual(csv_str, open('test.csv').read()) + with open('test.csv') as f: + self.assertEqual(csv_str, f.read()) # read csv df = pd.read_csv("test.csv") @@ -279,10 +288,12 @@ def test_save(self): np.testing.assert_array_almost_equal(df["2c"], solution.data["2c"]) # to json - json_str = solution.save_data("test.json", to_format="json") + solution.save_data("test.json", to_format="json") + json_str = solution.save_data(None, to_format="json") # check string is the same as the file - self.assertEqual(json_str, open('test.json').read()) + with open('test.json') as f: + self.assertEqual(json_str, f.read()) # check if string has the right values json_data = json.loads(json_str) @@ -300,7 +311,6 @@ def test_save(self): np.testing.assert_array_equal(solution["c"].entries, solution_load["c"].entries) np.testing.assert_array_equal(solution["d"].entries, solution_load["d"].entries) - def test_solution_evals_with_inputs(self): model = pybamm.lithium_ion.SPM() geometry = model.default_geometry From 34c95f43ddcd6fcb55a9de2cd16a1ed49e410152 Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Fri, 21 Jan 2022 11:34:17 +0000 Subject: [PATCH 4/8] #1906 make filename arg of save_data optional --- pybamm/solvers/solution.py | 4 ++-- tests/unit/test_solvers/test_solution.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pybamm/solvers/solution.py b/pybamm/solvers/solution.py index 4e86acca51..11ae3d81cf 100644 --- a/pybamm/solvers/solution.py +++ b/pybamm/solvers/solution.py @@ -549,13 +549,13 @@ def save(self, filename): with open(filename, "wb") as f: pickle.dump(self, f, pickle.HIGHEST_PROTOCOL) - def save_data(self, filename, variables=None, to_format="pickle", short_names=None): + def save_data(self, filename=None, variables=None, to_format="pickle", short_names=None): """ Save solution data only (raw arrays) Parameters ---------- - filename : str + filename : str, optional The name of the file to save data to. If None, then a str is returned variables : list, optional List of variables to save. If None, saves all of the variables that have diff --git a/tests/unit/test_solvers/test_solution.py b/tests/unit/test_solvers/test_solution.py index 0bc32348b6..071e9fcbb5 100644 --- a/tests/unit/test_solvers/test_solution.py +++ b/tests/unit/test_solvers/test_solution.py @@ -242,7 +242,7 @@ def test_save(self): # set variables first then save solution.update(["c", "d"]) with self.assertRaisesRegex(ValueError, "pickle"): - solution.save_data(None, to_format="pickle") + solution.save_data(to_format="pickle") solution.save_data("test.pickle") data_load = pybamm.load("test.pickle") @@ -256,7 +256,7 @@ def test_save(self): np.testing.assert_array_equal(solution.data["d"], data_load["d"]) with self.assertRaisesRegex(ValueError, "matlab"): - solution.save_data(None, to_format="matlab") + solution.save_data(to_format="matlab") # to matlab with bad variables name fails solution.update(["c + d"]) @@ -276,7 +276,7 @@ def test_save(self): solution.save_data("test.csv", to_format="csv") # only save "c" and "2c" solution.save_data("test.csv", ["c", "2c"], to_format="csv") - csv_str = solution.save_data(None, ["c", "2c"], to_format="csv") + csv_str = solution.save_data(variables=["c", "2c"], to_format="csv") # check string is the same as the file with open('test.csv') as f: @@ -289,7 +289,7 @@ def test_save(self): # to json solution.save_data("test.json", to_format="json") - json_str = solution.save_data(None, to_format="json") + json_str = solution.save_data(to_format="json") # check string is the same as the file with open('test.json') as f: From 3ad9e548dd90cab5cf5c7d89a4527bb21fb5129a Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Fri, 21 Jan 2022 11:37:28 +0000 Subject: [PATCH 5/8] #1906 add to changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 56b62db25a..7f66ed40b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## Features - Added an option to force install compatible versions of jax and jaxlib if already installed using CLI ([#1881](https://github.com/pybamm-team/PyBaMM/pull/1881)) +- Allow pybamm.Solution.save_data() to return a string if filename is None, and added json to_format option ([#1909](https://github.com/pybamm-team/PyBaMM/pull/1909) ## Bug fixes From e4858350364a3b2d364830e9474ca46f3c8bcc09 Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Fri, 21 Jan 2022 11:41:28 +0000 Subject: [PATCH 6/8] #1906 flake8 --- pybamm/solvers/solution.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pybamm/solvers/solution.py b/pybamm/solvers/solution.py index 11ae3d81cf..bcf95efc8d 100644 --- a/pybamm/solvers/solution.py +++ b/pybamm/solvers/solution.py @@ -549,7 +549,10 @@ def save(self, filename): with open(filename, "wb") as f: pickle.dump(self, f, pickle.HIGHEST_PROTOCOL) - def save_data(self, filename=None, variables=None, to_format="pickle", short_names=None): + def save_data( + self, filename=None, variables=None, + to_format="pickle", short_names=None + ): """ Save solution data only (raw arrays) From a705a0352e80bfcb8cfc87b7d551265276c39029 Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Fri, 21 Jan 2022 13:42:16 +0000 Subject: [PATCH 7/8] #1906 fix for windows --- tests/unit/test_solvers/test_solution.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_solvers/test_solution.py b/tests/unit/test_solvers/test_solution.py index 071e9fcbb5..244c70bd35 100644 --- a/tests/unit/test_solvers/test_solution.py +++ b/tests/unit/test_solvers/test_solution.py @@ -280,7 +280,10 @@ def test_save(self): # check string is the same as the file with open('test.csv') as f: - self.assertEqual(csv_str, f.read()) + # need to strip \r chars for windows + self.assertEqual( + csv_str.replace('\r', ''), f.read() + ) # read csv df = pd.read_csv("test.csv") @@ -293,7 +296,10 @@ def test_save(self): # check string is the same as the file with open('test.json') as f: - self.assertEqual(json_str, f.read()) + # need to strip \r chars for windows + self.assertEqual( + json_str.replace('\r', ''), f.read() + ) # check if string has the right values json_data = json.loads(json_str) From fb7f727458de78fdca70b1ec418c7756970998db Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Fri, 21 Jan 2022 13:47:20 +0000 Subject: [PATCH 8/8] #1906 coverage for json encoder class --- pybamm/solvers/solution.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pybamm/solvers/solution.py b/pybamm/solvers/solution.py index bcf95efc8d..9ec9a552b1 100644 --- a/pybamm/solvers/solution.py +++ b/pybamm/solvers/solution.py @@ -20,7 +20,8 @@ class NumpyEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, np.ndarray): return obj.tolist() - return json.JSONEncoder.default(self, obj) + # won't be called since we only need to convert numpy arrays + return json.JSONEncoder.default(self, obj) # pragma: no cover class Solution(object):