You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In [1]: from photosynthesis_metrics import MultiScaleGMSDLoss
In [2]: import torch
In [3]: loss = MultiScaleGMSDLoss(chromatic=True)
In [4]: x = torch.rand(1, 3, 256, 256)
In [5]: y = torch.rand(1, 3, 256, 256)
In [6]: loss(x, y)
/home/ostyakov/.local/lib/python3.6/site-packages/photosynthesis_metrics/utils.py:58: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
assert (torch.tensor(scale_weights).dim() == 1), \
Out[6]: tensor(0.1725)
In [7]: loss(x.cuda(), y.cuda())
/home/ostyakov/.local/lib/python3.6/site-packages/photosynthesis_metrics/utils.py:58: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
assert (torch.tensor(scale_weights).dim() == 1), \
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-7-d464c6bd32e6> in <module>()
----> 1 loss(x.cuda(), y.cuda())
/home/ostyakov/.local/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
539 result = self._slow_forward(*input, **kwargs)
540 else:
--> 541 result = self.forward(*input, **kwargs)
542 for hook in self._forward_hooks.values():
543 hook_result = hook(self, input, result)
/home/ostyakov/.local/lib/python3.6/site-packages/photosynthesis_metrics/gmsd.py in forward(self, prediction, target)
208 prediction, target = _adjust_dimensions(input_tensors=(prediction, target))
209
--> 210 return self.compute_metric(prediction, target)
211
212 def compute_metric(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
/home/ostyakov/.local/lib/python3.6/site-packages/photosynthesis_metrics/gmsd.py in compute_metric(self, prediction, target)
248 # Convert to YIQ color space https://en.wikipedia.org/wiki/YIQ
249 iq_weights = torch.tensor([[0.5959, -0.2746, -0.3213], [0.2115, -0.5227, 0.3112]]).t()
--> 250 prediction_iq = torch.matmul(prediction.permute(0, 2, 3, 1), iq_weights).permute(0, 3, 1, 2)
251 target_iq = torch.matmul(target.permute(0, 2, 3, 1), iq_weights).permute(0, 3, 1, 2)
252
RuntimeError: Expected object of device type cuda but got device type cpu for argument #2 'mat2' in call to _th_mm
The text was updated successfully, but these errors were encountered:
The text was updated successfully, but these errors were encountered: