From 20b0070e005cd4937c7a8d9d3d88ee59073b4e3f Mon Sep 17 00:00:00 2001 From: Steven Troxler Date: Thu, 28 Oct 2021 13:47:42 -0400 Subject: [PATCH] Support relative imports in ATAV qualifier handling Based on diff review of https://github.com/Instagram/LibCST/pull/536, I investigated relatvie import handling and realized that with minor changes we can now handle them correctly. Relative imports aren't likely in code coming from an automated tool, but they could happen in hand-written stubs if anyone tries to use this codemod tool to merge stubs with code. Added a new test: ``` > python -m unittest libcst.codemod.visitors.tests.test_apply_type_annotations ............................................. ---------------------------------------------------------------------- Ran 45 tests in 2.195s OK ``` --- .../visitors/_apply_type_annotations.py | 19 +++++++++++--- .../tests/test_apply_type_annotations.py | 26 +++++++++++++++++++ 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/libcst/codemod/visitors/_apply_type_annotations.py b/libcst/codemod/visitors/_apply_type_annotations.py index 2ac44c023..a459d300e 100644 --- a/libcst/codemod/visitors/_apply_type_annotations.py +++ b/libcst/codemod/visitors/_apply_type_annotations.py @@ -145,15 +145,28 @@ def _get_qualified_name_and_dequalified_node( dequalified_node = node.attr if isinstance(node, cst.Attribute) else node return qualified_name, dequalified_node + def _module_and_target(self, qualified_name: str) -> Tuple[str, str]: + relative_prefix = "" + while qualified_name.startswith("."): + relative_prefix += "." + qualified_name = qualified_name[1:] + split = qualified_name.rsplit(".", 1) + if len(split) == 1: + qualifier, target = "", split[0] + else: + qualifier, target = split + return (relative_prefix + qualifier, target) + def _handle_qualification_and_should_qualify(self, qualified_name: str) -> bool: """ Basd on a qualified name and the existing module imports, record that we need to add an import if necessary and return whether or not we should use the qualified name due to a preexisting import. """ - split_name = qualified_name.split(".") - if len(split_name) > 1 and qualified_name not in self.existing_imports: - module, target = ".".join(split_name[:-1]), split_name[-1] + module, target = self._module_and_target(qualified_name) + if module in ("", "builtins"): + return False + elif qualified_name not in self.existing_imports: if module == "builtins": return False elif module in self.existing_imports: diff --git a/libcst/codemod/visitors/tests/test_apply_type_annotations.py b/libcst/codemod/visitors/tests/test_apply_type_annotations.py index 4a63ab29e..54aec5d4e 100644 --- a/libcst/codemod/visitors/tests/test_apply_type_annotations.py +++ b/libcst/codemod/visitors/tests/test_apply_type_annotations.py @@ -123,6 +123,32 @@ def run_test_case_with_flags( FOO: Union[Example, int] = bar() """, ), + "with_relative_imports": ( + """ + from .relative0 import T0 + from ..relative1 import T1 + from . import relative2 + + x0: typing.Optional[T0] + x1: typing.Optional[T1] + x2: typing.Optional[relative2.T2] + """, + """ + x0 = None + x1 = None + x2 = None + """, + """ + from ..relative1 import T1 + from .relative0 import T0 + from .relative2 import T2 + from typing import Optional + + x0: Optional[T0] = None + x1: Optional[T1] = None + x2: Optional[T2] = None + """, + ), } ) def test_annotate_globals(self, stub: str, before: str, after: str) -> None: