From 6660b9b2bf5e73fa598b7b54e925f409439dad14 Mon Sep 17 00:00:00 2001 From: Anthony Sottile Date: Sat, 10 Jun 2023 17:13:03 -0400 Subject: [PATCH] also perform PEP 563 rewrites for async functions --- pyupgrade/_plugins/typing_pep563.py | 23 ++++++++++++++++++++--- tests/features/typing_pep563_test.py | 11 +++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/pyupgrade/_plugins/typing_pep563.py b/pyupgrade/_plugins/typing_pep563.py index 7fb4e275..0eec8af1 100644 --- a/pyupgrade/_plugins/typing_pep563.py +++ b/pyupgrade/_plugins/typing_pep563.py @@ -133,10 +133,9 @@ def _process_args( yield from _replace_string_literal(arg.annotation) -@register(ast.FunctionDef) -def visit_FunctionDef( +def _visit_func( state: State, - node: ast.FunctionDef, + node: ast.AsyncFunctionDef | ast.FunctionDef, parent: ast.AST, ) -> Iterable[tuple[Offset, TokenFunc]]: if not _supported_version(state): @@ -150,6 +149,24 @@ def visit_FunctionDef( yield from _replace_string_literal(node.returns) +@register(ast.AsyncFunctionDef) +def visit_AsyncFunctionDef( + state: State, + node: ast.AsyncFunctionDef, + parent: ast.AST, +) -> Iterable[tuple[Offset, TokenFunc]]: + yield from _visit_func(state, node, parent) + + +@register(ast.FunctionDef) +def visit_FunctionDef( + state: State, + node: ast.FunctionDef, + parent: ast.AST, +) -> Iterable[tuple[Offset, TokenFunc]]: + yield from _visit_func(state, node, parent) + + @register(ast.AnnAssign) def visit_AnnAssign( state: State, diff --git a/tests/features/typing_pep563_test.py b/tests/features/typing_pep563_test.py index 53d504ca..85f7a419 100644 --- a/tests/features/typing_pep563_test.py +++ b/tests/features/typing_pep563_test.py @@ -74,6 +74,17 @@ def test_fix_typing_pep563_noop(s): id='Simple annotation', ), + pytest.param( + 'from __future__ import annotations\n' + 'async def foo(var: "MyClass") -> "MyClass":\n' + ' ...\n', + + 'from __future__ import annotations\n' + 'async def foo(var: MyClass) -> MyClass:\n' + ' ...\n', + + id='simple async annotation', + ), pytest.param( 'from __future__ import annotations\n' 'def foo(*, inplace: "bool"): ...\n',