diff --git a/src/python/pants/backend/python/target_types.py b/src/python/pants/backend/python/target_types.py index 009646dfee4..4e799ef8568 100644 --- a/src/python/pants/backend/python/target_types.py +++ b/src/python/pants/backend/python/target_types.py @@ -53,6 +53,7 @@ StringSequenceField, Target, TriBoolField, + ValidNumbers, generate_file_based_overrides_field_help_message, ) from pants.option.subsystem import Subsystem @@ -515,18 +516,11 @@ class PythonTestsTimeout(IntField): alias = "timeout" help = ( "A timeout (in seconds) used by each test file belonging to this target.\n\n" - "This only applies if the option `--pytest-timeouts` is set to True." + "If unset, will default to `[pytest].timeout_default`; if that option is also unset, " + "then the test will never time out. Will never exceed `[pytest].timeout_maximum`. Only " + "applies if the option `--pytest-timeouts` is set to true (the default)." ) - - @classmethod - def compute_value(cls, raw_value: Optional[int], address: Address) -> Optional[int]: - value = super().compute_value(raw_value, address) - if value is not None and value < 1: - raise InvalidFieldException( - f"The value for the `timeout` field in target {address} must be > 0, but was " - f"{value}." - ) - return value + valid_numbers = ValidNumbers.positive_only def calculate_from_global_options(self, pytest: PyTest) -> Optional[int]: """Determine the timeout (in seconds) after applying global `pytest` options.""" diff --git a/src/python/pants/backend/python/target_types_test.py b/src/python/pants/backend/python/target_types_test.py index af14e3f406d..ac6f06a7483 100644 --- a/src/python/pants/backend/python/target_types_test.py +++ b/src/python/pants/backend/python/target_types_test.py @@ -65,14 +65,6 @@ from pants.util.frozendict import FrozenDict -def test_timeout_validation() -> None: - with pytest.raises(InvalidFieldException): - PythonTestsTimeout(-100, Address("", target_name="tests")) - with pytest.raises(InvalidFieldException): - PythonTestsTimeout(0, Address("", target_name="tests")) - assert PythonTestsTimeout(5, Address("", target_name="tests")).value == 5 - - def test_pex_binary_validation() -> None: def create_tgt(*, script: str | None = None, entry_point: str | None = None) -> PexBinary: return PexBinary( diff --git a/src/python/pants/backend/shell/target_types.py b/src/python/pants/backend/shell/target_types.py index b2750182349..d1507a0ed4a 100644 --- a/src/python/pants/backend/shell/target_types.py +++ b/src/python/pants/backend/shell/target_types.py @@ -6,11 +6,9 @@ import re from enum import Enum from textwrap import dedent -from typing import Optional from pants.backend.shell.shell_setup import ShellSetup from pants.core.goals.test import RuntimePackageDependenciesField -from pants.engine.addresses import Address from pants.engine.fs import PathGlobs, Paths from pants.engine.process import BinaryPathTest from pants.engine.rules import Get, MultiGet, collect_rules, rule @@ -21,7 +19,6 @@ GeneratedTargets, GenerateTargetsRequest, IntField, - InvalidFieldException, MultipleSourcesField, OverridesField, SingleSourceField, @@ -30,6 +27,7 @@ StringField, StringSequenceField, Target, + ValidNumbers, generate_file_based_overrides_field_help_message, generate_file_level_targets, ) @@ -102,19 +100,10 @@ class Shunit2TestDependenciesField(Dependencies): class Shunit2TestTimeoutField(IntField): alias = "timeout" help = ( - "A timeout (in seconds) used by each test file belonging to this target. " + "A timeout (in seconds) used by each test file belonging to this target.\n\n" "If unset, the test will never time out." ) - - @classmethod - def compute_value(cls, raw_value: Optional[int], address: Address) -> Optional[int]: - value = super().compute_value(raw_value, address) - if value is not None and value < 1: - raise InvalidFieldException( - f"The value for the `timeout` field in target {address} must be > 0, but was " - f"{value}." - ) - return value + valid_numbers = ValidNumbers.positive_only class Shunit2TestSourceField(ShellSourceField): @@ -328,16 +317,7 @@ class ShellCommandTimeoutField(IntField): alias = "timeout" default = 30 help = "Command execution timeout (in seconds)." - - @classmethod - def compute_value(cls, raw_value: Optional[int], address: Address) -> Optional[int]: - value = super().compute_value(raw_value, address) - if value is not None and value < 1: - raise InvalidFieldException( - f"The value for the `timeout` field in target {address} must be > 0, but was " - f"{value}." - ) - return value + valid_numbers = ValidNumbers.positive_only class ShellCommandToolsField(StringSequenceField): diff --git a/src/python/pants/engine/target.py b/src/python/pants/engine/target.py index 96a3618246e..bce2431d7e5 100644 --- a/src/python/pants/engine/target.py +++ b/src/python/pants/engine/target.py @@ -4,6 +4,7 @@ from __future__ import annotations import collections.abc +import enum import itertools import logging import os.path @@ -1300,22 +1301,52 @@ def compute_value(cls, raw_value: Optional[bool], address: Address) -> Optional[ return super().compute_value(raw_value, address) +class ValidNumbers(Enum): + """What range of numbers are allowed for IntField and FloatField.""" + + positive_only = enum.auto() + positive_and_zero = enum.auto() + all = enum.auto() + + def validate(self, num: float | int | None, alias: str, address: Address) -> None: + if num is None or self == self.all: # type: ignore[comparison-overlap] + return + if self == self.positive_and_zero: # type: ignore[comparison-overlap] + if num < 0: + raise InvalidFieldException( + f"The {repr(alias)} field in target {address} must be greater than or equal to " + f"zero, but was set to `{num}`." + ) + return + if num <= 0: + raise InvalidFieldException( + f"The {repr(alias)} field in target {address} must be greater than zero, but was " + f"set to `{num}`." + ) + + class IntField(ScalarField[int]): expected_type = int expected_type_description = "an integer" + valid_numbers: ClassVar[ValidNumbers] = ValidNumbers.all @classmethod def compute_value(cls, raw_value: Optional[int], address: Address) -> Optional[int]: - return super().compute_value(raw_value, address) + value_or_default = super().compute_value(raw_value, address) + cls.valid_numbers.validate(value_or_default, cls.alias, address) + return value_or_default class FloatField(ScalarField[float]): expected_type = float expected_type_description = "a float" + valid_numbers: ClassVar[ValidNumbers] = ValidNumbers.all @classmethod def compute_value(cls, raw_value: Optional[float], address: Address) -> Optional[float]: - return super().compute_value(raw_value, address) + value_or_default = super().compute_value(raw_value, address) + cls.valid_numbers.validate(value_or_default, cls.alias, address) + return value_or_default class StringField(ScalarField[str]): diff --git a/src/python/pants/engine/target_test.py b/src/python/pants/engine/target_test.py index 8a418a467ca..52c4aff0122 100644 --- a/src/python/pants/engine/target_test.py +++ b/src/python/pants/engine/target_test.py @@ -19,6 +19,7 @@ ExplicitlyProvidedDependencies, Field, FieldSet, + FloatField, GeneratedTargets, GenerateSourcesRequest, IntField, @@ -38,6 +39,7 @@ StringSequenceField, Tags, Target, + ValidNumbers, generate_file_level_targets, targets_with_sources_types, ) @@ -810,6 +812,41 @@ class GivenEnum(StringField): GivenEnum("carrot", addr) +@pytest.mark.parametrize("field_cls", [IntField, FloatField]) +def test_int_float_fields_valid_numbers(field_cls: type) -> None: + class AllNums(field_cls): # type: ignore[valid-type,misc] + alias = "all_nums" + valid_numbers = ValidNumbers.all + + class PositiveAndZero(field_cls): # type: ignore[valid-type,misc] + alias = "positive_and_zero" + valid_numbers = ValidNumbers.positive_and_zero + + class PositiveOnly(field_cls): # type: ignore[valid-type,misc] + alias = "positive_only" + valid_numbers = ValidNumbers.positive_only + + addr = Address("nums") + neg = -1 if issubclass(field_cls, IntField) else -1.0 + zero = 0 if issubclass(field_cls, IntField) else 0.0 + pos = 1 if issubclass(field_cls, IntField) else 1.0 + + assert AllNums(neg, addr).value == neg + assert AllNums(zero, addr).value == zero + assert AllNums(pos, addr).value == pos + + with pytest.raises(InvalidFieldException): + PositiveAndZero(neg, addr) + assert PositiveAndZero(zero, addr).value == zero + assert PositiveAndZero(pos, addr).value == pos + + with pytest.raises(InvalidFieldException): + PositiveOnly(neg, addr) + with pytest.raises(InvalidFieldException): + PositiveOnly(zero, addr) + assert PositiveOnly(pos, addr).value == pos + + def test_sequence_field() -> None: @dataclass(frozen=True) class CustomObject: