Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support adding custom env vars #504

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/html-py-ever/tests/run_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ def main():
count_py, parse_py, select_py = python(filename, "html.parser")
assert count_rs == count_py
print(f"{filename} {count_rs} {parse_rs:6f}s")
print(f"Parse py {parse_py:6f}s {parse_py/parse_rs:6.3f}x")
print(f"Select py {select_py:6f}s {select_py/select_rs:6.3f}x")
print(f"Parse py {parse_py:6f}s {parse_py / parse_rs:6.3f}x")
print(f"Select py {select_py:6f}s {select_py / select_rs:6.3f}x")

if lxml is not None:
count_lxml, parse_lxml, select_lxml = python(filename, "lxml")
assert count_rs == count_lxml
print(f"Parse lxml {parse_lxml:6f}s {parse_lxml/parse_rs:6.3f}x")
print(f"Select lxml {select_lxml:6f}s {select_lxml/select_rs:6.3f}x")
print(f"Parse lxml {parse_lxml:6f}s {parse_lxml / parse_rs:6.3f}x")
print(f"Select lxml {select_lxml:6f}s {select_lxml / select_rs:6.3f}x")


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_crossenv(session: nox.Session):
@nox.session()
def ruff(session: nox.Session):
session.install("ruff")
session.run("ruff", "format", "--check", ".")
session.run("ruff", "format", "--diff", ".")
session.run("ruff", "check", ".")


Expand Down
24 changes: 24 additions & 0 deletions setuptools_rust/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,28 @@
import subprocess
from typing import Optional


class Env:
"""Allow using ``functools.lru_cache`` with an environment variable dictionary.

Dictionaries are unhashable, but ``functools.lru_cache`` needs all parameters to
be hashable, which we solve which a custom ``__hash__``."""

env: Optional[dict[str, str]]

def __init__(self, env: Optional[dict[str, str]]):
self.env = env

def __eq__(self, other: object) -> bool:
if not isinstance(other, Env):
return False
return self.env == other.env

def __hash__(self) -> int:
if self.env is not None:
return hash(tuple(sorted(self.env.items())))
else:
return hash(None)


def format_called_process_error(
Expand Down
42 changes: 20 additions & 22 deletions setuptools_rust/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from setuptools.command.build_ext import get_abi3_suffix
from setuptools.command.install_scripts import install_scripts as CommandInstallScripts

from ._utils import format_called_process_error
from ._utils import format_called_process_error, Env
from .command import RustCommand
from .extension import Binding, RustBin, RustExtension, Strip
from .rustc_info import (
Expand All @@ -45,8 +45,8 @@
from setuptools import Command as CommandBdistWheel # type: ignore[assignment]


def _check_cargo_supports_crate_type_option() -> bool:
version = get_rust_version()
def _check_cargo_supports_crate_type_option(env: Optional[Env]) -> bool:
version = get_rust_version(env)

if version is None:
return False
Expand Down Expand Up @@ -144,10 +144,10 @@ def run_for_extension(self, ext: RustExtension) -> None:
def build_extension(
self, ext: RustExtension, forced_target_triple: Optional[str] = None
) -> List["_BuiltModule"]:
target_triple = self._detect_rust_target(forced_target_triple)
rustc_cfgs = get_rustc_cfgs(target_triple)
target_triple = self._detect_rust_target(forced_target_triple, ext.env)
rustc_cfgs = get_rustc_cfgs(target_triple, ext.env)

env = _prepare_build_environment()
env = _prepare_build_environment(ext.env)

if not os.path.exists(ext.path):
raise FileError(
Expand All @@ -156,7 +156,7 @@ def build_extension(

quiet = self.qbuild or ext.quiet
debug = self._is_debug_build(ext)
use_cargo_crate_type = _check_cargo_supports_crate_type_option()
use_cargo_crate_type = _check_cargo_supports_crate_type_option(ext.env)

package_id = ext.metadata(quiet=quiet)["resolve"]["root"]
if package_id is None:
Expand Down Expand Up @@ -477,7 +477,7 @@ def _py_limited_api(self) -> _PyLimitedApi:
return cast(_PyLimitedApi, bdist_wheel.py_limited_api)

def _detect_rust_target(
self, forced_target_triple: Optional[str] = None
self, forced_target_triple: Optional[str], env: Env
) -> Optional[str]:
assert self.plat_name is not None
if forced_target_triple is not None:
Expand All @@ -486,14 +486,14 @@ def _detect_rust_target(
return forced_target_triple

# Determine local rust target which needs to be "forced" if necessary
local_rust_target = _adjusted_local_rust_target(self.plat_name)
local_rust_target = _adjusted_local_rust_target(self.plat_name, env)

# Match cargo's behaviour of not using an explicit target if the
# target we're compiling for is the host
if (
local_rust_target is not None
# check for None first to avoid calling to rustc if not needed
and local_rust_target != get_rust_host()
and local_rust_target != get_rust_host(env)
):
return local_rust_target

Expand Down Expand Up @@ -609,7 +609,7 @@ def _replace_vendor_with_unknown(target: str) -> Optional[str]:
return "-".join(components)


def _prepare_build_environment() -> Dict[str, str]:
def _prepare_build_environment(env: Env) -> Dict[str, str]:
"""Prepares environment variables to use when executing cargo build."""

base_executable = None
Expand All @@ -625,20 +625,18 @@ def _prepare_build_environment() -> Dict[str, str]:
# executing python interpreter.
bindir = os.path.dirname(executable)

env = os.environ.copy()
env.update(
env_vars = (env.env or os.environ).copy()
env_vars.update(
{
# disables rust's pkg-config seeking for specified packages,
# which causes pythonXX-sys to fall back to detecting the
# interpreter from the path.
"PATH": os.path.join(bindir, os.environ.get("PATH", "")),
"PYTHON_SYS_EXECUTABLE": os.environ.get(
"PYTHON_SYS_EXECUTABLE", executable
),
"PYO3_PYTHON": os.environ.get("PYO3_PYTHON", executable),
"PATH": os.path.join(bindir, env_vars.get("PATH", "")),
"PYTHON_SYS_EXECUTABLE": env_vars.get("PYTHON_SYS_EXECUTABLE", executable),
"PYO3_PYTHON": env_vars.get("PYO3_PYTHON", executable),
}
)
return env
return env_vars


def _is_py_limited_api(
Expand Down Expand Up @@ -692,19 +690,19 @@ def _binding_features(
_PyLimitedApi = Literal["cp37", "cp38", "cp39", "cp310", "cp311", "cp312", True, False]


def _adjusted_local_rust_target(plat_name: str) -> Optional[str]:
def _adjusted_local_rust_target(plat_name: str, env: Env) -> Optional[str]:
"""Returns the local rust target for the given `plat_name`, if it is
necessary to 'force' a specific target for correctness."""

# If we are on a 64-bit machine, but running a 32-bit Python, then
# we'll target a 32-bit Rust build.
if plat_name == "win32":
if get_rustc_cfgs(None).get("target_env") == "gnu":
if get_rustc_cfgs(None, env).get("target_env") == "gnu":
return "i686-pc-windows-gnu"
else:
return "i686-pc-windows-msvc"
elif plat_name == "win-amd64":
if get_rustc_cfgs(None).get("target_env") == "gnu":
if get_rustc_cfgs(None, env).get("target_env") == "gnu":
return "x86_64-pc-windows-gnu"
else:
return "x86_64-pc-windows-msvc"
Expand Down
2 changes: 1 addition & 1 deletion setuptools_rust/clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ def run_for_extension(self, ext: RustExtension) -> None:

# Execute cargo command
try:
subprocess.check_output(args)
subprocess.check_output(args, env=ext.env.env)
except Exception:
pass
10 changes: 9 additions & 1 deletion setuptools_rust/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,16 @@ def run(self) -> None:
return

all_optional = all(ext.optional for ext in self.extensions)
# Use the environment of the first non-optional extension, or the first optional
# extension if there is no non-optional extension.
env = None
for ext in self.extensions:
if ext.env:
env = ext.env
if not ext.optional:
break
try:
version = get_rust_version()
version = get_rust_version(env)
if version is None:
min_version = max( # type: ignore[type-var]
filter(
Expand Down
11 changes: 9 additions & 2 deletions setuptools_rust/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
if TYPE_CHECKING:
from semantic_version import SimpleSpec

from ._utils import format_called_process_error
from ._utils import format_called_process_error, Env


class Binding(IntEnum):
Expand Down Expand Up @@ -112,6 +112,9 @@ class RustExtension:
abort the build process, and instead simply not install the failing
extension.
py_limited_api: Deprecated.
env: Environment variables to use when calling cargo or rustc (``env=``
in ``subprocess.Popen``). setuptools-rust may add additional
variables or modify ``PATH``.
"""

def __init__(
Expand All @@ -131,6 +134,7 @@ def __init__(
native: bool = False,
optional: bool = False,
py_limited_api: Literal["auto", True, False] = "auto",
env: Optional[Dict[str, str]] = None,
):
if isinstance(target, dict):
name = "; ".join("%s=%s" % (key, val) for key, val in target.items())
Expand All @@ -153,6 +157,7 @@ def __init__(
self.script = script
self.optional = optional
self.py_limited_api = py_limited_api
self.env = Env(env)

if native:
warnings.warn(
Expand Down Expand Up @@ -261,7 +266,7 @@ def _metadata(self, cargo: str, quiet: bool) -> "CargoMetadata":
# If not quiet, let stderr be inherited
stderr = subprocess.PIPE if quiet else None
payload = subprocess.check_output(
metadata_command, stderr=stderr, encoding="latin-1"
metadata_command, stderr=stderr, encoding="latin-1", env=self.env.env
)
except subprocess.CalledProcessError as e:
raise SetupError(format_called_process_error(e))
Expand Down Expand Up @@ -319,6 +324,7 @@ def __init__(
debug: Optional[bool] = None,
strip: Strip = Strip.No,
optional: bool = False,
env: Optional[dict[str, str]] = None,
):
super().__init__(
target=target,
Expand All @@ -333,6 +339,7 @@ def __init__(
optional=optional,
strip=strip,
py_limited_api=False,
env=env,
)

def entry_points(self) -> List[str]:
Expand Down
32 changes: 18 additions & 14 deletions setuptools_rust/rustc_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,31 @@
from functools import lru_cache
from typing import Dict, List, NewType, Optional, TYPE_CHECKING

from ._utils import Env

if TYPE_CHECKING:
from semantic_version import Version


def get_rust_version() -> Optional[Version]: # type: ignore[no-any-unimported]
def get_rust_version(env: Optional[Env]) -> Optional[Version]: # type: ignore[no-any-unimported]
try:
# first line of rustc -Vv is something like
# rustc 1.61.0 (fe5b13d68 2022-05-18)
from semantic_version import Version

return Version(_rust_version().split(" ")[1])
return Version(_rust_version(env).split(" ")[1])
except (subprocess.CalledProcessError, OSError):
return None


_HOST_LINE_START = "host: "


def get_rust_host() -> str:
def get_rust_host(env: Optional[Env]) -> str:
# rustc -Vv has a line denoting the host which cargo uses to decide the
# default target, e.g.
# host: aarch64-apple-darwin
for line in _rust_version_verbose().splitlines():
for line in _rust_version_verbose(env).splitlines():
if line.startswith(_HOST_LINE_START):
return line[len(_HOST_LINE_START) :].strip()
raise PlatformError("Could not determine rust host")
Expand All @@ -36,9 +38,9 @@ def get_rust_host() -> str:
RustCfgs = NewType("RustCfgs", Dict[str, Optional[str]])


def get_rustc_cfgs(target_triple: Optional[str]) -> RustCfgs:
def get_rustc_cfgs(target_triple: Optional[str], env: Env) -> RustCfgs:
cfgs = RustCfgs({})
for entry in get_rust_target_info(target_triple):
for entry in get_rust_target_info(target_triple, env):
maybe_split = entry.split("=", maxsplit=1)
if len(maybe_split) == 2:
cfgs[maybe_split[0]] = maybe_split[1].strip('"')
Expand All @@ -49,25 +51,27 @@ def get_rustc_cfgs(target_triple: Optional[str]) -> RustCfgs:


@lru_cache()
def get_rust_target_info(target_triple: Optional[str] = None) -> List[str]:
def get_rust_target_info(target_triple: Optional[str], env: Env) -> List[str]:
cmd = ["rustc", "--print", "cfg"]
if target_triple:
cmd.extend(["--target", target_triple])
output = subprocess.check_output(cmd, text=True)
output = subprocess.check_output(cmd, text=True, env=env.env)
return output.splitlines()


@lru_cache()
def get_rust_target_list() -> List[str]:
output = subprocess.check_output(["rustc", "--print", "target-list"], text=True)
def get_rust_target_list(env: Env) -> List[str]:
output = subprocess.check_output(
["rustc", "--print", "target-list"], text=True, env=env.env
)
return output.splitlines()


@lru_cache()
def _rust_version() -> str:
return subprocess.check_output(["rustc", "-V"], text=True)
def _rust_version(env: Env) -> str:
return subprocess.check_output(["rustc", "-V"], text=True, env=env.env)


@lru_cache()
def _rust_version_verbose() -> str:
return subprocess.check_output(["rustc", "-Vv"], text=True)
def _rust_version_verbose(env: Env) -> str:
return subprocess.check_output(["rustc", "-Vv"], text=True, env=env.env)
Loading
Loading