Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement lazy loading mechanism for expensive metadata providers #720

Merged
merged 9 commits into from
Jul 9, 2022
33 changes: 30 additions & 3 deletions libcst/_metadata_dependent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
from abc import ABC
from contextlib import contextmanager
from typing import (
Callable,
cast,
ClassVar,
Collection,
Generic,
Iterator,
Mapping,
Type,
TYPE_CHECKING,
TypeVar,
Union,
)

if TYPE_CHECKING:
Expand All @@ -29,7 +32,28 @@

_T = TypeVar("_T")

_UNDEFINED_DEFAULT = object()

class _UNDEFINED_DEFAULT:
pass


class LazyValue(Generic[_T]):
"""
The class for implementing a lazy metadata loading mechanism that improves the
performance when retriving expensive metadata (e.g., qualified names). Providers
including :class:`~libcst.metadata.QualifiedNameProvider` use this class to load
the metadata of a certain node lazily when calling
:func:`~libcst.MetadataDependent.get_metadata`.
"""

def __init__(self, callable: Callable[[], _T]) -> None:
self.callable = callable
self.return_value: Union[_T, Type[_UNDEFINED_DEFAULT]] = _UNDEFINED_DEFAULT

def __call__(self) -> _T:
if self.return_value is _UNDEFINED_DEFAULT:
self.return_value = self.callable()
return cast(_T, self.return_value)


class MetadataDependent(ABC):
Expand Down Expand Up @@ -107,6 +131,9 @@ def get_metadata(
)

if default is not _UNDEFINED_DEFAULT:
return cast(_T, self.metadata[key].get(node, default))
value = self.metadata[key].get(node, default)
else:
return cast(_T, self.metadata[key][node])
value = self.metadata[key][node]
if isinstance(value, LazyValue):
value = value()
return cast(_T, value)
7 changes: 6 additions & 1 deletion libcst/matchers/_matcher_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import libcst
import libcst.metadata as meta
from libcst import FlattenSentinel, MaybeSentinel, RemovalSentinel
from libcst._metadata_dependent import LazyValue


class DoNotCareSentinel(Enum):
Expand Down Expand Up @@ -1544,7 +1545,11 @@ def _fetch(provider: meta.ProviderT, node: libcst.CSTNode) -> object:
if provider not in metadata:
metadata[provider] = wrapper.resolve(provider)

return metadata.get(provider, {}).get(node, _METADATA_MISSING_SENTINEL)
node_metadata = metadata.get(provider, {}).get(node, _METADATA_MISSING_SENTINEL)
if isinstance(node_metadata, LazyValue):
node_metadata = node_metadata()

return node_metadata

return _fetch

Expand Down
29 changes: 16 additions & 13 deletions libcst/metadata/base_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from types import MappingProxyType
from typing import (
Callable,
cast,
Generic,
List,
Mapping,
Expand All @@ -16,12 +15,14 @@
Type,
TYPE_CHECKING,
TypeVar,
Union,
)

from libcst._batched_visitor import BatchableCSTVisitor
from libcst._metadata_dependent import (
_T as _MetadataT,
_UNDEFINED_DEFAULT,
LazyValue,
MetadataDependent,
)
from libcst._visitors import CSTVisitor
Expand All @@ -36,6 +37,7 @@
# BaseMetadataProvider[int] would be a subtype of BaseMetadataProvider[object], so the
# typevar is covariant.
_ProvidedMetadataT = TypeVar("_ProvidedMetadataT", covariant=True)
MaybeLazyMetadataT = Union[LazyValue[_ProvidedMetadataT], _ProvidedMetadataT]


# We can't use an ABCMeta here, because of metaclass conflicts
Expand All @@ -52,16 +54,16 @@ class BaseMetadataProvider(MetadataDependent, Generic[_ProvidedMetadataT]):
#
# N.B. This has some typing variance problems. See `set_metadata` for an
# explanation.
_computed: MutableMapping["CSTNode", _ProvidedMetadataT]
_computed: MutableMapping["CSTNode", MaybeLazyMetadataT]

#: Implement gen_cache to indicate the matadata provider depends on cache from external
#: Implement gen_cache to indicate the metadata provider depends on cache from external
#: system. This function will be called by :class:`~libcst.metadata.FullRepoManager`
#: to compute required cache object per file path.
gen_cache: Optional[Callable[[Path, List[str], int], Mapping[str, object]]] = None

def __init__(self, cache: object = None) -> None:
super().__init__()
self._computed = {}
self._computed: MutableMapping["CSTNode", MaybeLazyMetadataT] = {}
if self.gen_cache and cache is None:
# The metadata provider implementation is responsible to store and use cache.
raise Exception(
Expand All @@ -71,7 +73,7 @@ def __init__(self, cache: object = None) -> None:

def _gen(
self, wrapper: "MetadataWrapper"
) -> Mapping["CSTNode", _ProvidedMetadataT]:
) -> Mapping["CSTNode", MaybeLazyMetadataT]:
"""
Resolves and returns metadata mapping for the module in ``wrapper``.

Expand All @@ -93,11 +95,7 @@ def _gen_impl(self, module: "Module") -> None:
"""
...

# pyre-ignore[46]: The covariant `value` isn't type-safe because we write it to
# pyre: `self._computed`, however we assume that only one subclass in the MRO chain
# pyre: will ever call `set_metadata`, so it's okay for our purposes. There's no
# pyre: sane way to redesign this API so that it doesn't have this problem.
def set_metadata(self, node: "CSTNode", value: _ProvidedMetadataT) -> None:
def set_metadata(self, node: "CSTNode", value: MaybeLazyMetadataT) -> None:
"""
Record a metadata value ``value`` for ``node``.
"""
Expand All @@ -107,7 +105,9 @@ def get_metadata(
self,
key: Type["BaseMetadataProvider[_MetadataT]"],
node: "CSTNode",
default: _MetadataT = _UNDEFINED_DEFAULT,
default: Union[
MaybeLazyMetadataT, Type[_UNDEFINED_DEFAULT]
] = _UNDEFINED_DEFAULT,
) -> _MetadataT:
"""
The same method as :func:`~libcst.MetadataDependent.get_metadata` except
Expand All @@ -116,9 +116,12 @@ def get_metadata(
"""
if key is type(self):
if default is not _UNDEFINED_DEFAULT:
return cast(_MetadataT, self._computed.get(node, default))
ret = self._computed.get(node, default)
else:
return cast(_MetadataT, self._computed[node])
ret = self._computed[node]
if isinstance(ret, LazyValue):
return ret()
return ret

return super().get_metadata(key, node, default)

Expand Down
6 changes: 4 additions & 2 deletions libcst/metadata/name_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Collection, List, Mapping, Optional, Union

import libcst as cst
from libcst._metadata_dependent import MetadataDependent
from libcst._metadata_dependent import LazyValue, MetadataDependent
from libcst.helpers.module import calculate_module_and_package, ModuleNameAndPackage
from libcst.metadata.base_provider import BatchableMetadataProvider
from libcst.metadata.scope_provider import (
Expand Down Expand Up @@ -78,7 +78,9 @@ def __init__(self, provider: "QualifiedNameProvider") -> None:
def on_visit(self, node: cst.CSTNode) -> bool:
scope = self.provider.get_metadata(ScopeProvider, node, None)
if scope:
self.provider.set_metadata(node, scope.get_qualified_names_for(node))
self.provider.set_metadata(
node, LazyValue(lambda: scope.get_qualified_names_for(node))
)
else:
self.provider.set_metadata(node, set())
super().on_visit(node)
Expand Down
61 changes: 61 additions & 0 deletions libcst/metadata/tests/test_base_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import libcst as cst
from libcst import parse_module
from libcst._metadata_dependent import LazyValue
from libcst.metadata import (
BatchableMetadataProvider,
MetadataWrapper,
Expand Down Expand Up @@ -75,3 +76,63 @@ def visit_Return(self, node: cst.Return) -> None:
self.assertEqual(metadata[SimpleProvider][pass_], 1)
self.assertEqual(metadata[SimpleProvider][return_], 2)
self.assertEqual(metadata[SimpleProvider][pass_2], 1)

def test_lazy_visitor_provider(self) -> None:
class SimpleLazyProvider(VisitorMetadataProvider[int]):
"""
Sets metadata on every node to a callable that returns 1.
"""

def on_visit(self, node: cst.CSTNode) -> bool:
self.set_metadata(node, LazyValue(lambda: 1))
return True

wrapper = MetadataWrapper(parse_module("pass; return"))
module = wrapper.module
pass_ = cast(cst.SimpleStatementLine, module.body[0]).body[0]
return_ = cast(cst.SimpleStatementLine, module.body[0]).body[1]

provider = SimpleLazyProvider()
metadata = provider._gen(wrapper)

# Check access on provider
self.assertEqual(provider.get_metadata(SimpleLazyProvider, module), 1)
self.assertEqual(provider.get_metadata(SimpleLazyProvider, pass_), 1)
self.assertEqual(provider.get_metadata(SimpleLazyProvider, return_), 1)

# Check returned mapping
self.assertTrue(isinstance(metadata[module], LazyValue))
self.assertTrue(isinstance(metadata[pass_], LazyValue))
self.assertTrue(isinstance(metadata[return_], LazyValue))

def testlazy_batchable_provider(self) -> None:
class SimpleLazyProvider(BatchableMetadataProvider[int]):
"""
Sets metadata on every pass node to a callable that returns 1,
and every return node to a callable that returns 2.
"""

def visit_Pass(self, node: cst.Pass) -> None:
self.set_metadata(node, LazyValue(lambda: 1))

def visit_Return(self, node: cst.Return) -> None:
self.set_metadata(node, LazyValue(lambda: 2))

wrapper = MetadataWrapper(parse_module("pass; return; pass"))
module = wrapper.module
pass_ = cast(cst.SimpleStatementLine, module.body[0]).body[0]
return_ = cast(cst.SimpleStatementLine, module.body[0]).body[1]
pass_2 = cast(cst.SimpleStatementLine, module.body[0]).body[2]

provider = SimpleLazyProvider()
metadata = _gen_batchable(wrapper, [provider])

# Check access on provider
self.assertEqual(provider.get_metadata(SimpleLazyProvider, pass_), 1)
self.assertEqual(provider.get_metadata(SimpleLazyProvider, return_), 2)
self.assertEqual(provider.get_metadata(SimpleLazyProvider, pass_2), 1)

# Check returned mapping
self.assertTrue(isinstance(metadata[SimpleLazyProvider][pass_], LazyValue))
self.assertTrue(isinstance(metadata[SimpleLazyProvider][return_], LazyValue))
self.assertTrue(isinstance(metadata[SimpleLazyProvider][pass_2], LazyValue))
20 changes: 18 additions & 2 deletions libcst/metadata/tests/test_name_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import libcst as cst
from libcst import ensure_type
from libcst._nodes.base import CSTNode
from libcst.metadata import (
FullyQualifiedNameProvider,
MetadataWrapper,
Expand All @@ -22,11 +23,26 @@
from libcst.testing.utils import data_provider, UnitTest


class QNameVisitor(cst.CSTVisitor):

METADATA_DEPENDENCIES = (QualifiedNameProvider,)

def __init__(self) -> None:
self.qnames: Dict["CSTNode", Collection[QualifiedName]] = {}

def on_visit(self, node: cst.CSTNode) -> bool:
qname = self.get_metadata(QualifiedNameProvider, node)
self.qnames[node] = qname
return True


def get_qualified_name_metadata_provider(
module_str: str,
) -> Tuple[cst.Module, Mapping[cst.CSTNode, Collection[QualifiedName]]]:
wrapper = MetadataWrapper(cst.parse_module(dedent(module_str)))
return wrapper.module, wrapper.resolve(QualifiedNameProvider)
visitor = QNameVisitor()
wrapper.visit(visitor)
return wrapper.module, visitor.qnames


def get_qualified_names(module_str: str) -> Set[QualifiedName]:
Expand Down Expand Up @@ -358,7 +374,7 @@ def f(): pass
else:
import f
import a.b as f

f()
"""
)
Expand Down