-
Notifications
You must be signed in to change notification settings - Fork 529
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix #3430. This PR sets up the basic support for the array API, and make an example function (`compute_smooth_weight`) to support the array API. I believe NumPy and JAX have supported it (or through `array-api-compat`), so we don't need to write things twice for NumPy and JAX (although we can write them using the ChatGPT, it's still better to maintain only one thing). There are some challeging to use it in the TorchScript, so I give it up. Supporting more function can be implemented in the following PRs. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced testing for `compute_smooth_weight` function using `array_api_strict` for enhanced array operations. - **Chores** - Updated dependencies to include `'array-api-compat'` and `'array-api-strict>=2'` for improved compatibility and testing capabilities. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <[email protected]>
- Loading branch information
Showing
7 changed files
with
102 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
"""Utilities for the array API.""" | ||
|
||
|
||
def support_array_api(version: str) -> callable: | ||
"""Mark a function as supporting the specific version of the array API. | ||
Parameters | ||
---------- | ||
version : str | ||
The version of the array API | ||
Returns | ||
------- | ||
callable | ||
The decorated function | ||
Examples | ||
-------- | ||
>>> @support_array_api(version="2022.12") | ||
... def f(x): | ||
... pass | ||
""" | ||
|
||
def set_version(func: callable) -> callable: | ||
func.array_api_version = version | ||
return func | ||
|
||
return set_version |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
"""Test array API compatibility to be completely sure their usage of the array API is portable.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
import sys | ||
import unittest | ||
|
||
if sys.version_info >= (3, 9): | ||
import array_api_strict as xp | ||
else: | ||
raise unittest.SkipTest("array_api_strict doesn't support Python<=3.8") | ||
|
||
from deepmd.dpmodel.utils.env_mat import ( | ||
compute_smooth_weight, | ||
) | ||
|
||
from .utils import ( | ||
ArrayAPITest, | ||
) | ||
|
||
|
||
class TestEnvMat(unittest.TestCase, ArrayAPITest): | ||
def test_compute_smooth_weight(self): | ||
self.set_array_api_version(compute_smooth_weight) | ||
d = xp.arange(10, dtype=xp.float64) | ||
w = compute_smooth_weight( | ||
d, | ||
4.0, | ||
6.0, | ||
) | ||
self.assert_namespace_equal(w, d) | ||
self.assert_device_equal(w, d) | ||
self.assert_dtype_equal(w, d) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
import array_api_compat | ||
from array_api_strict import ( | ||
set_array_api_strict_flags, | ||
) | ||
|
||
|
||
class ArrayAPITest: | ||
"""Utils for array API tests.""" | ||
|
||
def set_array_api_version(self, func): | ||
"""Set the array API version for a function.""" | ||
set_array_api_strict_flags(api_version=func.array_api_version) | ||
|
||
def assert_namespace_equal(self, a, b): | ||
"""Assert two array has the same namespace.""" | ||
self.assertEqual( | ||
array_api_compat.array_namespace(a), array_api_compat.array_namespace(b) | ||
) | ||
|
||
def assert_dtype_equal(self, a, b): | ||
"""Assert two array has the same dtype.""" | ||
self.assertEqual(a.dtype, b.dtype) | ||
|
||
def assert_device_equal(self, a, b): | ||
"""Assert two array has the same device.""" | ||
self.assertEqual(array_api_compat.device(a), array_api_compat.device(b)) |