diff --git a/autokoopman/autokoopman.py b/autokoopman/autokoopman.py index 9474f69..f8e6f7e 100644 --- a/autokoopman/autokoopman.py +++ b/autokoopman/autokoopman.py @@ -377,7 +377,12 @@ def _sanitize_training_data( # convert the data to autokoopman trajectories if isinstance(training_data, TrajectoriesData): if not isinstance(training_data, UniformTimeTrajectoriesData): + print(f"resampling trajectories as they need to be uniform time (sampling period {sampling_period})") training_data = training_data.interp_uniform_time(sampling_period) + else: + if not np.isclose(training_data.sampling_period, sampling_period): + print(f"resampling trajectories because the sampling periods differ (original {training_data.sampling_period}, new {sampling_period})") + training_data = training_data.interp_uniform_time(sampling_period) else: # figure out how to add inputs training_iter = ( diff --git a/autokoopman/core/tuner.py b/autokoopman/core/tuner.py index 5f00384..3034c18 100644 --- a/autokoopman/core/tuner.py +++ b/autokoopman/core/tuner.py @@ -178,13 +178,18 @@ def generate_predictions( preds = {} # get the predictions for k, v in holdout_data._trajs.items(): + # ugh, this is a hack--sampling period is meaningful if the system is discrete + # TODO: make the behavior of sampling period less problematic + _kwargs = {} + if hasattr(v, "sampling_period"): + _kwargs["sampling_period"] = v.sampling_period sivp_interp = trained_model.model.solve_ivp( v.states[0], (np.min(v.times), np.max(v.times)), inputs=v.inputs, teval=v.times, + **_kwargs ) - # sivp_interp = sivp_interp.interp1d(v.times) preds[k] = sivp_interp return TrajectoriesData(preds)