Skip to content

Commit 2bf78f1

Browse files
committed
Pass current machine's architecture to rapids-dependency-file-generator.
1 parent 340e9a5 commit 2bf78f1

File tree

2 files changed

+37
-7
lines changed

2 files changed

+37
-7
lines changed

rapids_build_backend/impls.py

+14
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) 2024, NVIDIA CORPORATION.
22

33
import os
4+
import platform
45
import re
56
import shutil
67
import subprocess
@@ -51,6 +52,18 @@ def _get_backend(build_backend):
5152
)
5253

5354

55+
@lru_cache
56+
def _get_arch():
57+
"""Get the arch of the current machine.
58+
59+
Returns
60+
-------
61+
str
62+
The arch (e.g. "x86_64" or "aarch64")
63+
"""
64+
return platform.machine()
65+
66+
5467
@lru_cache
5568
def _get_cuda_version():
5669
"""Get the CUDA suffix based on nvcc.
@@ -190,6 +203,7 @@ def _edit_pyproject(config):
190203
matrix = _parse_matrix(config.matrix_entry) or dict(file_config.matrix)
191204
if not config.disable_cuda:
192205
matrix["cuda"] = [f"{cuda_version_major}.{cuda_version_minor}"]
206+
matrix["arch"] = [_get_arch()]
193207
rapids_dependency_file_generator.make_dependency_files(
194208
parsed_config=parsed_config,
195209
file_keys=[file_key],

tests/test_impls.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) 2024, NVIDIA CORPORATION.
22

33
import os.path
4+
import platform
45
from contextlib import contextmanager
56
from textwrap import dedent
67
from unittest.mock import Mock, patch
@@ -10,6 +11,7 @@
1011
from rapids_build_backend.impls import (
1112
_check_setup_py,
1213
_edit_pyproject,
14+
_get_arch,
1315
_get_cuda_suffix,
1416
_remove_rapidsai_from_config,
1517
_write_git_commits,
@@ -88,8 +90,9 @@ def test_write_git_commits(
8890
"cuda_version",
8991
"cuda_suffix",
9092
"cuda_python_requirement",
91-
"matrix",
93+
"arch",
9294
"arch_requirement",
95+
"matrix",
9396
],
9497
[
9598
(
@@ -100,8 +103,9 @@ def test_write_git_commits(
100103
("11", "5"),
101104
"-cu11",
102105
"cuda-python>=11.5,<11.6.dev0",
103-
"",
106+
"x86_64",
104107
"some-x86-package",
108+
"",
105109
),
106110
(
107111
".",
@@ -111,8 +115,9 @@ def test_write_git_commits(
111115
("11", "5"),
112116
"-cu11",
113117
"cuda-python>=11.5,<11.6.dev0",
114-
"arch=aarch64",
118+
"aarch64",
115119
"some-arm-package",
120+
"",
116121
),
117122
(
118123
"python",
@@ -122,8 +127,9 @@ def test_write_git_commits(
122127
("12", "1"),
123128
"-cu12",
124129
"cuda-python>=12.1,<12.2.dev0",
125-
"",
130+
"x86_64",
126131
"some-x86-package",
132+
"",
127133
),
128134
(
129135
".",
@@ -133,8 +139,11 @@ def test_write_git_commits(
133139
("11", "5"),
134140
"-cu11",
135141
None,
142+
None, # Test the arch detection logic
143+
"some-x86-package"
144+
if platform.machine() == "x86_64"
145+
else "some-arm-package",
136146
"",
137-
None,
138147
),
139148
(
140149
".",
@@ -144,8 +153,11 @@ def test_write_git_commits(
144153
None, # Ensure _get_cuda_version() isn't called and unpacked
145154
"",
146155
"cuda-python",
156+
None, # Test the arch detection logic
157+
"some-x86-package"
158+
if platform.machine() == "x86_64"
159+
else "some-arm-package",
147160
"",
148-
"some-x86-package",
149161
),
150162
],
151163
)
@@ -158,8 +170,9 @@ def test_edit_pyproject(
158170
cuda_version,
159171
cuda_suffix,
160172
cuda_python_requirement,
161-
matrix,
173+
arch,
162174
arch_requirement,
175+
matrix,
163176
):
164177
original_contents = dedent(
165178
"""\
@@ -265,6 +278,9 @@ def test_edit_pyproject(
265278
)
266279

267280
with patch(
281+
"rapids_build_backend.impls._get_arch",
282+
Mock(return_value=arch) if arch is not None else _get_arch,
283+
), patch(
268284
"rapids_build_backend.impls._get_cuda_version",
269285
Mock(return_value=cuda_version),
270286
), patch(

0 commit comments

Comments
 (0)