diff --git a/pybop/_dataset.py b/pybop/_dataset.py index 120fcb61..35fc9c61 100644 --- a/pybop/_dataset.py +++ b/pybop/_dataset.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Optional, Union import numpy as np from pybamm import solvers @@ -15,11 +15,13 @@ class Dataset: ---------- data_dictionary : dict or instance of pybamm.solvers.solution.Solution The experimental data to store within the dataset. + domain : str, optional + The domain of the dataset. Defaults to "Time [s]". """ - def __init__(self, data_dictionary): + def __init__(self, data_dictionary, domain: Optional[str] = None): """ - Initialize a Dataset instance with data and a set of names. + Initialise a Dataset instance with data and a set of names. """ if isinstance(data_dictionary, solvers.solution.Solution): @@ -27,7 +29,7 @@ def __init__(self, data_dictionary): if not isinstance(data_dictionary, dict): raise TypeError("The input to pybop.Dataset must be a dictionary.") self.data = data_dictionary - self.names = self.data.keys() + self.domain = domain or "Time [s]" def __repr__(self): """ @@ -38,7 +40,7 @@ def __repr__(self): str A string that includes the type and contents of the dataset. """ - return f"Dataset: {type(self.data)} \n Contains: {self.names}" + return f"Dataset: {type(self.data)} \n Contains: {self.data.keys()}" def __setitem__(self, key, value): """ @@ -85,7 +87,7 @@ def check(self, domain: str = None, signal: Union[str, list[str]] = None) -> boo Parameters ---------- domain : str, optional - The domain of the dataset. Defaults to "Time [s]". + If not None, updates the domain of the dataset. signal : str or List[str], optional The signal(s) to check. Defaults to ["Voltage [V]"]. @@ -99,11 +101,11 @@ def check(self, domain: str = None, signal: Union[str, list[str]] = None) -> boo ValueError If the time series and the data series are not consistent. """ - self.domain = domain or "Time [s]" + self.domain = domain or self.domain signals = [signal] if isinstance(signal, str) else (signal or ["Voltage [V]"]) # Check that the dataset contains domain and chosen signals - missing_attributes = set([self.domain, *signals]) - set(self.names) + missing_attributes = set([self.domain, *signals]) - set(self.data.keys()) if missing_attributes: raise ValueError( f"Expected {', '.join(missing_attributes)} in list of dataset" @@ -143,3 +145,13 @@ def _check_data_consistency( raise ValueError( f"{self.domain} data and {s} data must be the same length." ) + + def get_subset(self, index: Union[list, np.ndarray]): + """ + Reduce the dataset to a subset defined by the list of indices. + """ + data = {} + for key in self.data.keys(): + data[key] = self[key][index] + + return Dataset(data, domain=self.domain) diff --git a/tests/unit/test_dataset.py b/tests/unit/test_dataset.py index bb881429..76cce412 100644 --- a/tests/unit/test_dataset.py +++ b/tests/unit/test_dataset.py @@ -42,7 +42,6 @@ def test_dataset(self): # Test conversion of pybamm solution into dictionary assert dataset.data == pybop.Dataset(solution).data - assert dataset.names == pybop.Dataset(solution).names # Test set and get item test_current = solution["Current [A]"].data + np.ones_like( @@ -56,6 +55,10 @@ def test_dataset(self): # Test conversion of single signal to list assert dataset.check() + # Test get subset + dataset = dataset.get_subset(list(range(5))) + assert len(dataset[dataset.domain]) == 5 + # Form frequency dataset data_dictionary = { "Frequency [Hz]": np.linspace(-10, 0, 10),