From 115e8cdd93176b57174c6d57bac00681309bf992 Mon Sep 17 00:00:00 2001 From: Chenguang Zhu Date: Tue, 5 Jul 2022 13:12:01 -0400 Subject: [PATCH 1/9] Implement lazy loading mechanism for expensive metadata providers --- libcst/_metadata_dependent.py | 28 ++++++++++++++++++++++++++-- libcst/metadata/name_provider.py | 4 ++-- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/libcst/_metadata_dependent.py b/libcst/_metadata_dependent.py index 6a768270c..81cd649bd 100644 --- a/libcst/_metadata_dependent.py +++ b/libcst/_metadata_dependent.py @@ -7,6 +7,7 @@ from abc import ABC from contextlib import contextmanager from typing import ( + Callable, cast, ClassVar, Collection, @@ -31,6 +32,26 @@ _UNDEFINED_DEFAULT = object() +_SENTINEL = object() + +class LazyValue: + """ + The class for implementing a lazy metadata loading mechanism that improves the + performance when retriving expensive metadata (e.g., qualified names). Providers + including :class:`~libcst.metadata.QualifiedNameProvider` use this class to load + the metadata of a certain node lazily when calling + :func:`~libcst.MetadataDependent.get_metadata`. + """ + + def __init__(self, callable: Callable) -> None: + self.callable = callable + self.return_value = _SENTINEL + + def __call__(self) -> None: + if self.return_value is _SENTINEL: + self.return_value = self.callable() + return self.return_value + class MetadataDependent(ABC): """ @@ -107,6 +128,9 @@ def get_metadata( ) if default is not _UNDEFINED_DEFAULT: - return cast(_T, self.metadata[key].get(node, default)) + value = self.metadata[key].get(node, default) else: - return cast(_T, self.metadata[key][node]) + value = self.metadata[key][node] + if isinstance(value, LazyValue): + value = value() + return cast(_T, value) diff --git a/libcst/metadata/name_provider.py b/libcst/metadata/name_provider.py index 007535043..aa91b6cf3 100644 --- a/libcst/metadata/name_provider.py +++ b/libcst/metadata/name_provider.py @@ -8,7 +8,7 @@ from typing import Collection, List, Mapping, Optional, Union import libcst as cst -from libcst._metadata_dependent import MetadataDependent +from libcst._metadata_dependent import MetadataDependent, LazyValue from libcst.helpers.module import calculate_module_and_package, ModuleNameAndPackage from libcst.metadata.base_provider import BatchableMetadataProvider from libcst.metadata.scope_provider import ( @@ -78,7 +78,7 @@ def __init__(self, provider: "QualifiedNameProvider") -> None: def on_visit(self, node: cst.CSTNode) -> bool: scope = self.provider.get_metadata(ScopeProvider, node, None) if scope: - self.provider.set_metadata(node, scope.get_qualified_names_for(node)) + self.provider.set_metadata(node, LazyValue(lambda: scope.get_qualified_names_for(node))) else: self.provider.set_metadata(node, set()) super().on_visit(node) From 655147a63e56f9707c0ed812963232d5143acec1 Mon Sep 17 00:00:00 2001 From: Chenguang Zhu Date: Tue, 5 Jul 2022 20:22:01 -0400 Subject: [PATCH 2/9] Fix tests in test_name_providers.py --- libcst/_metadata_dependent.py | 1 + libcst/metadata/name_provider.py | 6 ++++-- libcst/metadata/tests/test_name_provider.py | 18 ++++++++++++++++-- 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/libcst/_metadata_dependent.py b/libcst/_metadata_dependent.py index 81cd649bd..8137c55be 100644 --- a/libcst/_metadata_dependent.py +++ b/libcst/_metadata_dependent.py @@ -34,6 +34,7 @@ _SENTINEL = object() + class LazyValue: """ The class for implementing a lazy metadata loading mechanism that improves the diff --git a/libcst/metadata/name_provider.py b/libcst/metadata/name_provider.py index aa91b6cf3..60d8763e7 100644 --- a/libcst/metadata/name_provider.py +++ b/libcst/metadata/name_provider.py @@ -8,7 +8,7 @@ from typing import Collection, List, Mapping, Optional, Union import libcst as cst -from libcst._metadata_dependent import MetadataDependent, LazyValue +from libcst._metadata_dependent import LazyValue, MetadataDependent from libcst.helpers.module import calculate_module_and_package, ModuleNameAndPackage from libcst.metadata.base_provider import BatchableMetadataProvider from libcst.metadata.scope_provider import ( @@ -78,7 +78,9 @@ def __init__(self, provider: "QualifiedNameProvider") -> None: def on_visit(self, node: cst.CSTNode) -> bool: scope = self.provider.get_metadata(ScopeProvider, node, None) if scope: - self.provider.set_metadata(node, LazyValue(lambda: scope.get_qualified_names_for(node))) + self.provider.set_metadata( + node, LazyValue(lambda: scope.get_qualified_names_for(node)) + ) else: self.provider.set_metadata(node, set()) super().on_visit(node) diff --git a/libcst/metadata/tests/test_name_provider.py b/libcst/metadata/tests/test_name_provider.py index 9b0b409fc..d7f4b4a98 100644 --- a/libcst/metadata/tests/test_name_provider.py +++ b/libcst/metadata/tests/test_name_provider.py @@ -21,12 +21,26 @@ from libcst.metadata.name_provider import FullyQualifiedNameVisitor from libcst.testing.utils import data_provider, UnitTest +class QNameVisitor(cst.CSTVisitor): + + METADATA_DEPENDENCIES = (QualifiedNameProvider,) + + def __init__(self): + self.qnames = {} + + def on_visit(self, node: cst.CSTNode) -> bool: + qname = self.get_metadata(QualifiedNameProvider, node) + self.qnames[node] = qname + return True + def get_qualified_name_metadata_provider( module_str: str, ) -> Tuple[cst.Module, Mapping[cst.CSTNode, Collection[QualifiedName]]]: wrapper = MetadataWrapper(cst.parse_module(dedent(module_str))) - return wrapper.module, wrapper.resolve(QualifiedNameProvider) + visitor = QNameVisitor() + wrapper.visit(visitor) + return wrapper.module, visitor.qnames def get_qualified_names(module_str: str) -> Set[QualifiedName]: @@ -358,7 +372,7 @@ def f(): pass else: import f import a.b as f - + f() """ ) From dbbeb39dfa99917fba68b0c8cb105d73aaac91f1 Mon Sep 17 00:00:00 2001 From: Chenguang Zhu Date: Tue, 5 Jul 2022 23:53:28 -0400 Subject: [PATCH 3/9] Fix type check errors --- libcst/_metadata_dependent.py | 6 +++--- libcst/metadata/tests/test_name_provider.py | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/libcst/_metadata_dependent.py b/libcst/_metadata_dependent.py index 8137c55be..f73925782 100644 --- a/libcst/_metadata_dependent.py +++ b/libcst/_metadata_dependent.py @@ -44,11 +44,11 @@ class LazyValue: :func:`~libcst.MetadataDependent.get_metadata`. """ - def __init__(self, callable: Callable) -> None: + def __init__(self, callable: Callable[[], _T]) -> None: self.callable = callable - self.return_value = _SENTINEL + self.return_value: object = _SENTINEL - def __call__(self) -> None: + def __call__(self) -> object: if self.return_value is _SENTINEL: self.return_value = self.callable() return self.return_value diff --git a/libcst/metadata/tests/test_name_provider.py b/libcst/metadata/tests/test_name_provider.py index d7f4b4a98..ac76afe4c 100644 --- a/libcst/metadata/tests/test_name_provider.py +++ b/libcst/metadata/tests/test_name_provider.py @@ -19,14 +19,15 @@ ) from libcst.metadata.full_repo_manager import FullRepoManager from libcst.metadata.name_provider import FullyQualifiedNameVisitor +from libcst._nodes.base import CSTNode from libcst.testing.utils import data_provider, UnitTest class QNameVisitor(cst.CSTVisitor): METADATA_DEPENDENCIES = (QualifiedNameProvider,) - def __init__(self): - self.qnames = {} + def __init__(self) -> None: + self.qnames: Dict["CSTNode", Collection[QualifiedName]] = {} def on_visit(self, node: cst.CSTNode) -> bool: qname = self.get_metadata(QualifiedNameProvider, node) From 7046c03346f04dbb81800a2080fd251f17b43d50 Mon Sep 17 00:00:00 2001 From: Chenguang Zhu Date: Wed, 6 Jul 2022 13:42:31 -0400 Subject: [PATCH 4/9] Add support for lazy values in metadata matchers --- libcst/matchers/_matcher_base.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/libcst/matchers/_matcher_base.py b/libcst/matchers/_matcher_base.py index 64670be42..8fd45f662 100644 --- a/libcst/matchers/_matcher_base.py +++ b/libcst/matchers/_matcher_base.py @@ -31,7 +31,7 @@ import libcst import libcst.metadata as meta from libcst import FlattenSentinel, MaybeSentinel, RemovalSentinel - +from libcst._metadata_dependent import LazyValue class DoNotCareSentinel(Enum): """ @@ -1544,7 +1544,11 @@ def _fetch(provider: meta.ProviderT, node: libcst.CSTNode) -> object: if provider not in metadata: metadata[provider] = wrapper.resolve(provider) - return metadata.get(provider, {}).get(node, _METADATA_MISSING_SENTINEL) + node_metadata = metadata.get(provider, {}).get(node, _METADATA_MISSING_SENTINEL) + if isinstance(node_metadata, LazyValue): + node_metadata = node_metadata() + + return node_metadata return _fetch From 6d871effa26b11a7811bd46e7aa33b81848b217b Mon Sep 17 00:00:00 2001 From: Chenguang Zhu Date: Wed, 6 Jul 2022 14:15:10 -0400 Subject: [PATCH 5/9] Fix linting errors --- libcst/matchers/_matcher_base.py | 1 + libcst/metadata/tests/test_name_provider.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/libcst/matchers/_matcher_base.py b/libcst/matchers/_matcher_base.py index 8fd45f662..d8f69ec63 100644 --- a/libcst/matchers/_matcher_base.py +++ b/libcst/matchers/_matcher_base.py @@ -33,6 +33,7 @@ from libcst import FlattenSentinel, MaybeSentinel, RemovalSentinel from libcst._metadata_dependent import LazyValue + class DoNotCareSentinel(Enum): """ A sentinel that is used in matcher classes to indicate that a caller diff --git a/libcst/metadata/tests/test_name_provider.py b/libcst/metadata/tests/test_name_provider.py index ac76afe4c..9f3813687 100644 --- a/libcst/metadata/tests/test_name_provider.py +++ b/libcst/metadata/tests/test_name_provider.py @@ -10,6 +10,7 @@ import libcst as cst from libcst import ensure_type +from libcst._nodes.base import CSTNode from libcst.metadata import ( FullyQualifiedNameProvider, MetadataWrapper, @@ -19,9 +20,9 @@ ) from libcst.metadata.full_repo_manager import FullRepoManager from libcst.metadata.name_provider import FullyQualifiedNameVisitor -from libcst._nodes.base import CSTNode from libcst.testing.utils import data_provider, UnitTest + class QNameVisitor(cst.CSTVisitor): METADATA_DEPENDENCIES = (QualifiedNameProvider,) From 87ac7f5b75df5ea645b379fdd5656dff78f22e16 Mon Sep 17 00:00:00 2001 From: Chenguang Zhu Date: Wed, 6 Jul 2022 16:37:54 -0400 Subject: [PATCH 6/9] Fix type errors --- .../codemod/visitors/_apply_type_annotations.py | 11 ++++++++--- .../visitors/_gather_string_annotation_names.py | 6 ++++-- libcst/metadata/name_provider.py | 16 ++++++++++++---- libcst/metadata/tests/test_name_provider.py | 6 ++++-- 4 files changed, 28 insertions(+), 11 deletions(-) diff --git a/libcst/codemod/visitors/_apply_type_annotations.py b/libcst/codemod/visitors/_apply_type_annotations.py index 7811aa631..f821965e2 100644 --- a/libcst/codemod/visitors/_apply_type_annotations.py +++ b/libcst/codemod/visitors/_apply_type_annotations.py @@ -5,7 +5,7 @@ from collections import defaultdict from dataclasses import dataclass -from typing import Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import cast, Collection, Dict, List, Optional, Sequence, Set, Tuple, Union import libcst as cst import libcst.matchers as m @@ -17,7 +17,7 @@ from libcst.codemod.visitors._gather_imports import GatherImportsVisitor from libcst.codemod.visitors._imports import ImportItem from libcst.helpers import get_full_name_for_node -from libcst.metadata import PositionProvider, QualifiedNameProvider +from libcst.metadata import PositionProvider, QualifiedName, QualifiedNameProvider NameOrAttribute = Union[cst.Name, cst.Attribute] @@ -48,7 +48,12 @@ def _get_unique_qualified_name( visitor: m.MatcherDecoratableVisitor, node: cst.CSTNode ) -> str: name = None - names = [q.name for q in visitor.get_metadata(QualifiedNameProvider, node)] + names = [ + q.name + for q in cast( + Collection[QualifiedName], visitor.get_metadata(QualifiedNameProvider, node) + ) + ] if len(names) == 0: # we hit this branch if the stub is directly using a fully # qualified name, which is not technically valid python but is diff --git a/libcst/codemod/visitors/_gather_string_annotation_names.py b/libcst/codemod/visitors/_gather_string_annotation_names.py index 0f1b926b3..02392a938 100644 --- a/libcst/codemod/visitors/_gather_string_annotation_names.py +++ b/libcst/codemod/visitors/_gather_string_annotation_names.py @@ -9,7 +9,7 @@ import libcst.matchers as m from libcst.codemod._context import CodemodContext from libcst.codemod._visitor import ContextAwareVisitor -from libcst.metadata import MetadataWrapper, QualifiedNameProvider +from libcst.metadata import MetadataWrapper, QualifiedName, QualifiedNameProvider FUNCS_CONSIDERED_AS_STRING_ANNOTATIONS = {"typing.TypeVar"} @@ -45,7 +45,9 @@ def leave_Annotation(self, original_node: cst.Annotation) -> None: self._annotation_stack.pop() def visit_Call(self, node: cst.Call) -> bool: - qnames = self.get_metadata(QualifiedNameProvider, node) + qnames = cast( + Collection[QualifiedName], self.get_metadata(QualifiedNameProvider, node) + ) if any(qn.name in self._typing_functions for qn in qnames): self._annotation_stack.append(node) return True diff --git a/libcst/metadata/name_provider.py b/libcst/metadata/name_provider.py index 60d8763e7..571ace833 100644 --- a/libcst/metadata/name_provider.py +++ b/libcst/metadata/name_provider.py @@ -5,7 +5,7 @@ import dataclasses from pathlib import Path -from typing import Collection, List, Mapping, Optional, Union +from typing import cast, Collection, List, Mapping, Optional, Union import libcst as cst from libcst._metadata_dependent import LazyValue, MetadataDependent @@ -17,8 +17,10 @@ ScopeProvider, ) +_UNDEFINED_DEFAULT = object -class QualifiedNameProvider(BatchableMetadataProvider[Collection[QualifiedName]]): + +class QualifiedNameProvider(BatchableMetadataProvider[_UNDEFINED_DEFAULT]): """ Compute possible qualified names of a variable CSTNode (extends `PEP-3155 `_). @@ -64,7 +66,10 @@ def has_name( visitor: MetadataDependent, node: cst.CSTNode, name: Union[str, QualifiedName] ) -> bool: """Check if any of qualified name has the str name or :class:`~libcst.metadata.QualifiedName` name.""" - qualified_names = visitor.get_metadata(QualifiedNameProvider, node, set()) + qualified_names = cast( + Collection[QualifiedName], + visitor.get_metadata(QualifiedNameProvider, node, set()), + ) if isinstance(name, str): return any(qn.name == name for qn in qualified_names) else: @@ -173,7 +178,10 @@ def __init__( self.provider = provider def on_visit(self, node: cst.CSTNode) -> bool: - qnames = self.provider.get_metadata(QualifiedNameProvider, node) + qnames = cast( + Collection[QualifiedName], + self.provider.get_metadata(QualifiedNameProvider, node), + ) if qnames is not None: self.provider.set_metadata( node, diff --git a/libcst/metadata/tests/test_name_provider.py b/libcst/metadata/tests/test_name_provider.py index 9f3813687..9ccbcee38 100644 --- a/libcst/metadata/tests/test_name_provider.py +++ b/libcst/metadata/tests/test_name_provider.py @@ -6,7 +6,7 @@ from pathlib import Path from tempfile import TemporaryDirectory from textwrap import dedent -from typing import Collection, Dict, Mapping, Optional, Set, Tuple +from typing import cast, Collection, Dict, Mapping, Optional, Set, Tuple import libcst as cst from libcst import ensure_type @@ -31,7 +31,9 @@ def __init__(self) -> None: self.qnames: Dict["CSTNode", Collection[QualifiedName]] = {} def on_visit(self, node: cst.CSTNode) -> bool: - qname = self.get_metadata(QualifiedNameProvider, node) + qname = cast( + Collection[QualifiedName], self.get_metadata(QualifiedNameProvider, node) + ) self.qnames[node] = qname return True From 5553a3bb10d0dca9bd0e44366cfffa1259508903 Mon Sep 17 00:00:00 2001 From: Zsolt Dollenstein Date: Fri, 8 Jul 2022 12:35:27 +0100 Subject: [PATCH 7/9] Fix type issues and implement lazy value support in base metadata provider too --- libcst/_metadata_dependent.py | 16 +++++----- .../visitors/_apply_type_annotations.py | 11 ++----- .../_gather_string_annotation_names.py | 6 ++-- libcst/metadata/base_provider.py | 29 ++++++++++--------- libcst/metadata/name_provider.py | 16 +++------- libcst/metadata/tests/test_name_provider.py | 6 ++-- 6 files changed, 36 insertions(+), 48 deletions(-) diff --git a/libcst/_metadata_dependent.py b/libcst/_metadata_dependent.py index f73925782..4faf74727 100644 --- a/libcst/_metadata_dependent.py +++ b/libcst/_metadata_dependent.py @@ -11,11 +11,13 @@ cast, ClassVar, Collection, + Generic, Iterator, Mapping, Type, TYPE_CHECKING, TypeVar, + Union, ) if TYPE_CHECKING: @@ -30,12 +32,12 @@ _T = TypeVar("_T") -_UNDEFINED_DEFAULT = object() -_SENTINEL = object() +class _UNDEFINED_DEFAULT: + pass -class LazyValue: +class LazyValue(Generic[_T]): """ The class for implementing a lazy metadata loading mechanism that improves the performance when retriving expensive metadata (e.g., qualified names). Providers @@ -46,12 +48,12 @@ class LazyValue: def __init__(self, callable: Callable[[], _T]) -> None: self.callable = callable - self.return_value: object = _SENTINEL + self.return_value: Union[_T, Type[_UNDEFINED_DEFAULT]] = _UNDEFINED_DEFAULT - def __call__(self) -> object: - if self.return_value is _SENTINEL: + def __call__(self) -> _T: + if self.return_value is _UNDEFINED_DEFAULT: self.return_value = self.callable() - return self.return_value + return cast(_T, self.return_value) class MetadataDependent(ABC): diff --git a/libcst/codemod/visitors/_apply_type_annotations.py b/libcst/codemod/visitors/_apply_type_annotations.py index f821965e2..7811aa631 100644 --- a/libcst/codemod/visitors/_apply_type_annotations.py +++ b/libcst/codemod/visitors/_apply_type_annotations.py @@ -5,7 +5,7 @@ from collections import defaultdict from dataclasses import dataclass -from typing import cast, Collection, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import Dict, List, Optional, Sequence, Set, Tuple, Union import libcst as cst import libcst.matchers as m @@ -17,7 +17,7 @@ from libcst.codemod.visitors._gather_imports import GatherImportsVisitor from libcst.codemod.visitors._imports import ImportItem from libcst.helpers import get_full_name_for_node -from libcst.metadata import PositionProvider, QualifiedName, QualifiedNameProvider +from libcst.metadata import PositionProvider, QualifiedNameProvider NameOrAttribute = Union[cst.Name, cst.Attribute] @@ -48,12 +48,7 @@ def _get_unique_qualified_name( visitor: m.MatcherDecoratableVisitor, node: cst.CSTNode ) -> str: name = None - names = [ - q.name - for q in cast( - Collection[QualifiedName], visitor.get_metadata(QualifiedNameProvider, node) - ) - ] + names = [q.name for q in visitor.get_metadata(QualifiedNameProvider, node)] if len(names) == 0: # we hit this branch if the stub is directly using a fully # qualified name, which is not technically valid python but is diff --git a/libcst/codemod/visitors/_gather_string_annotation_names.py b/libcst/codemod/visitors/_gather_string_annotation_names.py index 02392a938..0f1b926b3 100644 --- a/libcst/codemod/visitors/_gather_string_annotation_names.py +++ b/libcst/codemod/visitors/_gather_string_annotation_names.py @@ -9,7 +9,7 @@ import libcst.matchers as m from libcst.codemod._context import CodemodContext from libcst.codemod._visitor import ContextAwareVisitor -from libcst.metadata import MetadataWrapper, QualifiedName, QualifiedNameProvider +from libcst.metadata import MetadataWrapper, QualifiedNameProvider FUNCS_CONSIDERED_AS_STRING_ANNOTATIONS = {"typing.TypeVar"} @@ -45,9 +45,7 @@ def leave_Annotation(self, original_node: cst.Annotation) -> None: self._annotation_stack.pop() def visit_Call(self, node: cst.Call) -> bool: - qnames = cast( - Collection[QualifiedName], self.get_metadata(QualifiedNameProvider, node) - ) + qnames = self.get_metadata(QualifiedNameProvider, node) if any(qn.name in self._typing_functions for qn in qnames): self._annotation_stack.append(node) return True diff --git a/libcst/metadata/base_provider.py b/libcst/metadata/base_provider.py index 69af2dcea..1c113f57a 100644 --- a/libcst/metadata/base_provider.py +++ b/libcst/metadata/base_provider.py @@ -7,7 +7,6 @@ from types import MappingProxyType from typing import ( Callable, - cast, Generic, List, Mapping, @@ -16,12 +15,14 @@ Type, TYPE_CHECKING, TypeVar, + Union, ) from libcst._batched_visitor import BatchableCSTVisitor from libcst._metadata_dependent import ( _T as _MetadataT, _UNDEFINED_DEFAULT, + LazyValue, MetadataDependent, ) from libcst._visitors import CSTVisitor @@ -36,6 +37,7 @@ # BaseMetadataProvider[int] would be a subtype of BaseMetadataProvider[object], so the # typevar is covariant. _ProvidedMetadataT = TypeVar("_ProvidedMetadataT", covariant=True) +MaybeLazyMetadataT = Union[LazyValue[_ProvidedMetadataT], _ProvidedMetadataT] # We can't use an ABCMeta here, because of metaclass conflicts @@ -52,16 +54,16 @@ class BaseMetadataProvider(MetadataDependent, Generic[_ProvidedMetadataT]): # # N.B. This has some typing variance problems. See `set_metadata` for an # explanation. - _computed: MutableMapping["CSTNode", _ProvidedMetadataT] + _computed: MutableMapping["CSTNode", MaybeLazyMetadataT] - #: Implement gen_cache to indicate the matadata provider depends on cache from external + #: Implement gen_cache to indicate the metadata provider depends on cache from external #: system. This function will be called by :class:`~libcst.metadata.FullRepoManager` #: to compute required cache object per file path. gen_cache: Optional[Callable[[Path, List[str], int], Mapping[str, object]]] = None def __init__(self, cache: object = None) -> None: super().__init__() - self._computed = {} + self._computed: MutableMapping["CSTNode", MaybeLazyMetadataT] = {} if self.gen_cache and cache is None: # The metadata provider implementation is responsible to store and use cache. raise Exception( @@ -71,7 +73,7 @@ def __init__(self, cache: object = None) -> None: def _gen( self, wrapper: "MetadataWrapper" - ) -> Mapping["CSTNode", _ProvidedMetadataT]: + ) -> Mapping["CSTNode", MaybeLazyMetadataT]: """ Resolves and returns metadata mapping for the module in ``wrapper``. @@ -93,11 +95,7 @@ def _gen_impl(self, module: "Module") -> None: """ ... - # pyre-ignore[46]: The covariant `value` isn't type-safe because we write it to - # pyre: `self._computed`, however we assume that only one subclass in the MRO chain - # pyre: will ever call `set_metadata`, so it's okay for our purposes. There's no - # pyre: sane way to redesign this API so that it doesn't have this problem. - def set_metadata(self, node: "CSTNode", value: _ProvidedMetadataT) -> None: + def set_metadata(self, node: "CSTNode", value: MaybeLazyMetadataT) -> None: """ Record a metadata value ``value`` for ``node``. """ @@ -107,7 +105,9 @@ def get_metadata( self, key: Type["BaseMetadataProvider[_MetadataT]"], node: "CSTNode", - default: _MetadataT = _UNDEFINED_DEFAULT, + default: Union[ + MaybeLazyMetadataT, Type[_UNDEFINED_DEFAULT] + ] = _UNDEFINED_DEFAULT, ) -> _MetadataT: """ The same method as :func:`~libcst.MetadataDependent.get_metadata` except @@ -116,9 +116,12 @@ def get_metadata( """ if key is type(self): if default is not _UNDEFINED_DEFAULT: - return cast(_MetadataT, self._computed.get(node, default)) + ret = self._computed.get(node, default) else: - return cast(_MetadataT, self._computed[node]) + ret = self._computed[node] + if isinstance(ret, LazyValue): + return ret() + return ret return super().get_metadata(key, node, default) diff --git a/libcst/metadata/name_provider.py b/libcst/metadata/name_provider.py index 571ace833..60d8763e7 100644 --- a/libcst/metadata/name_provider.py +++ b/libcst/metadata/name_provider.py @@ -5,7 +5,7 @@ import dataclasses from pathlib import Path -from typing import cast, Collection, List, Mapping, Optional, Union +from typing import Collection, List, Mapping, Optional, Union import libcst as cst from libcst._metadata_dependent import LazyValue, MetadataDependent @@ -17,10 +17,8 @@ ScopeProvider, ) -_UNDEFINED_DEFAULT = object - -class QualifiedNameProvider(BatchableMetadataProvider[_UNDEFINED_DEFAULT]): +class QualifiedNameProvider(BatchableMetadataProvider[Collection[QualifiedName]]): """ Compute possible qualified names of a variable CSTNode (extends `PEP-3155 `_). @@ -66,10 +64,7 @@ def has_name( visitor: MetadataDependent, node: cst.CSTNode, name: Union[str, QualifiedName] ) -> bool: """Check if any of qualified name has the str name or :class:`~libcst.metadata.QualifiedName` name.""" - qualified_names = cast( - Collection[QualifiedName], - visitor.get_metadata(QualifiedNameProvider, node, set()), - ) + qualified_names = visitor.get_metadata(QualifiedNameProvider, node, set()) if isinstance(name, str): return any(qn.name == name for qn in qualified_names) else: @@ -178,10 +173,7 @@ def __init__( self.provider = provider def on_visit(self, node: cst.CSTNode) -> bool: - qnames = cast( - Collection[QualifiedName], - self.provider.get_metadata(QualifiedNameProvider, node), - ) + qnames = self.provider.get_metadata(QualifiedNameProvider, node) if qnames is not None: self.provider.set_metadata( node, diff --git a/libcst/metadata/tests/test_name_provider.py b/libcst/metadata/tests/test_name_provider.py index 9ccbcee38..9f3813687 100644 --- a/libcst/metadata/tests/test_name_provider.py +++ b/libcst/metadata/tests/test_name_provider.py @@ -6,7 +6,7 @@ from pathlib import Path from tempfile import TemporaryDirectory from textwrap import dedent -from typing import cast, Collection, Dict, Mapping, Optional, Set, Tuple +from typing import Collection, Dict, Mapping, Optional, Set, Tuple import libcst as cst from libcst import ensure_type @@ -31,9 +31,7 @@ def __init__(self) -> None: self.qnames: Dict["CSTNode", Collection[QualifiedName]] = {} def on_visit(self, node: cst.CSTNode) -> bool: - qname = cast( - Collection[QualifiedName], self.get_metadata(QualifiedNameProvider, node) - ) + qname = self.get_metadata(QualifiedNameProvider, node) self.qnames[node] = qname return True From f6fb3c18a83c26cc528296c73bc57685ecacd9cb Mon Sep 17 00:00:00 2001 From: Chenguang Zhu Date: Fri, 8 Jul 2022 15:51:39 -0400 Subject: [PATCH 8/9] Add unit tests for BaseMetadataProvider --- libcst/metadata/tests/test_base_provider.py | 61 +++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/libcst/metadata/tests/test_base_provider.py b/libcst/metadata/tests/test_base_provider.py index 0bf4ca512..6e20edf58 100644 --- a/libcst/metadata/tests/test_base_provider.py +++ b/libcst/metadata/tests/test_base_provider.py @@ -13,6 +13,7 @@ VisitorMetadataProvider, ) from libcst.metadata.wrapper import _gen_batchable +from libcst._metadata_dependent import LazyValue from libcst.testing.utils import UnitTest @@ -75,3 +76,63 @@ def visit_Return(self, node: cst.Return) -> None: self.assertEqual(metadata[SimpleProvider][pass_], 1) self.assertEqual(metadata[SimpleProvider][return_], 2) self.assertEqual(metadata[SimpleProvider][pass_2], 1) + + def test_lazy_visitor_provider(self) -> None: + class SimpleLazyProvider(VisitorMetadataProvider[int]): + """ + Sets metadata on every node to a callable that returns 1. + """ + + def on_visit(self, node: cst.CSTNode) -> bool: + self.set_metadata(node, LazyValue(lambda: 1)) + return True + + wrapper = MetadataWrapper(parse_module("pass; return")) + module = wrapper.module + pass_ = cast(cst.SimpleStatementLine, module.body[0]).body[0] + return_ = cast(cst.SimpleStatementLine, module.body[0]).body[1] + + provider = SimpleLazyProvider() + metadata = provider._gen(wrapper) + + # Check access on provider + self.assertEqual(provider.get_metadata(SimpleLazyProvider, module), 1) + self.assertEqual(provider.get_metadata(SimpleLazyProvider, pass_), 1) + self.assertEqual(provider.get_metadata(SimpleLazyProvider, return_), 1) + + # Check returned mapping + self.assertTrue(isinstance(metadata[module], LazyValue)) + self.assertTrue(isinstance(metadata[pass_], LazyValue)) + self.assertTrue(isinstance(metadata[return_], LazyValue)) + + def testlazy_batchable_provider(self) -> None: + class SimpleLazyProvider(BatchableMetadataProvider[int]): + """ + Sets metadata on every pass node to a callable that returns 1, + and every return node to a callable that returns 2. + """ + + def visit_Pass(self, node: cst.Pass) -> None: + self.set_metadata(node, LazyValue(lambda: 1)) + + def visit_Return(self, node: cst.Return) -> None: + self.set_metadata(node, LazyValue(lambda: 2)) + + wrapper = MetadataWrapper(parse_module("pass; return; pass")) + module = wrapper.module + pass_ = cast(cst.SimpleStatementLine, module.body[0]).body[0] + return_ = cast(cst.SimpleStatementLine, module.body[0]).body[1] + pass_2 = cast(cst.SimpleStatementLine, module.body[0]).body[2] + + provider = SimpleLazyProvider() + metadata = _gen_batchable(wrapper, [provider]) + + # Check access on provider + self.assertEqual(provider.get_metadata(SimpleLazyProvider, pass_), 1) + self.assertEqual(provider.get_metadata(SimpleLazyProvider, return_), 2) + self.assertEqual(provider.get_metadata(SimpleLazyProvider, pass_2), 1) + + # Check returned mapping + self.assertTrue(isinstance(metadata[SimpleLazyProvider][pass_], LazyValue)) + self.assertTrue(isinstance(metadata[SimpleLazyProvider][return_], LazyValue)) + self.assertTrue(isinstance(metadata[SimpleLazyProvider][pass_2], LazyValue)) From 3f33af314f903c5fc1e8b44181ff9c88ddb268e9 Mon Sep 17 00:00:00 2001 From: Chenguang Zhu Date: Fri, 8 Jul 2022 15:56:22 -0400 Subject: [PATCH 9/9] Fix linting errors --- libcst/metadata/tests/test_base_provider.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libcst/metadata/tests/test_base_provider.py b/libcst/metadata/tests/test_base_provider.py index 6e20edf58..26ebde701 100644 --- a/libcst/metadata/tests/test_base_provider.py +++ b/libcst/metadata/tests/test_base_provider.py @@ -7,13 +7,13 @@ import libcst as cst from libcst import parse_module +from libcst._metadata_dependent import LazyValue from libcst.metadata import ( BatchableMetadataProvider, MetadataWrapper, VisitorMetadataProvider, ) from libcst.metadata.wrapper import _gen_batchable -from libcst._metadata_dependent import LazyValue from libcst.testing.utils import UnitTest