Skip to content

Commit

Permalink
IntField and FloatField can ban negative numbers and zero (#13705)
Browse files Browse the repository at this point in the history
Prework for #13530. 

[ci skip-rust]
[ci skip-build-wheels]
  • Loading branch information
Eric-Arellano authored Nov 24, 2021
1 parent b9f97e5 commit 2704e3f
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 45 deletions.
16 changes: 5 additions & 11 deletions src/python/pants/backend/python/target_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
StringSequenceField,
Target,
TriBoolField,
ValidNumbers,
generate_file_based_overrides_field_help_message,
)
from pants.option.subsystem import Subsystem
Expand Down Expand Up @@ -515,18 +516,11 @@ class PythonTestsTimeout(IntField):
alias = "timeout"
help = (
"A timeout (in seconds) used by each test file belonging to this target.\n\n"
"This only applies if the option `--pytest-timeouts` is set to True."
"If unset, will default to `[pytest].timeout_default`; if that option is also unset, "
"then the test will never time out. Will never exceed `[pytest].timeout_maximum`. Only "
"applies if the option `--pytest-timeouts` is set to true (the default)."
)

@classmethod
def compute_value(cls, raw_value: Optional[int], address: Address) -> Optional[int]:
value = super().compute_value(raw_value, address)
if value is not None and value < 1:
raise InvalidFieldException(
f"The value for the `timeout` field in target {address} must be > 0, but was "
f"{value}."
)
return value
valid_numbers = ValidNumbers.positive_only

def calculate_from_global_options(self, pytest: PyTest) -> Optional[int]:
"""Determine the timeout (in seconds) after applying global `pytest` options."""
Expand Down
8 changes: 0 additions & 8 deletions src/python/pants/backend/python/target_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,6 @@
from pants.util.frozendict import FrozenDict


def test_timeout_validation() -> None:
with pytest.raises(InvalidFieldException):
PythonTestsTimeout(-100, Address("", target_name="tests"))
with pytest.raises(InvalidFieldException):
PythonTestsTimeout(0, Address("", target_name="tests"))
assert PythonTestsTimeout(5, Address("", target_name="tests")).value == 5


def test_pex_binary_validation() -> None:
def create_tgt(*, script: str | None = None, entry_point: str | None = None) -> PexBinary:
return PexBinary(
Expand Down
28 changes: 4 additions & 24 deletions src/python/pants/backend/shell/target_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@
import re
from enum import Enum
from textwrap import dedent
from typing import Optional

from pants.backend.shell.shell_setup import ShellSetup
from pants.core.goals.test import RuntimePackageDependenciesField
from pants.engine.addresses import Address
from pants.engine.fs import PathGlobs, Paths
from pants.engine.process import BinaryPathTest
from pants.engine.rules import Get, MultiGet, collect_rules, rule
Expand All @@ -21,7 +19,6 @@
GeneratedTargets,
GenerateTargetsRequest,
IntField,
InvalidFieldException,
MultipleSourcesField,
OverridesField,
SingleSourceField,
Expand All @@ -30,6 +27,7 @@
StringField,
StringSequenceField,
Target,
ValidNumbers,
generate_file_based_overrides_field_help_message,
generate_file_level_targets,
)
Expand Down Expand Up @@ -102,19 +100,10 @@ class Shunit2TestDependenciesField(Dependencies):
class Shunit2TestTimeoutField(IntField):
alias = "timeout"
help = (
"A timeout (in seconds) used by each test file belonging to this target. "
"A timeout (in seconds) used by each test file belonging to this target.\n\n"
"If unset, the test will never time out."
)

@classmethod
def compute_value(cls, raw_value: Optional[int], address: Address) -> Optional[int]:
value = super().compute_value(raw_value, address)
if value is not None and value < 1:
raise InvalidFieldException(
f"The value for the `timeout` field in target {address} must be > 0, but was "
f"{value}."
)
return value
valid_numbers = ValidNumbers.positive_only


class Shunit2TestSourceField(ShellSourceField):
Expand Down Expand Up @@ -328,16 +317,7 @@ class ShellCommandTimeoutField(IntField):
alias = "timeout"
default = 30
help = "Command execution timeout (in seconds)."

@classmethod
def compute_value(cls, raw_value: Optional[int], address: Address) -> Optional[int]:
value = super().compute_value(raw_value, address)
if value is not None and value < 1:
raise InvalidFieldException(
f"The value for the `timeout` field in target {address} must be > 0, but was "
f"{value}."
)
return value
valid_numbers = ValidNumbers.positive_only


class ShellCommandToolsField(StringSequenceField):
Expand Down
35 changes: 33 additions & 2 deletions src/python/pants/engine/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import collections.abc
import enum
import itertools
import logging
import os.path
Expand Down Expand Up @@ -1300,22 +1301,52 @@ def compute_value(cls, raw_value: Optional[bool], address: Address) -> Optional[
return super().compute_value(raw_value, address)


class ValidNumbers(Enum):
"""What range of numbers are allowed for IntField and FloatField."""

positive_only = enum.auto()
positive_and_zero = enum.auto()
all = enum.auto()

def validate(self, num: float | int | None, alias: str, address: Address) -> None:
if num is None or self == self.all: # type: ignore[comparison-overlap]
return
if self == self.positive_and_zero: # type: ignore[comparison-overlap]
if num < 0:
raise InvalidFieldException(
f"The {repr(alias)} field in target {address} must be greater than or equal to "
f"zero, but was set to `{num}`."
)
return
if num <= 0:
raise InvalidFieldException(
f"The {repr(alias)} field in target {address} must be greater than zero, but was "
f"set to `{num}`."
)


class IntField(ScalarField[int]):
expected_type = int
expected_type_description = "an integer"
valid_numbers: ClassVar[ValidNumbers] = ValidNumbers.all

@classmethod
def compute_value(cls, raw_value: Optional[int], address: Address) -> Optional[int]:
return super().compute_value(raw_value, address)
value_or_default = super().compute_value(raw_value, address)
cls.valid_numbers.validate(value_or_default, cls.alias, address)
return value_or_default


class FloatField(ScalarField[float]):
expected_type = float
expected_type_description = "a float"
valid_numbers: ClassVar[ValidNumbers] = ValidNumbers.all

@classmethod
def compute_value(cls, raw_value: Optional[float], address: Address) -> Optional[float]:
return super().compute_value(raw_value, address)
value_or_default = super().compute_value(raw_value, address)
cls.valid_numbers.validate(value_or_default, cls.alias, address)
return value_or_default


class StringField(ScalarField[str]):
Expand Down
37 changes: 37 additions & 0 deletions src/python/pants/engine/target_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ExplicitlyProvidedDependencies,
Field,
FieldSet,
FloatField,
GeneratedTargets,
GenerateSourcesRequest,
IntField,
Expand All @@ -38,6 +39,7 @@
StringSequenceField,
Tags,
Target,
ValidNumbers,
generate_file_level_targets,
targets_with_sources_types,
)
Expand Down Expand Up @@ -810,6 +812,41 @@ class GivenEnum(StringField):
GivenEnum("carrot", addr)


@pytest.mark.parametrize("field_cls", [IntField, FloatField])
def test_int_float_fields_valid_numbers(field_cls: type) -> None:
class AllNums(field_cls): # type: ignore[valid-type,misc]
alias = "all_nums"
valid_numbers = ValidNumbers.all

class PositiveAndZero(field_cls): # type: ignore[valid-type,misc]
alias = "positive_and_zero"
valid_numbers = ValidNumbers.positive_and_zero

class PositiveOnly(field_cls): # type: ignore[valid-type,misc]
alias = "positive_only"
valid_numbers = ValidNumbers.positive_only

addr = Address("nums")
neg = -1 if issubclass(field_cls, IntField) else -1.0
zero = 0 if issubclass(field_cls, IntField) else 0.0
pos = 1 if issubclass(field_cls, IntField) else 1.0

assert AllNums(neg, addr).value == neg
assert AllNums(zero, addr).value == zero
assert AllNums(pos, addr).value == pos

with pytest.raises(InvalidFieldException):
PositiveAndZero(neg, addr)
assert PositiveAndZero(zero, addr).value == zero
assert PositiveAndZero(pos, addr).value == pos

with pytest.raises(InvalidFieldException):
PositiveOnly(neg, addr)
with pytest.raises(InvalidFieldException):
PositiveOnly(zero, addr)
assert PositiveOnly(pos, addr).value == pos


def test_sequence_field() -> None:
@dataclass(frozen=True)
class CustomObject:
Expand Down

0 comments on commit 2704e3f

Please sign in to comment.