Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore the 0.4.1 behavior for libcst.helpers.get_absolute_module #684

Merged
merged 1 commit into from
May 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions libcst/codemod/commands/remove_unused_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from libcst import Import, ImportFrom, ImportStar, Module
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.codemod.visitors import GatherCommentsVisitor, RemoveImportsVisitor
from libcst.helpers import get_absolute_module_for_import
from libcst.helpers import get_absolute_module_from_package_for_import
from libcst.metadata import PositionProvider, ProviderT

DEFAULT_SUPPRESS_COMMENT_REGEX = (
Expand Down Expand Up @@ -74,8 +74,8 @@ def _handle_import(self, node: Union[Import, ImportFrom]) -> None:
asname=alias.evaluated_alias,
)
else:
module_name = get_absolute_module_for_import(
self.context.full_module_name, node
module_name = get_absolute_module_from_package_for_import(
self.context.full_package_name, node
)
if module_name is None:
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions libcst/codemod/visitors/_add_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from libcst.codemod._visitor import ContextAwareTransformer
from libcst.codemod.visitors._gather_imports import GatherImportsVisitor
from libcst.codemod.visitors._imports import ImportItem
from libcst.helpers import get_absolute_module_for_import
from libcst.helpers import get_absolute_module_from_package_for_import


class AddImportsVisitor(ContextAwareTransformer):
Expand Down Expand Up @@ -214,7 +214,7 @@ def leave_ImportFrom(
return updated_node

# Get the module we're importing as a string, see if we have work to do.
module = get_absolute_module_for_import(
module = get_absolute_module_from_package_for_import(
self.context.full_package_name, updated_node
)
if (
Expand Down
6 changes: 4 additions & 2 deletions libcst/codemod/visitors/_gather_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from libcst.codemod._context import CodemodContext
from libcst.codemod._visitor import ContextAwareVisitor
from libcst.codemod.visitors._imports import ImportItem
from libcst.helpers import get_absolute_module_for_import
from libcst.helpers import get_absolute_module_from_package_for_import


class GatherImportsVisitor(ContextAwareVisitor):
Expand Down Expand Up @@ -85,7 +85,9 @@ def visit_ImportFrom(self, node: libcst.ImportFrom) -> None:
self.all_imports.append(node)

# Get the module we're importing as a string.
module = get_absolute_module_for_import(self.context.full_package_name, node)
module = get_absolute_module_from_package_for_import(
self.context.full_package_name, node
)
if module is None:
# Can't get the absolute import from relative, so we can't
# support this.
Expand Down
6 changes: 4 additions & 2 deletions libcst/codemod/visitors/_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass, replace
from typing import Optional

from libcst.helpers import get_absolute_module
from libcst.helpers import get_absolute_module_from_package


@dataclass(frozen=True)
Expand Down Expand Up @@ -39,5 +39,7 @@ def resolve_relative(self, package_name: Optional[str]) -> "ImportItem":
mod = replace(mod, module_name="", obj_name=mod.module_name)
if package_name is None:
return mod
m = get_absolute_module(package_name, mod.module_name or None, self.relative)
m = get_absolute_module_from_package(
package_name, mod.module_name or None, self.relative
)
return mod if m is None else replace(mod, module_name=m, relative=0)
11 changes: 7 additions & 4 deletions libcst/codemod/visitors/_remove_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from libcst.codemod._context import CodemodContext
from libcst.codemod._visitor import ContextAwareTransformer, ContextAwareVisitor
from libcst.codemod.visitors._gather_unused_imports import GatherUnusedImportsVisitor
from libcst.helpers import get_absolute_module_for_import, get_full_name_for_node
from libcst.helpers import (
get_absolute_module_from_package_for_import,
get_full_name_for_node,
)
from libcst.metadata import Assignment, ProviderT, ScopeProvider


Expand Down Expand Up @@ -38,7 +41,7 @@ def _remove_imports_from_importfrom_stmt(
# We don't handle removing this, so ignore it.
return

module_name = get_absolute_module_for_import(
module_name = get_absolute_module_from_package_for_import(
self.context.full_package_name, import_node
)
if module_name is None:
Expand Down Expand Up @@ -248,7 +251,7 @@ def remove_unused_import_by_node(
if isinstance(names, cst.ImportStar):
# We don't handle removing this, so ignore it.
return
module_name = get_absolute_module_for_import(
module_name = get_absolute_module_from_package_for_import(
context.full_package_name, node
)
if module_name is None:
Expand Down Expand Up @@ -415,7 +418,7 @@ def leave_ImportFrom(
return updated_node

# Make sure we actually know the absolute module.
module_name = get_absolute_module_for_import(
module_name = get_absolute_module_from_package_for_import(
self.context.full_package_name, updated_node
)
if module_name is None or module_name not in self.unused_obj_imports:
Expand Down
6 changes: 6 additions & 0 deletions libcst/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
get_absolute_module,
get_absolute_module_for_import,
get_absolute_module_for_import_or_raise,
get_absolute_module_from_package,
get_absolute_module_from_package_for_import,
get_absolute_module_from_package_for_import_or_raise,
insert_header_comments,
ModuleNameAndPackage,
)
Expand All @@ -28,6 +31,9 @@
"get_absolute_module",
"get_absolute_module_for_import",
"get_absolute_module_for_import_or_raise",
"get_absolute_module_from_package",
"get_absolute_module_from_package_for_import",
"get_absolute_module_from_package_for_import_or_raise",
"get_full_name_for_node",
"get_full_name_for_node_or_raise",
"ensure_type",
Expand Down
57 changes: 53 additions & 4 deletions libcst/helpers/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,55 @@ def insert_header_comments(node: Module, comments: List[str]) -> Module:


def get_absolute_module(
current_module: Optional[str], module_name: Optional[str], num_dots: int
) -> Optional[str]:
if num_dots == 0:
# This is an absolute import, so the module is correct.
return module_name
if current_module is None:
# We don't actually have the current module available, so we can't compute
# the absolute module from relative.
return None
# We have the current module, as well as the relative, let's compute the base.
modules = current_module.split(".")
if len(modules) < num_dots:
# This relative import goes past the base of the repository, so we can't calculate it.
return None
base_module = ".".join(modules[:-num_dots])
# Finally, if the module name was supplied, append it to the end.
if module_name is not None:
# If we went all the way to the top, the base module should be empty, so we
# should return the relative bit as absolute. Otherwise, combine the base
# module and module name using a dot separator.
base_module = (
f"{base_module}.{module_name}" if len(base_module) > 0 else module_name
)
# If they tried to import all the way to the root, return None. Otherwise,
# return the module itself.
return base_module if len(base_module) > 0 else None


def get_absolute_module_for_import(
current_module: Optional[str], import_node: ImportFrom
) -> Optional[str]:
# First, let's try to grab the module name, regardless of relative status.
module = import_node.module
module_name = get_full_name_for_node(module) if module is not None else None
# Now, get the relative import location if it exists.
num_dots = len(import_node.relative)
return get_absolute_module(current_module, module_name, num_dots)


def get_absolute_module_for_import_or_raise(
current_module: Optional[str], import_node: ImportFrom
) -> str:
module = get_absolute_module_for_import(current_module, import_node)
if module is None:
raise Exception(f"Unable to compute absolute module for {import_node}")
return module


def get_absolute_module_from_package(
current_package: Optional[str], module_name: Optional[str], num_dots: int
) -> Optional[str]:
if num_dots == 0:
Expand All @@ -55,21 +104,21 @@ def get_absolute_module(
return "{}.{}".format(base, module_name) if module_name else base


def get_absolute_module_for_import(
def get_absolute_module_from_package_for_import(
current_package: Optional[str], import_node: ImportFrom
) -> Optional[str]:
# First, let's try to grab the module name, regardless of relative status.
module = import_node.module
module_name = get_full_name_for_node(module) if module is not None else None
# Now, get the relative import location if it exists.
num_dots = len(import_node.relative)
return get_absolute_module(current_package, module_name, num_dots)
return get_absolute_module_from_package(current_package, module_name, num_dots)


def get_absolute_module_for_import_or_raise(
def get_absolute_module_from_package_for_import_or_raise(
current_package: Optional[str], import_node: ImportFrom
) -> str:
module = get_absolute_module_for_import(current_package, import_node)
module = get_absolute_module_from_package_for_import(current_package, import_node)
if module is None:
raise Exception(f"Unable to compute absolute module for {import_node}")
return module
Expand Down
55 changes: 51 additions & 4 deletions libcst/helpers/tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
calculate_module_and_package,
get_absolute_module_for_import,
get_absolute_module_for_import_or_raise,
get_absolute_module_from_package_for_import,
get_absolute_module_from_package_for_import_or_raise,
insert_header_comments,
ModuleNameAndPackage,
)
Expand Down Expand Up @@ -67,6 +69,44 @@ def test_insert_header_comments(self) -> None:
insert_header_comments(node, inserted_comments).code, expected_code
)

@data_provider(
(
# Simple imports that are already absolute.
(None, "from a.b import c", "a.b"),
("x.y.z", "from a.b import c", "a.b"),
# Relative import that can't be resolved due to missing module.
(None, "from ..w import c", None),
# Relative import that goes past the module level.
("x", "from ...y import z", None),
("x.y.z", "from .....w import c", None),
("x.y.z", "from ... import c", None),
# Correct resolution of absolute from relative modules.
("x.y.z", "from . import c", "x.y"),
("x.y.z", "from .. import c", "x"),
("x.y.z", "from .w import c", "x.y.w"),
("x.y.z", "from ..w import c", "x.w"),
("x.y.z", "from ...w import c", "w"),
)
)
def test_get_absolute_module(
self,
module: Optional[str],
importfrom: str,
output: Optional[str],
) -> None:
node = ensure_type(cst.parse_statement(importfrom), cst.SimpleStatementLine)
assert len(node.body) == 1, "Unexpected number of statements!"
import_node = ensure_type(node.body[0], cst.ImportFrom)

self.assertEqual(get_absolute_module_for_import(module, import_node), output)
if output is None:
with self.assertRaises(Exception):
get_absolute_module_for_import_or_raise(module, import_node)
else:
self.assertEqual(
get_absolute_module_for_import_or_raise(module, import_node), output
)

@data_provider(
(
# Simple imports that are already absolute.
Expand Down Expand Up @@ -94,7 +134,7 @@ def test_insert_header_comments(self) -> None:
("x/y/z/__init__.py", "from ...w import c", "x.w"),
)
)
def test_get_absolute_module(
def test_get_absolute_module_from_package(
self,
filename: Optional[str],
importfrom: str,
Expand All @@ -108,13 +148,20 @@ def test_get_absolute_module(
assert len(node.body) == 1, "Unexpected number of statements!"
import_node = ensure_type(node.body[0], cst.ImportFrom)

self.assertEqual(get_absolute_module_for_import(package, import_node), output)
self.assertEqual(
get_absolute_module_from_package_for_import(package, import_node), output
)
if output is None:
with self.assertRaises(Exception):
get_absolute_module_for_import_or_raise(package, import_node)
get_absolute_module_from_package_for_import_or_raise(
package, import_node
)
else:
self.assertEqual(
get_absolute_module_for_import_or_raise(package, import_node), output
get_absolute_module_from_package_for_import_or_raise(
package, import_node
),
output,
)

@data_provider(
Expand Down