Skip to content

Commit 3ec1704

Browse files
Minor updates to plotting, dataset and interpolant (#915)
* Align dict name with pybamm.Solution * Update axis labels * Allow len(dataset) * Update CHANGELOG.md * Add kind option to interpolant * Update CHANGELOG.md
1 parent 5ecf55c commit 3ec1704

9 files changed

Lines changed: 24 additions & 12 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
## Bug Fixes
88

9+
- [#915](https://github.com/pybop-team/PyBOP/pull/915) - Fixes axis labels for non-standard domain names, adds `Dataset` length property and adds `kind` property to `Interpolant`.
910
- [#911](https://github.com/pybop-team/PyBOP/pull/911) - Fixes the passing of the cost log to the Voronoi surface plot.
1011
- [#905](https://github.com/pybop-team/PyBOP/pull/905) - Remove restriction on numpy.
1112

examples/scripts/battery_parameterisation/gitt_fitting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
# Determine the indices corresponding to each pulse in the dataset
3737
nonzero_index = np.concatenate(
38-
([-1], np.flatnonzero(dataset["Current [A]"]), [len(dataset["Current [A]"]) + 1])
38+
([-1], np.flatnonzero(dataset["Current [A]"]), [len(dataset) + 1])
3939
)
4040
pulse_starts = np.extract(
4141
nonzero_index[1:] - nonzero_index[:-1] != 1, # check if there is a gap

pybop/models/lithium_ion/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class Interpolant:
2929
Output values corresponding to x.
3030
name : str, optional
3131
Name for the interpolant when used in PyBaMM.
32+
kind: str, optional
33+
Which kind of interpolator to use. Can be "linear" (default) or "cubic".
3234
bounds_error : bool, optional
3335
If True, raise error when interpolating outside bounds.
3436
fill_value : str or float, optional
@@ -42,13 +44,15 @@ def __init__(
4244
x: np.ndarray,
4345
y: np.ndarray,
4446
name: str | None = None,
47+
kind: str | None = None,
4548
bounds_error: bool = False,
4649
fill_value: str | float = "extrapolate",
4750
axis: int = 0,
4851
):
4952
self.x = np.asarray(x)
5053
self.y = np.asarray(y)
5154
self.name = name
55+
self.kind = kind or "linear"
5256
self._interp_func = self._create_interpolant(bounds_error, fill_value, axis)
5357

5458
def _create_interpolant(
@@ -58,6 +62,7 @@ def _create_interpolant(
5862
return interpolate.interp1d(
5963
self.x,
6064
self.y,
65+
kind=self.kind,
6166
bounds_error=bounds_error,
6267
fill_value=fill_value,
6368
axis=axis,
@@ -82,7 +87,9 @@ def __call__(self, x: float | np.ndarray):
8287
return self._interp_func(x)
8388
except Exception:
8489
# Fall back to PyBaMM interpolant for symbolic evaluation
85-
return PybammInterpolant(self.x, self.y, x, name=self.name)
90+
return PybammInterpolant(
91+
self.x, self.y, x, name=self.name, interpolator=self.kind
92+
)
8693

8794

8895
class InverseOCV:

pybop/plot/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def dataset(dataset, signal=None, trace_names=None, show=True, **layout_kwargs):
3434
# Compile ydata and labels or legend
3535
y = [dataset[s] for s in signal]
3636
if len(signal) == 1:
37-
yaxis_title = signal[0]
37+
yaxis_title = StandardPlot.remove_brackets(signal[0])
3838
if trace_names is None:
3939
trace_names = ["Data"]
4040
else:

pybop/plot/problem.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def problem(
7373
plot_dict = StandardPlot(
7474
layout_options=dict(
7575
title="Scatter Plot",
76-
xaxis_title="Time / s",
76+
xaxis_title=StandardPlot.remove_brackets(domain),
7777
yaxis_title=StandardPlot.remove_brackets(var),
7878
)
7979
)

pybop/processing/dataset.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ def __getitem__(self, key):
6666

6767
return self.data[key]
6868

69+
def __len__(self) -> int:
70+
"""Return the length of the data, based on the length of the domain data."""
71+
return len(self.data[self.domain])
72+
6973
def check(self, domain: str = None, signal: str | list[str] = None) -> bool:
7074
"""
7175
Check the consistency of a PyBOP Dataset against the expected format.

pybop/simulators/solution.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class Solution:
2121
"""
2222

2323
def __init__(self, inputs: Inputs = None):
24-
self._dict = {}
24+
self._variables = {}
2525
self.all_inputs = [inputs] if inputs is not None else []
2626

2727
def set_solution_variable(
@@ -30,9 +30,9 @@ def set_solution_variable(
3030
data: np.ndarray,
3131
sensitivities: dict[str, np.ndarray] | None = None,
3232
):
33-
self._dict[variable_name] = SolutionVariable(
33+
self._variables[variable_name] = SolutionVariable(
3434
data=data, sensitivities=sensitivities
3535
)
3636

3737
def __getitem__(self, key):
38-
return self._dict[key]
38+
return self._variables[key]

tests/unit/test_dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_dataset(self):
6363

6464
# Test get subset
6565
dataset = dataset.get_subset(list(range(5)))
66-
assert len(dataset[dataset.domain]) == 5
66+
assert len(dataset) == 5
6767

6868
# Form frequency dataset
6969
data_dictionary = {
@@ -259,7 +259,7 @@ def dataset(self):
259259
def test_current_data_processing(self, dataset):
260260
# Test generation of a current consistent with the charge throughput data
261261
consistent_dataset = pybop.generate_consistent_current(dataset, tolerance=1e-2)
262-
assert len(consistent_dataset["Time [s]"]) >= len(dataset["Time [s]"])
262+
assert len(consistent_dataset) >= len(dataset)
263263

264264
for var in ["Time [s]", "Current [A]", "Discharge capacity [A.h]"]:
265265
assert consistent_dataset[var][0] == dataset[var][0]
@@ -275,7 +275,7 @@ def test_current_data_processing(self, dataset):
275275

276276
# Test downsampling of constant current sections
277277
downsampled_dataset = pybop.downsample_constant_current(dataset, tolerance=1e-4)
278-
assert len(downsampled_dataset["Time [s]"]) < len(dataset["Time [s]"])
278+
assert len(downsampled_dataset) < len(dataset)
279279

280280
for var in ["Time [s]", "Current [A]", "Discharge capacity [A.h]"]:
281281
assert downsampled_dataset[var][0] == dataset[var][0]
@@ -302,7 +302,7 @@ def test_current_data_processing(self, dataset):
302302
}
303303
)
304304
ds_dataset = pybop.downsample_constant_current(dataset_wo_ct, tolerance=1e-4)
305-
assert len(ds_dataset["Time [s]"]) < len(dataset["Time [s]"])
305+
assert len(ds_dataset) < len(dataset)
306306

307307
for var in ["Time [s]", "Current [A]"]:
308308
assert ds_dataset[var][0] == dataset_wo_ct[var][0]

tests/unit/test_problem.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_fitting_problem(self, simulator, dataset):
106106
assert_array_equal(target_data, dataset["Voltage [V]"])
107107

108108
# Test set target
109-
dataset["Voltage [V]"] += np.random.normal(0, 0.05, len(dataset["Voltage [V]"]))
109+
dataset["Voltage [V]"] += np.random.normal(0, 0.05, len(dataset))
110110
cost.set_target("Voltage [V]", dataset)
111111
problem = pybop.Problem(simulator, cost)
112112

0 commit comments

Comments
 (0)