Skip to content

Commit

Permalink
Register get_cpu_capability for jit
Browse files Browse the repository at this point in the history
Context: In torchvision we ensure that functional ops are torchsciptable.
Recently exposed `torch.backends.cpu.get_cpu_capability()` in pytorch#100164 is failing in torchvision CI
```
RuntimeError:
Python builtin <built-in function _get_cpu_capability> is currently not supported in Torchscript:
  File "/usr/local/lib/python3.10/dist-packages/torch/backends/cpu/__init__.py", line 17
    - "AVX512"
    """
    return torch._C._get_cpu_capability()
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
```
Ref: pytorch/vision#7557

In this PR, `torch._C._get_cpu_capability()` is explicitly registered for JIT and tested.
  • Loading branch information
vfdev-5 committed May 8, 2023
1 parent d9d98b4 commit 19d77cb
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 0 deletions.
5 changes: 5 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7445,8 +7445,13 @@ def test_parallel_info(self):
torch.__config__.parallel_info()

def test_get_cpu_capability(self):
# This method is primarily exposed for torchvision's resize
torch.backends.cpu.get_cpu_capability()

# We have to ensure that method is torchscriptable as torchvision's resize
# should be torchscriptable
torch.jit.script(torch.backends.cpu.get_cpu_capability)

@slowTest
def test_slow_test(self):
# Just a smoketest to make sure our slowTest decorator works.
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/runtime/register_special_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,10 @@ RegisterOperators reg({
"aten::set_grad_enabled(bool val) -> ()",
[](Stack& stack) { torch::GradMode::set_enabled(pop(stack).toBool()); },
aliasAnalysisConservative()),
Operator(
"aten::_get_cpu_capability() -> str",
[](Stack& stack) { push(stack, at::get_cpu_capability()); },
aliasAnalysisConservative()),
});
} // namespace
} // namespace jit
Expand Down
1 change: 1 addition & 0 deletions torch/jit/_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
(torch.nn.init._no_grad_uniform_, "aten::_no_grad_uniform_"),
(torch.nn.init._no_grad_zero_, "aten::_no_grad_zero_"),
(torch._C._get_tracing_state, "aten::_get_tracing_state"),
(torch._C._get_cpu_capability, "aten::_get_cpu_capability"),
(warnings.warn, "aten::warn"),
(torch._VF.stft, "aten::stft"), # type: ignore[attr-defined]
(torch._VF.istft, "aten::istft"), # type: ignore[attr-defined]
Expand Down

0 comments on commit 19d77cb

Please sign in to comment.