Skip to content

Commit

Permalink
Merge pull request #306 from vpratz/fix-adapter-serialization
Browse files Browse the repository at this point in the history
Fix: make transforms AsSet and AsTimeSeries serializable
  • Loading branch information
vpratz authored Feb 11, 2025
2 parents 9071be4 + 0a4921b commit 5da3759
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 1 deletion.
7 changes: 7 additions & 0 deletions bayesflow/adapters/transforms/as_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,10 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
return np.squeeze(data, axis=2)

return data

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "AsSet":
return cls()

def get_config(self) -> dict:
return {}
7 changes: 7 additions & 0 deletions bayesflow/adapters/transforms/as_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,10 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
return np.squeeze(data, axis=2)

return data

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "AsTimeSeries":
return cls()

def get_config(self) -> dict:
return {}
6 changes: 5 additions & 1 deletion bayesflow/networks/flow_matching/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,11 @@ def compute_metrics(
else:
# not pre-configured, resample
x1 = x
x0 = self.base_distribution.sample(keras.ops.shape(x1), seed=self.seed_generator)
if not self.built:
xz_shape = keras.ops.shape(x1)
conditions_shape = None if conditions is None else keras.ops.shape(conditions)
self.build(xz_shape, conditions_shape)
x0 = self.base_distribution.sample(keras.ops.shape(x1)[:-1])

if self.use_optimal_transport:
x1, x0, conditions = optimal_transport(
Expand Down

0 comments on commit 5da3759

Please sign in to comment.