Skip to content

Commit

Permalink
Fix Enum.Value identity under pickling. (#2597)
Browse files Browse the repository at this point in the history
  • Loading branch information
jsirois authored Nov 15, 2024
1 parent a430f09 commit 93f6430
Show file tree
Hide file tree
Showing 38 changed files with 487 additions and 34 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
- name: Noop
run: "true"
checks:
name: tox -e format-check,lint-check,typecheck,vendor-check,package,docs
name: tox -e format-check,lint-check,typecheck,enum-check,vendor-check,package,docs
needs: org-check
runs-on: ubuntu-24.04
steps:
Expand All @@ -42,6 +42,9 @@ jobs:
uses: pantsbuild/actions/run-tox@b16b9cf47cd566acfe217b1dafc5b452e27e6fd7
with:
tox-env: format-check,lint-check,typecheck
- name: Check Enum Types
run: |
BASE_MODE=pull ./dtox.sh -e enum-check -- -v --require-py27
- name: Check Vendoring
uses: pantsbuild/actions/run-tox@b16b9cf47cd566acfe217b1dafc5b452e27e6fd7
with:
Expand Down
3 changes: 3 additions & 0 deletions pex/atomic_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ class Value(Enum.Value):
POSIX = Value("posix")


FileLockStyle.seal()


def _is_bsd_lock(lock_style=None):
# type: (Optional[FileLockStyle.Value]) -> bool

Expand Down
3 changes: 3 additions & 0 deletions pex/bin/pex.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,9 @@ class Value(Enum.Value):
VERBOSE = Value("verbose")


Seed.seal()


class HandleSeedAction(Action):
def __init__(self, *args, **kwargs):
kwargs["nargs"] = "?"
Expand Down
2 changes: 2 additions & 0 deletions pex/cache/dirs.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ def iter_transitive_dependents(self):
)


CacheDir.seal()

if TYPE_CHECKING:
_AtomicCacheDir = TypeVar("_AtomicCacheDir", bound="AtomicCacheDir")

Expand Down
3 changes: 3 additions & 0 deletions pex/cli/commands/cache/bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def render(self, total_bytes):
PB = Value("PB", 1000 * TB.multiple)


ByteUnits.seal()


@attr.s(frozen=True)
class ByteAmount(object):
@classmethod
Expand Down
19 changes: 15 additions & 4 deletions pex/cli/commands/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ class Value(Enum.Value):
ERROR = Value("error")


FingerprintMismatch.seal()


class ExportFormat(Enum["ExportFormat.Value"]):
class Value(Enum.Value):
pass
Expand All @@ -98,6 +101,9 @@ class Value(Enum.Value):
PEP_665 = Value("pep-665")


ExportFormat.seal()


class ExportSortBy(Enum["ExportSortBy.Value"]):
class Value(Enum.Value):
pass
Expand All @@ -106,6 +112,9 @@ class Value(Enum.Value):
PROJECT_NAME = Value("project-name")


ExportSortBy.seal()


class DryRunStyle(Enum["DryRunStyle.Value"]):
class Value(Enum.Value):
pass
Expand All @@ -114,6 +123,9 @@ class Value(Enum.Value):
CHECK = Value("check")


DryRunStyle.seal()


class HandleDryRunAction(Action):
def __init__(self, *args, **kwargs):
kwargs["nargs"] = "?"
Expand Down Expand Up @@ -1270,10 +1282,9 @@ def _process_lock_update(
project_name not in requirements_by_project_name,
"Deletes should have been unconditionally removed from requirements "
"earlier. Found deleted project {project_name} in updated requirements:\n"
"{requirements}".format(
project_name=project_name,
requirements="\n".join(map(str, requirements_by_project_name.values())),
),
"{requirements}",
project_name=project_name,
requirements="\n".join(map(str, requirements_by_project_name.values())),
)
constraints_by_project_name.pop(project_name, None)
elif isinstance(update, VersionUpdate):
Expand Down
3 changes: 3 additions & 0 deletions pex/cli/commands/venv.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class Value(Enum.Value):
FLAT_ZIPPED = Value("flat-zipped")


InstallLayout.seal()


class Venv(OutputMixin, JsonMixin, BuildTimeCommand):
@classmethod
def _add_inspect_arguments(cls, parser):
Expand Down
3 changes: 3 additions & 0 deletions pex/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,9 @@ class Value(Enum.Value):
SYMLINK = Value("symlink")


CopyMode.seal()


def iter_copytree(
src, # type: Text
dst, # type: Text
Expand Down
6 changes: 3 additions & 3 deletions pex/dependency_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def configure(
len(root_reqs) > 0,
"The deep --exclude mechanism failed to exclude {dist} from transitive "
"requirements. It should have been excluded by configured excludes: "
"{excludes} but was not.".format(
dist=fingerprinted_dist.distribution, excludes=excludes
),
"{excludes} but was not.",
dist=fingerprinted_dist.distribution,
excludes=excludes,
)
pex_warnings.warn(
"The distribution {dist} was required by the input {requirements} "
Expand Down
6 changes: 6 additions & 0 deletions pex/dist_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ def load_metadata(
PKG_INFO = Value("PKG-INFO")


MetadataType.seal()


@attr.s(frozen=True)
class MetadataKey(object):
metadata_type = attr.ib() # type: MetadataType.Value
Expand Down Expand Up @@ -963,6 +966,9 @@ def of(cls, location):
return cls.SDIST


DistributionType.seal()


@attr.s(frozen=True)
class Distribution(object):
@staticmethod
Expand Down
82 changes: 81 additions & 1 deletion pex/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,34 @@

from __future__ import absolute_import, print_function

import sys
import weakref
from collections import defaultdict
from functools import total_ordering

from _weakref import ReferenceType

from pex.exceptions import production_assert
from pex.typing import TYPE_CHECKING, Generic, cast

if TYPE_CHECKING:
from typing import Any, DefaultDict, List, Optional, Tuple, Type, TypeVar
from typing import Any, DefaultDict, Iterator, List, Optional, Tuple, Type, TypeVar

_V = TypeVar("_V", bound="Enum.Value")


def _get_or_create(
module, # type: str
enum_type, # type: str
enum_value_type, # type: str
enum_value_value, # type: str
):
# type: (...) -> Enum.Value
enum_class = getattr(sys.modules[module], enum_type)
enum_value_class = getattr(enum_class, enum_value_type)
return cast("Enum.Value", enum_value_class._get_or_create(enum_value_value))


class Enum(Generic["_V"]):
@total_ordering
class Value(object):
Expand All @@ -26,11 +40,37 @@ class Value(object):

@classmethod
def _iter_values(cls):
# type: () -> Iterator[Enum.Value]
for ref in cls._values_by_type[cls]:
value = ref()
if value:
yield value

@classmethod
def _get_or_create(cls, value):
# type: (str) -> Enum.Value
for existing_value in cls._iter_values():
if existing_value.value == value:
return existing_value
return cls(value)

def __reduce__(self):
if sys.version_info[0] >= 3:
return self._get_or_create, (self.value,)

# N.B.: Python 2.7 does not handle pickling nested classes; so we go through some
# hoops here and in `Enum.seal`.
module = self.__module__
enum_type = getattr(self, "_enum_type", None)
production_assert(
isinstance(enum_type, str),
"The Enum subclass in the {module} module containing value {self} was not "
"`seal`ed.",
module=module,
self=self,
)
return _get_or_create, (module, enum_type, type(self).__name__, self.value)

def __init__(self, value):
# type: (str) -> None
values = Enum.Value._values_by_type[type(self)]
Expand Down Expand Up @@ -78,6 +118,46 @@ def __le__(self, other):
raise self._create_type_error(other)
return self is other or self < other

@classmethod
def seal(cls):
if sys.version_info[0] >= 3:
return

# N.B.: Python 2.7 does not handle pickling nested classes; so we go through some
# hoops here and in `Enum.Value.__reduce__`.

enum_type_name, _, enum_value_type_name = cls.type_var.partition(".")
if enum_value_type_name:
production_assert(
cls.__name__ == enum_type_name,
"Expected Enum subclass {cls} to have a type parameter of the form `{name}.Value` "
"where `Value` is a subclass of `Enum.Value`. Instead found: {type_var}",
cls=cls,
name=cls.__name__,
type_var=cls.type_var,
)
enum_value_type = getattr(cls, enum_value_type_name, None)
else:
enum_value_type = getattr(sys.modules[cls.__module__], enum_type_name, None)

production_assert(
enum_type_name is not None,
"Failed to find Enum.Value type {type_var} for Enum {cls} in module {module}",
type_var=cls.type_var,
cls=cls,
module=cls.__module__,
)
production_assert(
issubclass(enum_value_type, Enum.Value),
"Expected Enum subclass {cls} to have a type parameter that is a subclass of "
"`Enum.Value`. Instead found {type_var} was of type: {enum_value_type}",
cls=cls,
name=cls.__name__,
type_var=cls.type_var,
enum_value_type=enum_value_type,
)
setattr(enum_value_type, "_enum_type", cls.__name__)

_values = None # type: Optional[Tuple[_V, ...]]

@classmethod
Expand Down
19 changes: 15 additions & 4 deletions pex/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@
import sys
from textwrap import dedent

from pex.typing import TYPE_CHECKING
from pex.version import __version__

if TYPE_CHECKING:
from typing import Any


_ASSERT_DETAILS = (
dedent(
"""\
Expand Down Expand Up @@ -39,10 +44,14 @@
).strip()


def reportable_unexpected_error_msg(msg=""):
# type: (str) -> str
def reportable_unexpected_error_msg(
msg="", # type: str
*args, # type: Any
**kwargs # type: Any
):
# type: (...) -> str

message = [msg, "---", _ASSERT_DETAILS]
message = [msg.format(*args, **kwargs), "---", _ASSERT_DETAILS]
pex = os.environ.get("PEX")
if pex:
try:
Expand All @@ -67,8 +76,10 @@ def reportable_unexpected_error_msg(msg=""):
def production_assert(
condition, # type: bool
msg="", # type: str
*args, # type: Any
**kwargs # type: Any
):
# type: (...) -> None

if not condition:
raise AssertionError(reportable_unexpected_error_msg(msg=msg))
raise AssertionError(reportable_unexpected_error_msg(msg, *args, **kwargs))
10 changes: 4 additions & 6 deletions pex/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,10 @@ def create_ssl_context(self):
# [^7]: https://gitlab.com/redhat-crypto/fedora-crypto-policies/-/merge_requests/110/diffs#269a48e71ac25ad1d07ff00db2390834c8ba7596_11_16
production_assert(
in_main_thread(),
msg=(
"An SSLContext must be initialized from the main thread. An attempt was made to "
"initialize an SSLContext for {cert_config} from thread {thread}.".format(
cert_config=self, thread=threading.current_thread()
)
),
"An SSLContext must be initialized from the main thread. An attempt was made to "
"initialize an SSLContext for {cert_config} from thread {thread}.",
cert_config=self,
thread=threading.current_thread(),
)
with guard_stdout():
# We import ssl lazily as an affordance to PEXes that use gevent SSL monkeypatching,
Expand Down
3 changes: 3 additions & 0 deletions pex/inherit_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,6 @@ def for_value(cls, value):
value, type(value)
)
)


InheritPath.seal()
2 changes: 2 additions & 0 deletions pex/interpreter_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ class Value(Enum.Value):
EOL = Value("eol")


Lifecycle.seal()

# This value is based off of:
# 1. Past releases: https://www.python.org/downloads/ where the max patch level was achieved by
# 2.7.18.
Expand Down
3 changes: 3 additions & 0 deletions pex/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def identify_original(cls, pex):
return cls.Value.try_load(pex) or Layout.LOOSE


Layout.seal()


class _Layout(object):
def __init__(
self,
Expand Down
3 changes: 3 additions & 0 deletions pex/pep_427.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ class Value(Enum.Value):
WHEEL_FILE = Value(".whl file")


InstallableType.seal()


@attr.s(frozen=True)
class InstallPaths(object):

Expand Down
3 changes: 3 additions & 0 deletions pex/pex_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ def perform_check(
ERROR = Value("error")


Check.seal()


class PEXBuilder(object):
"""Helper for building PEX environments."""

Expand Down
Loading

0 comments on commit 93f6430

Please sign in to comment.