Skip to content

Commit 63e1a23

Browse files
Use new DFG API
1 parent afc92fb commit 63e1a23

File tree

3 files changed

+39
-33
lines changed

3 files changed

+39
-33
lines changed

rapids_build_backend/impls.py

+37-31
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,21 @@
88
from functools import lru_cache
99
from importlib import import_module
1010

11+
import rapids_dependency_file_generator
1112
import tomli_w
12-
import yaml
13-
from rapids_dependency_file_generator.cli import generate_matrix
14-
from rapids_dependency_file_generator.constants import default_pyproject_dir
15-
from rapids_dependency_file_generator.rapids_dependency_file_generator import (
16-
get_requested_output_types,
17-
make_dependency_files,
18-
)
1913

2014
from . import utils
2115
from .config import Config
2216

2317

18+
def _parse_matrix(matrix):
19+
if not matrix:
20+
return None
21+
return {
22+
key: [value] for key, value in (item.split("=") for item in matrix.split(";"))
23+
}
24+
25+
2426
@lru_cache
2527
def _get_backend(build_backend):
2628
"""Get the wrapped build backend specified in pyproject.toml."""
@@ -164,33 +166,37 @@ def _edit_pyproject(config):
164166

165167
cuda_version = _get_cuda_version(config.require_cuda)
166168

167-
with open(config.dependencies_file) as f:
168-
parsed_config = yaml.load(f, Loader=yaml.FullLoader)
169-
files = {}
170-
for file_key, file_config in parsed_config["files"].items():
171-
if "pyproject" not in get_requested_output_types(file_config["output"]):
172-
continue
173-
pyproject_dir = os.path.join(
174-
os.path.dirname(config.dependencies_file),
175-
file_config.get("pyproject_dir", default_pyproject_dir),
176-
)
177-
if not os.path.exists(pyproject_dir):
178-
continue
179-
if not os.path.samefile(pyproject_dir, "."):
180-
continue
181-
file_config["output"] = ["pyproject"]
182-
if config.matrix:
183-
file_config["matrix"] = generate_matrix(config.matrix)
184-
if cuda_version is not None:
185-
file_config.setdefault("matrix", {})["cuda"] = [
186-
f"{cuda_version[0]}.{cuda_version[1]}"
187-
]
188-
files[file_key] = file_config
189-
parsed_config["files"] = files
169+
parsed_config = rapids_dependency_file_generator.load_config_from_file(
170+
config.dependencies_file
171+
)
190172

191173
try:
192174
shutil.copyfile(pyproject_file, bkp_pyproject_file)
193-
make_dependency_files(parsed_config, config.dependencies_file, False)
175+
for file_key, file_config in parsed_config.files.items():
176+
if (
177+
rapids_dependency_file_generator.Output.PYPROJECT
178+
not in file_config.output
179+
):
180+
continue
181+
pyproject_dir = os.path.join(
182+
os.path.dirname(config.dependencies_file),
183+
file_config.pyproject_dir,
184+
)
185+
if not os.path.exists(pyproject_dir):
186+
continue
187+
if not os.path.samefile(pyproject_dir, "."):
188+
continue
189+
matrix = _parse_matrix(config.matrix) or dict(file_config.matrix)
190+
if cuda_version is not None:
191+
matrix["cuda"] = [f"{cuda_version[0]}.{cuda_version[1]}"]
192+
rapids_dependency_file_generator.make_dependency_files(
193+
parsed_config,
194+
[file_key],
195+
{rapids_dependency_file_generator.Output.PYPROJECT},
196+
matrix,
197+
[],
198+
False,
199+
)
194200
pyproject = utils._get_pyproject()
195201
project_data = pyproject["project"]
196202
project_data["name"] += _get_cuda_suffix(config.require_cuda)

tests/test_impls.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def test_edit_pyproject(
136136
- arch
137137
pyproject_dir: {pyproject_dir}
138138
matrix:
139-
cuda: ["11.5", "12.1"]
139+
cuda: ["11.5"]
140140
arch: ["x86_64"]
141141
extras:
142142
table: project

tests/test_packages.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,4 +114,4 @@ def test_simple_scikit_build_core(tmp_path, env, nvcc_version):
114114
if nvcc_version == "11":
115115
assert extras == {"jit": {"ptxcompiler-cu11"}}
116116
else:
117-
assert extras == {}
117+
assert extras == {"jit": set()}

0 commit comments

Comments
 (0)