Skip to content

Commit

Permalink
Resolve attribute access on Any constants in pytd to just the constant.
Browse files Browse the repository at this point in the history
For a pyi file like:
```py
from typing import Any
dep: Any
x: dep.Thing
```
When processing local names, turn `x = pytd.Constant(NamedType("dep.Thing"))` into `x = pytd.Constant(NamedType("dep"))`. This makes it resolve later to `Any`.
PiperOrigin-RevId: 589298720
  • Loading branch information
Solumin authored and rchen152 committed Dec 14, 2023
1 parent 8687370 commit f607901
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 34 deletions.
2 changes: 1 addition & 1 deletion pytype/pyi/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ def post_process_ast(ast, src, name=None):

if name:
ast = ast.Replace(name=name)
ast = ast.Visit(visitors.AddNamePrefix())
ast = ast.Visit(visitors.ResolveLocalNames())
else:
# If there's no unique name, hash the sourcecode.
ast = ast.Replace(name=hashlib.md5(src.encode("utf-8")).hexdigest())
Expand Down
17 changes: 15 additions & 2 deletions pytype/pytd/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,7 +1241,7 @@ def VisitNamedType(self, node):
return node.Replace(name=new_name)


class AddNamePrefix(Visitor):
class ResolveLocalNames(Visitor):
"""Visitor for making names fully qualified.
This will change
Expand All @@ -1252,6 +1252,10 @@ def bar(x: Foo) -> Foo
class baz.Foo:
pass
def bar(x: baz.Foo) -> baz.Foo
References to nested classes will be full resolved, e.g. if C is nested in
B is nested in A, then `x: C` becomes `x: foo.A.B.C`.
References to attributes of Any-typed constants will be resolved to Any.
"""

def __init__(self):
Expand All @@ -1266,6 +1270,8 @@ def _ClassStackString(self):

def EnterTypeDeclUnit(self, node):
self.classes = {cls.name for cls in node.classes}
self.any_constants = {const.name for const in node.constants
if const.type == pytd.AnythingType()}
self.name = node.name
self.prefix = node.name + "."

Expand All @@ -1287,10 +1293,17 @@ def VisitNamedType(self, node):
# This is an external type; do not prefix it. StripExternalNamePrefix will
# remove it later.
return node
if node.name.split(".")[0] in self.classes:
target = node.name.split(".")[0]
if target in self.classes:
# We need to check just the first part, in case we have a class constant
# like Foo.BAR, or some similarly nested name.
return node.Replace(name=self.prefix + node.name)
if target in self.any_constants:
# If we have a constant in module `foo` that's Any, i.e.
# mod: Any
# x: mod.Thing
# We resolve `mod.Thing` to Any.
return pytd.AnythingType()
if self.cls_stack:
if node.name == self.cls_stack[-1].name:
# We're referencing a class from within itself.
Expand Down
122 changes: 91 additions & 31 deletions pytype/pytd/visitors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,12 @@ def f(x: T) -> T: ...
B = A
""")
src2 = "from foo import *"
ast1 = self.Parse(src1).Replace(name="foo").Visit(visitors.AddNamePrefix())
ast2 = self.Parse(src2).Replace(name="bar").Visit(visitors.AddNamePrefix())
ast1 = (
self.Parse(src1).Replace(name="foo").Visit(visitors.ResolveLocalNames())
)
ast2 = (
self.Parse(src2).Replace(name="bar").Visit(visitors.ResolveLocalNames())
)
ast2 = ast2.Visit(visitors.LookupExternalTypes(
{"foo": ast1, "bar": ast2}, self_name="bar"))
self.assertEqual("bar", ast2.name)
Expand All @@ -252,7 +256,9 @@ def test_lookup_star_alias_in_unnamed_module(self):
class A: ...
""")
src2 = "from foo import *"
ast1 = self.Parse(src1).Replace(name="foo").Visit(visitors.AddNamePrefix())
ast1 = (
self.Parse(src1).Replace(name="foo").Visit(visitors.ResolveLocalNames())
)
ast2 = self.Parse(src2)
name = ast2.name
ast2 = ast2.Visit(visitors.LookupExternalTypes(
Expand All @@ -267,9 +273,15 @@ def test_lookup_two_star_aliases(self):
from foo import *
from bar import *
""")
ast1 = self.Parse(src1).Replace(name="foo").Visit(visitors.AddNamePrefix())
ast2 = self.Parse(src2).Replace(name="bar").Visit(visitors.AddNamePrefix())
ast3 = self.Parse(src3).Replace(name="baz").Visit(visitors.AddNamePrefix())
ast1 = (
self.Parse(src1).Replace(name="foo").Visit(visitors.ResolveLocalNames())
)
ast2 = (
self.Parse(src2).Replace(name="bar").Visit(visitors.ResolveLocalNames())
)
ast3 = (
self.Parse(src3).Replace(name="baz").Visit(visitors.ResolveLocalNames())
)
ast3 = ast3.Visit(visitors.LookupExternalTypes(
{"foo": ast1, "bar": ast2, "baz": ast3}, self_name="baz"))
self.assertSetEqual({a.name for a in ast3.aliases}, {"baz.A", "baz.B"})
Expand All @@ -281,9 +293,15 @@ def test_lookup_two_star_aliases_with_same_class(self):
from foo import *
from bar import *
""")
ast1 = self.Parse(src1).Replace(name="foo").Visit(visitors.AddNamePrefix())
ast2 = self.Parse(src2).Replace(name="bar").Visit(visitors.AddNamePrefix())
ast3 = self.Parse(src3).Replace(name="baz").Visit(visitors.AddNamePrefix())
ast1 = (
self.Parse(src1).Replace(name="foo").Visit(visitors.ResolveLocalNames())
)
ast2 = (
self.Parse(src2).Replace(name="bar").Visit(visitors.ResolveLocalNames())
)
ast3 = (
self.Parse(src3).Replace(name="baz").Visit(visitors.ResolveLocalNames())
)
self.assertRaises(KeyError, ast3.Visit, visitors.LookupExternalTypes(
{"foo": ast1, "bar": ast2, "baz": ast3}, self_name="baz"))

Expand All @@ -294,8 +312,12 @@ def test_lookup_star_alias_with_duplicate_class(self):
class A:
x = ... # type: int
""")
ast1 = self.Parse(src1).Replace(name="foo").Visit(visitors.AddNamePrefix())
ast2 = self.Parse(src2).Replace(name="bar").Visit(visitors.AddNamePrefix())
ast1 = (
self.Parse(src1).Replace(name="foo").Visit(visitors.ResolveLocalNames())
)
ast2 = (
self.Parse(src2).Replace(name="bar").Visit(visitors.ResolveLocalNames())
)
ast2 = ast2.Visit(visitors.LookupExternalTypes(
{"foo": ast1, "bar": ast2}, self_name="bar"))
self.assertMultiLineEqual(pytd_utils.Print(ast2), textwrap.dedent("""
Expand All @@ -310,9 +332,15 @@ def test_lookup_two_star_aliases_with_default_pyi(self):
from foo import *
from bar import *
""")
ast1 = self.Parse(src1).Replace(name="foo").Visit(visitors.AddNamePrefix())
ast2 = self.Parse(src2).Replace(name="bar").Visit(visitors.AddNamePrefix())
ast3 = self.Parse(src3).Replace(name="baz").Visit(visitors.AddNamePrefix())
ast1 = (
self.Parse(src1).Replace(name="foo").Visit(visitors.ResolveLocalNames())
)
ast2 = (
self.Parse(src2).Replace(name="bar").Visit(visitors.ResolveLocalNames())
)
ast3 = (
self.Parse(src3).Replace(name="baz").Visit(visitors.ResolveLocalNames())
)
ast3 = ast3.Visit(visitors.LookupExternalTypes(
{"foo": ast1, "bar": ast2, "baz": ast3}, self_name="baz"))
self.assertMultiLineEqual(pytd_utils.Print(ast3), textwrap.dedent("""
Expand All @@ -328,8 +356,12 @@ def test_lookup_star_alias_with_duplicate_getattr(self):
from foo import *
def __getattr__(name) -> Any: ...
""")
ast1 = self.Parse(src1).Replace(name="foo").Visit(visitors.AddNamePrefix())
ast2 = self.Parse(src2).Replace(name="bar").Visit(visitors.AddNamePrefix())
ast1 = (
self.Parse(src1).Replace(name="foo").Visit(visitors.ResolveLocalNames())
)
ast2 = (
self.Parse(src2).Replace(name="bar").Visit(visitors.ResolveLocalNames())
)
ast2 = ast2.Visit(visitors.LookupExternalTypes(
{"foo": ast1, "bar": ast2}, self_name="bar"))
self.assertMultiLineEqual(pytd_utils.Print(ast2), textwrap.dedent("""
Expand All @@ -345,9 +377,15 @@ def test_lookup_two_star_aliases_with_different_getattrs(self):
from foo import *
from bar import *
""")
ast1 = self.Parse(src1).Replace(name="foo").Visit(visitors.AddNamePrefix())
ast2 = self.Parse(src2).Replace(name="bar").Visit(visitors.AddNamePrefix())
ast3 = self.Parse(src3).Replace(name="baz").Visit(visitors.AddNamePrefix())
ast1 = (
self.Parse(src1).Replace(name="foo").Visit(visitors.ResolveLocalNames())
)
ast2 = (
self.Parse(src2).Replace(name="bar").Visit(visitors.ResolveLocalNames())
)
ast3 = (
self.Parse(src3).Replace(name="baz").Visit(visitors.ResolveLocalNames())
)
self.assertRaises(KeyError, ast3.Visit, visitors.LookupExternalTypes(
{"foo": ast1, "bar": ast2, "baz": ast3}, self_name="baz"))

Expand All @@ -357,8 +395,12 @@ def test_lookup_star_alias_with_different_getattr(self):
from foo import *
def __getattr__(name) -> str: ...
""")
ast1 = self.Parse(src1).Replace(name="foo").Visit(visitors.AddNamePrefix())
ast2 = self.Parse(src2).Replace(name="bar").Visit(visitors.AddNamePrefix())
ast1 = (
self.Parse(src1).Replace(name="foo").Visit(visitors.ResolveLocalNames())
)
ast2 = (
self.Parse(src2).Replace(name="bar").Visit(visitors.ResolveLocalNames())
)
ast2 = ast2.Visit(visitors.LookupExternalTypes(
{"foo": ast1, "bar": ast2}, self_name="bar"))
self.assertMultiLineEqual(pytd_utils.Print(ast2), textwrap.dedent("""
Expand Down Expand Up @@ -565,7 +607,7 @@ class X(Generic[T]):
self.assertIsNone(tree.Lookup("T").scope)
self.assertEqual("X",
tree.Lookup("X").template[0].type_param.scope)
tree = tree.Replace(name="foo").Visit(visitors.AddNamePrefix())
tree = tree.Replace(name="foo").Visit(visitors.ResolveLocalNames())
self.assertIsNotNone(tree.Lookup("foo.f"))
self.assertIsNotNone(tree.Lookup("foo.X"))
self.assertEqual("foo", tree.Lookup("foo.T").scope)
Expand All @@ -580,8 +622,8 @@ def test_add_name_prefix_twice(self):
class X(Generic[T]): ...
""")
tree = self.Parse(src)
tree = tree.Replace(name="foo").Visit(visitors.AddNamePrefix())
tree = tree.Replace(name="foo").Visit(visitors.AddNamePrefix())
tree = tree.Replace(name="foo").Visit(visitors.ResolveLocalNames())
tree = tree.Replace(name="foo").Visit(visitors.ResolveLocalNames())
self.assertIsNotNone(tree.Lookup("foo.foo.x"))
self.assertEqual("foo.foo", tree.Lookup("foo.foo.T").scope)
self.assertEqual("foo.foo.X",
Expand All @@ -596,7 +638,7 @@ class Y: ...
x = tree.Lookup("x")
x = x.Replace(type=pytd.ClassType("Y"))
tree = tree.Replace(constants=(x,), name="foo")
tree = tree.Visit(visitors.AddNamePrefix())
tree = tree.Visit(visitors.ResolveLocalNames())
self.assertEqual("foo.Y", tree.Lookup("foo.x").type.name)

def test_add_name_prefix_on_nested_class_alias(self):
Expand All @@ -614,8 +656,14 @@ class foo.A.B:
class foo.A.B.C: ...
D: Type[foo.A.B.C]
""").strip()
self.assertMultiLineEqual(expected, pytd_utils.Print(
self.Parse(src).Replace(name="foo").Visit(visitors.AddNamePrefix())))
self.assertMultiLineEqual(
expected,
pytd_utils.Print(
self.Parse(src)
.Replace(name="foo")
.Visit(visitors.ResolveLocalNames())
),
)

def test_add_name_prefix_on_nested_class_outside_ref(self):
src = textwrap.dedent("""
Expand Down Expand Up @@ -643,8 +691,14 @@ def f(self, x: foo.A.B) -> foo.A.B: ...
def foo.f(x: foo.A.B) -> foo.A.B: ...
""").strip()
self.assertMultiLineEqual(expected, pytd_utils.Print(
self.Parse(src).Replace(name="foo").Visit(visitors.AddNamePrefix())))
self.assertMultiLineEqual(
expected,
pytd_utils.Print(
self.Parse(src)
.Replace(name="foo")
.Visit(visitors.ResolveLocalNames())
),
)

def test_add_name_prefix_on_nested_class_method(self):
src = textwrap.dedent("""
Expand All @@ -657,8 +711,14 @@ class foo.A:
class foo.A.B:
def copy(self) -> foo.A.B: ...
""").strip()
self.assertMultiLineEqual(expected, pytd_utils.Print(
self.Parse(src).Replace(name="foo").Visit(visitors.AddNamePrefix())))
self.assertMultiLineEqual(
expected,
pytd_utils.Print(
self.Parse(src)
.Replace(name="foo")
.Visit(visitors.ResolveLocalNames())
),
)

def test_print_merge_types(self):
src = textwrap.dedent("""
Expand Down
16 changes: 16 additions & 0 deletions pytype/tests/test_import2.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,22 @@ def load() -> Component:
return Component(foos=foos)
""", module_name="loaders")

def test_import_any(self):
with self.DepTree([("foo.pyi", """
from typing import Any
dep: Any
x: dep.Thing
class A(dep.Base):
def get(self) -> dep.Got: ...
""")]):
self.Check("""
from typing import Any
import foo
assert_type(foo.dep, Any)
assert_type(foo.x, Any)
assert_type(foo.A(), foo.A)
assert_type(foo.A().get(), Any)
""")

if __name__ == "__main__":
test_base.main()

0 comments on commit f607901

Please sign in to comment.