diff --git a/nam/models/_base.py b/nam/models/_base.py index e77e797a..40ce84af 100644 --- a/nam/models/_base.py +++ b/nam/models/_base.py @@ -54,6 +54,18 @@ def _metadata_loudness_x(cls) -> torch.Tensor: ) ) + @property + def device(self) -> Optional[torch.device]: + """ + Helpful property, where the parameters of the model live. + """ + # We can do this because the models are tiny and I don't expect a NAM to be on + # multiple devices + try: + return next(self.parameters()).device + except StopIteration: + return None + @property def sample_rate(self) -> Optional[float]: return self._sample_rate.item() if self._has_sample_rate else None @@ -81,7 +93,7 @@ def _metadata_loudness(self, gain: float = 1.0, db: bool = True) -> float: :param gain: Multiplies input signal """ - x = self._metadata_loudness_x() + x = self._metadata_loudness_x().to(self.device) y = self._at_nominal_settings(gain * x) loudness = torch.sqrt(torch.mean(torch.square(y))) if db: diff --git a/nam/models/wavenet.py b/nam/models/wavenet.py index e7015c20..b48e41e2 100644 --- a/nam/models/wavenet.py +++ b/nam/models/wavenet.py @@ -318,7 +318,7 @@ def export_weights(self) -> np.ndarray: weights = torch.cat([layer.export_weights() for layer in self._layers]) if self._head is not None: weights = torch.cat([weights, self._head.export_weights()]) - weights = torch.cat([weights, torch.Tensor([self._head_scale])]) + weights = torch.cat([weights.cpu(), torch.Tensor([self._head_scale])]) return weights.detach().cpu().numpy() def import_weights(self, weights: torch.Tensor):