Skip to content

Commit

Permalink
Fix type issues and implement lazy value support in base metadata pro…
Browse files Browse the repository at this point in the history
…vider too
  • Loading branch information
zsol committed Jul 8, 2022
1 parent 87ac7f5 commit 1332b1a
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 48 deletions.
16 changes: 9 additions & 7 deletions libcst/_metadata_dependent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from contextlib import contextmanager
from typing import (
Callable,
Generic,
Union,
cast,
ClassVar,
Collection,
Expand All @@ -30,12 +32,12 @@

_T = TypeVar("_T")

_UNDEFINED_DEFAULT = object()

_SENTINEL = object()
class _UNDEFINED_DEFAULT:
pass


class LazyValue:
class LazyValue(Generic[_T]):
"""
The class for implementing a lazy metadata loading mechanism that improves the
performance when retriving expensive metadata (e.g., qualified names). Providers
Expand All @@ -46,12 +48,12 @@ class LazyValue:

def __init__(self, callable: Callable[[], _T]) -> None:
self.callable = callable
self.return_value: object = _SENTINEL
self.return_value: Union[_T, Type[_UNDEFINED_DEFAULT]] = _UNDEFINED_DEFAULT

def __call__(self) -> object:
if self.return_value is _SENTINEL:
def __call__(self) -> _T:
if self.return_value is _UNDEFINED_DEFAULT:
self.return_value = self.callable()
return self.return_value
return cast(_T, self.return_value)


class MetadataDependent(ABC):
Expand Down
11 changes: 3 additions & 8 deletions libcst/codemod/visitors/_apply_type_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from collections import defaultdict
from dataclasses import dataclass
from typing import cast, Collection, Dict, List, Optional, Sequence, Set, Tuple, Union
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union

import libcst as cst
import libcst.matchers as m
Expand All @@ -17,7 +17,7 @@
from libcst.codemod.visitors._gather_imports import GatherImportsVisitor
from libcst.codemod.visitors._imports import ImportItem
from libcst.helpers import get_full_name_for_node
from libcst.metadata import PositionProvider, QualifiedName, QualifiedNameProvider
from libcst.metadata import PositionProvider, QualifiedNameProvider


NameOrAttribute = Union[cst.Name, cst.Attribute]
Expand Down Expand Up @@ -48,12 +48,7 @@ def _get_unique_qualified_name(
visitor: m.MatcherDecoratableVisitor, node: cst.CSTNode
) -> str:
name = None
names = [
q.name
for q in cast(
Collection[QualifiedName], visitor.get_metadata(QualifiedNameProvider, node)
)
]
names = [q.name for q in visitor.get_metadata(QualifiedNameProvider, node)]
if len(names) == 0:
# we hit this branch if the stub is directly using a fully
# qualified name, which is not technically valid python but is
Expand Down
6 changes: 2 additions & 4 deletions libcst/codemod/visitors/_gather_string_annotation_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import libcst.matchers as m
from libcst.codemod._context import CodemodContext
from libcst.codemod._visitor import ContextAwareVisitor
from libcst.metadata import MetadataWrapper, QualifiedName, QualifiedNameProvider
from libcst.metadata import MetadataWrapper, QualifiedNameProvider

FUNCS_CONSIDERED_AS_STRING_ANNOTATIONS = {"typing.TypeVar"}

Expand Down Expand Up @@ -45,9 +45,7 @@ def leave_Annotation(self, original_node: cst.Annotation) -> None:
self._annotation_stack.pop()

def visit_Call(self, node: cst.Call) -> bool:
qnames = cast(
Collection[QualifiedName], self.get_metadata(QualifiedNameProvider, node)
)
qnames = self.get_metadata(QualifiedNameProvider, node)
if any(qn.name in self._typing_functions for qn in qnames):
self._annotation_stack.append(node)
return True
Expand Down
29 changes: 16 additions & 13 deletions libcst/metadata/base_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from types import MappingProxyType
from typing import (
Callable,
cast,
Union,
Generic,
List,
Mapping,
Expand All @@ -23,6 +23,7 @@
_T as _MetadataT,
_UNDEFINED_DEFAULT,
MetadataDependent,
LazyValue,
)
from libcst._visitors import CSTVisitor

Expand All @@ -36,6 +37,7 @@
# BaseMetadataProvider[int] would be a subtype of BaseMetadataProvider[object], so the
# typevar is covariant.
_ProvidedMetadataT = TypeVar("_ProvidedMetadataT", covariant=True)
MaybeLazyMetadataT = Union[LazyValue[_ProvidedMetadataT], _ProvidedMetadataT]


# We can't use an ABCMeta here, because of metaclass conflicts
Expand All @@ -52,16 +54,16 @@ class BaseMetadataProvider(MetadataDependent, Generic[_ProvidedMetadataT]):
#
# N.B. This has some typing variance problems. See `set_metadata` for an
# explanation.
_computed: MutableMapping["CSTNode", _ProvidedMetadataT]
_computed: MutableMapping["CSTNode", MaybeLazyMetadataT]

#: Implement gen_cache to indicate the matadata provider depends on cache from external
#: Implement gen_cache to indicate the metadata provider depends on cache from external
#: system. This function will be called by :class:`~libcst.metadata.FullRepoManager`
#: to compute required cache object per file path.
gen_cache: Optional[Callable[[Path, List[str], int], Mapping[str, object]]] = None

def __init__(self, cache: object = None) -> None:
super().__init__()
self._computed = {}
self._computed: MutableMapping["CSTNode", MaybeLazyMetadataT] = {}
if self.gen_cache and cache is None:
# The metadata provider implementation is responsible to store and use cache.
raise Exception(
Expand All @@ -71,7 +73,7 @@ def __init__(self, cache: object = None) -> None:

def _gen(
self, wrapper: "MetadataWrapper"
) -> Mapping["CSTNode", _ProvidedMetadataT]:
) -> Mapping["CSTNode", MaybeLazyMetadataT]:
"""
Resolves and returns metadata mapping for the module in ``wrapper``.
Expand All @@ -93,11 +95,7 @@ def _gen_impl(self, module: "Module") -> None:
"""
...

# pyre-ignore[46]: The covariant `value` isn't type-safe because we write it to
# pyre: `self._computed`, however we assume that only one subclass in the MRO chain
# pyre: will ever call `set_metadata`, so it's okay for our purposes. There's no
# pyre: sane way to redesign this API so that it doesn't have this problem.
def set_metadata(self, node: "CSTNode", value: _ProvidedMetadataT) -> None:
def set_metadata(self, node: "CSTNode", value: MaybeLazyMetadataT) -> None:
"""
Record a metadata value ``value`` for ``node``.
"""
Expand All @@ -107,7 +105,9 @@ def get_metadata(
self,
key: Type["BaseMetadataProvider[_MetadataT]"],
node: "CSTNode",
default: _MetadataT = _UNDEFINED_DEFAULT,
default: Union[
MaybeLazyMetadataT, Type[_UNDEFINED_DEFAULT]
] = _UNDEFINED_DEFAULT,
) -> _MetadataT:
"""
The same method as :func:`~libcst.MetadataDependent.get_metadata` except
Expand All @@ -116,9 +116,12 @@ def get_metadata(
"""
if key is type(self):
if default is not _UNDEFINED_DEFAULT:
return cast(_MetadataT, self._computed.get(node, default))
ret = self._computed.get(node, default)
else:
return cast(_MetadataT, self._computed[node])
ret = self._computed[node]
if isinstance(ret, LazyValue):
return ret()
return ret

return super().get_metadata(key, node, default)

Expand Down
16 changes: 4 additions & 12 deletions libcst/metadata/name_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import dataclasses
from pathlib import Path
from typing import cast, Collection, List, Mapping, Optional, Union
from typing import Collection, List, Mapping, Optional, Union

import libcst as cst
from libcst._metadata_dependent import LazyValue, MetadataDependent
Expand All @@ -17,10 +17,8 @@
ScopeProvider,
)

_UNDEFINED_DEFAULT = object


class QualifiedNameProvider(BatchableMetadataProvider[_UNDEFINED_DEFAULT]):
class QualifiedNameProvider(BatchableMetadataProvider[Collection[QualifiedName]]):
"""
Compute possible qualified names of a variable CSTNode
(extends `PEP-3155 <https://www.python.org/dev/peps/pep-3155/>`_).
Expand Down Expand Up @@ -66,10 +64,7 @@ def has_name(
visitor: MetadataDependent, node: cst.CSTNode, name: Union[str, QualifiedName]
) -> bool:
"""Check if any of qualified name has the str name or :class:`~libcst.metadata.QualifiedName` name."""
qualified_names = cast(
Collection[QualifiedName],
visitor.get_metadata(QualifiedNameProvider, node, set()),
)
qualified_names = visitor.get_metadata(QualifiedNameProvider, node, set())
if isinstance(name, str):
return any(qn.name == name for qn in qualified_names)
else:
Expand Down Expand Up @@ -178,10 +173,7 @@ def __init__(
self.provider = provider

def on_visit(self, node: cst.CSTNode) -> bool:
qnames = cast(
Collection[QualifiedName],
self.provider.get_metadata(QualifiedNameProvider, node),
)
qnames = self.provider.get_metadata(QualifiedNameProvider, node)
if qnames is not None:
self.provider.set_metadata(
node,
Expand Down
6 changes: 2 additions & 4 deletions libcst/metadata/tests/test_name_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pathlib import Path
from tempfile import TemporaryDirectory
from textwrap import dedent
from typing import cast, Collection, Dict, Mapping, Optional, Set, Tuple
from typing import Collection, Dict, Mapping, Optional, Set, Tuple

import libcst as cst
from libcst import ensure_type
Expand All @@ -31,9 +31,7 @@ def __init__(self) -> None:
self.qnames: Dict["CSTNode", Collection[QualifiedName]] = {}

def on_visit(self, node: cst.CSTNode) -> bool:
qname = cast(
Collection[QualifiedName], self.get_metadata(QualifiedNameProvider, node)
)
qname = self.get_metadata(QualifiedNameProvider, node)
self.qnames[node] = qname
return True

Expand Down

0 comments on commit 1332b1a

Please sign in to comment.