1
1
# Copyright (c) 2024, NVIDIA CORPORATION.
2
2
3
3
import os .path
4
+ import platform
4
5
from contextlib import contextmanager
5
6
from textwrap import dedent
6
7
from unittest .mock import Mock , patch
10
11
from rapids_build_backend .impls import (
11
12
_check_setup_py ,
12
13
_edit_pyproject ,
14
+ _get_arch ,
13
15
_get_cuda_suffix ,
14
16
_remove_rapidsai_from_config ,
15
17
_write_git_commits ,
@@ -88,8 +90,9 @@ def test_write_git_commits(
88
90
"cuda_version" ,
89
91
"cuda_suffix" ,
90
92
"cuda_python_requirement" ,
91
- "matrix " ,
93
+ "arch " ,
92
94
"arch_requirement" ,
95
+ "matrix" ,
93
96
],
94
97
[
95
98
(
@@ -100,8 +103,9 @@ def test_write_git_commits(
100
103
("11" , "5" ),
101
104
"-cu11" ,
102
105
"cuda-python>=11.5,<11.6.dev0" ,
103
- "" ,
106
+ "x86_64 " ,
104
107
"some-x86-package" ,
108
+ "" ,
105
109
),
106
110
(
107
111
"." ,
@@ -111,8 +115,9 @@ def test_write_git_commits(
111
115
("11" , "5" ),
112
116
"-cu11" ,
113
117
"cuda-python>=11.5,<11.6.dev0" ,
114
- "arch= aarch64" ,
118
+ "aarch64" ,
115
119
"some-arm-package" ,
120
+ "" ,
116
121
),
117
122
(
118
123
"python" ,
@@ -122,8 +127,9 @@ def test_write_git_commits(
122
127
("12" , "1" ),
123
128
"-cu12" ,
124
129
"cuda-python>=12.1,<12.2.dev0" ,
125
- "" ,
130
+ "x86_64 " ,
126
131
"some-x86-package" ,
132
+ "" ,
127
133
),
128
134
(
129
135
"." ,
@@ -133,8 +139,11 @@ def test_write_git_commits(
133
139
("11" , "5" ),
134
140
"-cu11" ,
135
141
None ,
142
+ None , # Test the arch detection logic
143
+ "some-x86-package"
144
+ if platform .machine () == "x86_64"
145
+ else "some-arm-package" ,
136
146
"" ,
137
- None ,
138
147
),
139
148
(
140
149
"." ,
@@ -144,8 +153,11 @@ def test_write_git_commits(
144
153
None , # Ensure _get_cuda_version() isn't called and unpacked
145
154
"" ,
146
155
"cuda-python" ,
156
+ None , # Test the arch detection logic
157
+ "some-x86-package"
158
+ if platform .machine () == "x86_64"
159
+ else "some-arm-package" ,
147
160
"" ,
148
- "some-x86-package" ,
149
161
),
150
162
],
151
163
)
@@ -158,8 +170,9 @@ def test_edit_pyproject(
158
170
cuda_version ,
159
171
cuda_suffix ,
160
172
cuda_python_requirement ,
161
- matrix ,
173
+ arch ,
162
174
arch_requirement ,
175
+ matrix ,
163
176
):
164
177
original_contents = dedent (
165
178
"""\
@@ -265,6 +278,9 @@ def test_edit_pyproject(
265
278
)
266
279
267
280
with patch (
281
+ "rapids_build_backend.impls._get_arch" ,
282
+ Mock (return_value = arch ) if arch is not None else _get_arch ,
283
+ ), patch (
268
284
"rapids_build_backend.impls._get_cuda_version" ,
269
285
Mock (return_value = cuda_version ),
270
286
), patch (
0 commit comments