22from collections .abc import Iterable
33from functools import lru_cache
44from pathlib import Path
5- from typing import List , Set
5+ from typing import List , Set , Union
6+
7+ typing_ns = "typing"
8+ type_checking = "TYPE_CHECKING"
69
710
811def parse_import (node : ast .Import ) -> List [str ]:
@@ -31,6 +34,40 @@ def parse_imports(node: ast.AST) -> List[str]:
3134 return []
3235
3336
37+ def find_type_checking_body (node : ast .If ) -> List [ast .stmt ]:
38+ if isinstance (node .test , ast .Name ) and node .test .id == type_checking :
39+ return node .body
40+
41+ if isinstance (node .test , ast .Attribute ) and node .test .attr == type_checking :
42+ if isinstance (node .test .value , ast .Name ) and node .test .value .id == typing_ns :
43+ return node .body
44+
45+ return []
46+
47+
48+ def flatten (data : Iterable ) -> list :
49+ return [item for nested in data for item in nested ]
50+
51+
52+ def parse_node (node : ast .AST ) -> Union [dict , None ]:
53+ if isinstance (node , ast .Import ):
54+ return {"include" : parse_import (node )}
55+
56+ if isinstance (node , ast .ImportFrom ):
57+ return {"include" : parse_import_from (node )}
58+
59+ if isinstance (node , ast .If ):
60+ found = find_type_checking_body (node )
61+ parsed = flatten (parse_imports (f ) for f in found )
62+
63+ if not parsed :
64+ return None
65+
66+ return {"exclude" : parsed }
67+
68+ return None
69+
70+
3471def parse_module (path : Path ) -> ast .AST :
3572 with open (path .as_posix (), "r" , encoding = "utf-8" , errors = "ignore" ) as f :
3673 tree = ast .parse (f .read (), path .name )
@@ -42,14 +79,17 @@ def parse_module(path: Path) -> ast.AST:
4279def extract_imports (path : Path ) -> List [str ]:
4380 tree = parse_module (path )
4481
45- return [i for node in ast .walk (tree ) for i in parse_imports (node ) if i is not None ]
82+ nodes = (parse_node (n ) for n in ast .walk (tree ))
83+ parsed_nodes = [n for n in nodes if n is not None ]
4684
85+ includes = [i for n in parsed_nodes for i in n .get ("include" , [])]
86+ excludes = {i for n in parsed_nodes for i in n .get ("exclude" , [])}
4787
48- def extract_and_flatten (py_modules : Iterable ) -> Set [str ]:
49- extracted = (extract_imports (m ) for m in py_modules )
50- flattened = (i for imports in extracted for i in imports )
88+ return [i for i in includes if i not in excludes ]
5189
52- return set (flattened )
90+
91+ def extract_and_flatten (py_modules : Iterable ) -> Set [str ]:
92+ return {i for m in py_modules for i in extract_imports (m )}
5393
5494
5595def is_python_file (path : Path ) -> bool :
0 commit comments