Skip to content

Commit

Permalink
Add Dataset.get_subset (#624)
Browse files Browse the repository at this point in the history
* Add dataset.get_subset

* Update domain setting
  • Loading branch information
NicolaCourtier authored Jan 14, 2025
1 parent a9c4839 commit b550e3e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 9 deletions.
28 changes: 20 additions & 8 deletions pybop/_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Optional, Union

import numpy as np
from pybamm import solvers
Expand All @@ -15,19 +15,21 @@ class Dataset:
----------
data_dictionary : dict or instance of pybamm.solvers.solution.Solution
The experimental data to store within the dataset.
domain : str, optional
The domain of the dataset. Defaults to "Time [s]".
"""

def __init__(self, data_dictionary):
def __init__(self, data_dictionary, domain: Optional[str] = None):
"""
Initialize a Dataset instance with data and a set of names.
Initialise a Dataset instance with data and a set of names.
"""

if isinstance(data_dictionary, solvers.solution.Solution):
data_dictionary = data_dictionary.get_data_dict()
if not isinstance(data_dictionary, dict):
raise TypeError("The input to pybop.Dataset must be a dictionary.")
self.data = data_dictionary
self.names = self.data.keys()
self.domain = domain or "Time [s]"

def __repr__(self):
"""
Expand All @@ -38,7 +40,7 @@ def __repr__(self):
str
A string that includes the type and contents of the dataset.
"""
return f"Dataset: {type(self.data)} \n Contains: {self.names}"
return f"Dataset: {type(self.data)} \n Contains: {self.data.keys()}"

def __setitem__(self, key, value):
"""
Expand Down Expand Up @@ -85,7 +87,7 @@ def check(self, domain: str = None, signal: Union[str, list[str]] = None) -> boo
Parameters
----------
domain : str, optional
The domain of the dataset. Defaults to "Time [s]".
If not None, updates the domain of the dataset.
signal : str or List[str], optional
The signal(s) to check. Defaults to ["Voltage [V]"].
Expand All @@ -99,11 +101,11 @@ def check(self, domain: str = None, signal: Union[str, list[str]] = None) -> boo
ValueError
If the time series and the data series are not consistent.
"""
self.domain = domain or "Time [s]"
self.domain = domain or self.domain
signals = [signal] if isinstance(signal, str) else (signal or ["Voltage [V]"])

# Check that the dataset contains domain and chosen signals
missing_attributes = set([self.domain, *signals]) - set(self.names)
missing_attributes = set([self.domain, *signals]) - set(self.data.keys())
if missing_attributes:
raise ValueError(
f"Expected {', '.join(missing_attributes)} in list of dataset"
Expand Down Expand Up @@ -143,3 +145,13 @@ def _check_data_consistency(
raise ValueError(
f"{self.domain} data and {s} data must be the same length."
)

def get_subset(self, index: Union[list, np.ndarray]):
"""
Reduce the dataset to a subset defined by the list of indices.
"""
data = {}
for key in self.data.keys():
data[key] = self[key][index]

return Dataset(data, domain=self.domain)
5 changes: 4 additions & 1 deletion tests/unit/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def test_dataset(self):

# Test conversion of pybamm solution into dictionary
assert dataset.data == pybop.Dataset(solution).data
assert dataset.names == pybop.Dataset(solution).names

# Test set and get item
test_current = solution["Current [A]"].data + np.ones_like(
Expand All @@ -56,6 +55,10 @@ def test_dataset(self):
# Test conversion of single signal to list
assert dataset.check()

# Test get subset
dataset = dataset.get_subset(list(range(5)))
assert len(dataset[dataset.domain]) == 5

# Form frequency dataset
data_dictionary = {
"Frequency [Hz]": np.linspace(-10, 0, 10),
Expand Down

0 comments on commit b550e3e

Please sign in to comment.