Skip to content

Commit

Permalink
Improve warning message when catching PyTorch MPS convolution bug (#522)
Browse files Browse the repository at this point in the history
Better warning message
  • Loading branch information
sdatkinson authored Dec 19, 2024
1 parent 4fb6408 commit 52640d6
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions nam/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ def _export_input_output(self) -> _Tuple[_np.ndarray, _np.ndarray]:
)


def _get_torch_version() -> str:
return _torch.__version__


class BaseNet(_Base):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -217,17 +221,25 @@ def _forward_mps_safe(self, x: _torch.Tensor, **kwargs) -> _torch.Tensor:
return self._forward(x, **kwargs)
except NotImplementedError as e:
if "Output channels > 65536 not supported at the MPS device." in str(e):
print(
"===WARNING===\n"
"NAM encountered a bug in PyTorch's MPS backend and will "
"switch to a fallback.\n"
f"Your version of PyTorch is {_torch.__version__}.\n"
"Please report this in an Issue at:\n"
"https://github.com/sdatkinson/neural-amp-modeler/issues/new/choose"
"\n"
"so that NAM's dependencies can avoid buggy versions of "
"PyTorch and the associated performance hit."
msg = (
"Warning: NAM encountered a bug in PyTorch's MPS backend and "
"will switch to a fallback."
)
known_bad_versions = {"2.5.0", "2.5.1"}
torch_version = _get_torch_version()
if torch_version not in known_bad_versions:
msg += (
"\n"
f"Your version of PyTorch is {torch_version}, which "
"wasn't known to have this problem.\n"
"Please open an Issue at:\n"
"https://github.com/sdatkinson/neural-amp-modeler/issues/507"
"\n"
f"and report your PyTorch version ({torch_version}) "
"so that we can keep track of versions of PyTorch that "
"might be avoided."
)
print(msg)
self._mps_65536_fallback = True
return self._forward_mps_safe(x, **kwargs)
else:
Expand Down

0 comments on commit 52640d6

Please sign in to comment.