diff --git a/src/unimport/analyzers/import_statement.py b/src/unimport/analyzers/import_statement.py index 245f7a84..79475271 100644 --- a/src/unimport/analyzers/import_statement.py +++ b/src/unimport/analyzers/import_statement.py @@ -103,6 +103,18 @@ def _is_type_checking_block(if_node: ast.If) -> bool: return True return False + @staticmethod + def _collect_import_names(nodes: list[ast.stmt], *, recursive: bool = True) -> set[str]: + names: set[str] = set() + for node in nodes: + if isinstance(node, (ast.Import, ast.ImportFrom)): + for alias in node.names: + names.add(alias.asname or alias.name) + elif recursive and isinstance(node, ast.If): + names |= ImportAnalyzer._collect_import_names(node.body) + names |= ImportAnalyzer._collect_import_names(node.orelse) + return names + def visit_If(self, if_node: ast.If) -> None: if self._is_type_checking_block(if_node): self._in_type_checking = True @@ -113,17 +125,9 @@ def visit_If(self, if_node: ast.If) -> None: self.visit(node) return - self.if_names = { - name.asname or name.name - for n in filter(lambda node: isinstance(node, (ast.Import, ast.ImportFrom)), if_node.body) - for name in n.names # type: ignore - } - - self.orelse_names = { - name.asname or name.name - for n in filter(lambda node: isinstance(node, (ast.Import, ast.ImportFrom)), if_node.orelse) - for name in n.names # type: ignore - } + self.if_names = self._collect_import_names(if_node.body) + + self.orelse_names = self._collect_import_names(if_node.orelse, recursive=False) self.generic_visit(if_node) diff --git a/tests/cases/analyzer/if_dispatch/case_7.py b/tests/cases/analyzer/if_dispatch/case_7.py new file mode 100644 index 00000000..7fa1f0c1 --- /dev/null +++ b/tests/cases/analyzer/if_dispatch/case_7.py @@ -0,0 +1,23 @@ +from typing import Union + +from unimport.statement import Import, ImportFrom, Name + +__all__ = ["NAMES", "IMPORTS", "UNUSED_IMPORTS"] + + +NAMES: list[Name] = [ + Name(lineno=4, name="sys.version_info", is_all=False), + Name(lineno=5, name="TYPE_CHECKING", is_all=False), + Name(lineno=8, name="Any", is_all=False), + Name(lineno=11, name="Any", is_all=False), + Name(lineno=11, name="Any", is_all=False), + Name(lineno=11, name="Any", is_all=False), + Name(lineno=19, name="print", is_all=False), + Name(lineno=19, name="ForwardRef", is_all=False), +] +IMPORTS: list[Union[Import, ImportFrom]] = [ + Import(lineno=1, column=1, name="sys", package="sys"), + ImportFrom(lineno=2, column=1, name="TYPE_CHECKING", package="typing", star=False, suggestions=[]), + ImportFrom(lineno=2, column=2, name="Any", package="typing", star=False, suggestions=[]), +] +UNUSED_IMPORTS: list[Union[Import, ImportFrom]] = [] diff --git a/tests/cases/refactor/if_dispatch/case_7.py b/tests/cases/refactor/if_dispatch/case_7.py new file mode 100644 index 00000000..d4d44190 --- /dev/null +++ b/tests/cases/refactor/if_dispatch/case_7.py @@ -0,0 +1,19 @@ +import sys +from typing import TYPE_CHECKING, Any + +if sys.version_info < (3, 7): + if TYPE_CHECKING: + + class ForwardRef: + def __init__(self, arg: Any): + pass + + def _eval_type(self, globalns: Any, localns: Any) -> Any: + pass + + else: + from typing import _ForwardRef as ForwardRef +else: + from typing import ForwardRef + +print(ForwardRef) diff --git a/tests/cases/source/if_dispatch/case_7.py b/tests/cases/source/if_dispatch/case_7.py new file mode 100644 index 00000000..d4d44190 --- /dev/null +++ b/tests/cases/source/if_dispatch/case_7.py @@ -0,0 +1,19 @@ +import sys +from typing import TYPE_CHECKING, Any + +if sys.version_info < (3, 7): + if TYPE_CHECKING: + + class ForwardRef: + def __init__(self, arg: Any): + pass + + def _eval_type(self, globalns: Any, localns: Any) -> Any: + pass + + else: + from typing import _ForwardRef as ForwardRef +else: + from typing import ForwardRef + +print(ForwardRef)