diff --git a/mypy/test/testtransform.py b/mypy/test/testtransform.py index 803f2dcd4035..d884fe9137ab 100644 --- a/mypy/test/testtransform.py +++ b/mypy/test/testtransform.py @@ -60,6 +60,7 @@ def test_transform(testcase: DataDrivenTestCase) -> None: and not os.path.splitext( os.path.basename(f.path))[0].endswith('_')): t = TypeAssertTransformVisitor() + t.test_only = True f = t.mypyfile(f) a += str(f).split('\n') except CompileError as e: diff --git a/mypy/treetransform.py b/mypy/treetransform.py index 4191569995b0..bd8a623455f7 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -20,7 +20,7 @@ YieldFromExpr, NamedTupleExpr, TypedDictExpr, NonlocalDecl, SetComprehension, DictionaryComprehension, ComplexExpr, TypeAliasExpr, EllipsisExpr, YieldExpr, ExecStmt, Argument, BackquoteExpr, AwaitExpr, AssignmentExpr, - OverloadPart, EnumCallExpr, REVEAL_TYPE + OverloadPart, EnumCallExpr, REVEAL_TYPE, GDEF ) from mypy.types import Type, FunctionLike, ProperType from mypy.traverser import TraverserVisitor @@ -37,6 +37,8 @@ class TransformVisitor(NodeVisitor[Node]): Notes: + * This can only be used to transform functions or classes, not top-level + statements, and/or modules as a whole. * Do not duplicate TypeInfo nodes. This would generally not be desirable. * Only update some name binding cross-references, but only those that refer to Var, Decorator or FuncDef nodes, not those targeting ClassDef or @@ -48,6 +50,9 @@ class TransformVisitor(NodeVisitor[Node]): """ def __init__(self) -> None: + # To simplify testing, set this flag to True if you want to transform + # all statements in a file (this is prohibited in normal mode). + self.test_only = False # There may be multiple references to a Var node. Keep track of # Var translations using a dictionary. self.var_map = {} # type: Dict[Var, Var] @@ -58,6 +63,7 @@ def __init__(self) -> None: self.func_placeholder_map = {} # type: Dict[FuncDef, FuncDef] def visit_mypy_file(self, node: MypyFile) -> MypyFile: + assert self.test_only, "This visitor should not be used for whole files." # NOTE: The 'names' and 'imports' instance variables will be empty! ignored_lines = {line: codes[:] for line, codes in node.ignored_lines.items()} @@ -358,7 +364,10 @@ def copy_ref(self, new: RefExpr, original: RefExpr) -> None: new.fullname = original.fullname target = original.node if isinstance(target, Var): - target = self.visit_var(target) + # Do not transform references to global variables. See + # testGenericFunctionAliasExpand for an example where this is important. + if original.kind != GDEF: + target = self.visit_var(target) elif isinstance(target, Decorator): target = self.visit_var(target.var) elif isinstance(target, FuncDef): diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index c4807c8fe1e7..654f33c590c2 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2402,3 +2402,31 @@ def func(tp: Type[C[S]]) -> S: else: return reg[tp.test].test[0] [builtins fixtures/dict.pyi] + +[case testGenericFunctionAliasExpand] +from typing import Optional, TypeVar + +T = TypeVar("T") +def gen(x: T) -> T: ... +gen_a = gen + +S = TypeVar("S", int, str) +class C: ... +def test() -> Optional[S]: + reveal_type(gen_a(C())) # N: Revealed type is '__main__.C*' + return None + +[case testGenericFunctionMemberExpand] +from typing import Optional, TypeVar, Callable + +T = TypeVar("T") + +class A: + def __init__(self) -> None: + self.gen: Callable[[T], T] + +S = TypeVar("S", int, str) +class C: ... +def test() -> Optional[S]: + reveal_type(A().gen(C())) # N: Revealed type is '__main__.C*' + return None