Skip to content

Commit a000ac8

Browse files
author
David Sanders
committed
Added basic support for importing base classes as members of an imported module
1 parent 7670ac9 commit a000ac8

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

mypy/stubgen.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -347,9 +347,7 @@ def get_base_types(self, cdef: ClassDef) -> List[str]:
347347
if base.name != 'object':
348348
base_types.append(base.name)
349349
elif isinstance(base, MemberExpr):
350-
modname = get_qualified_name(base.expr)
351-
base_types.append('%s.%s' % (modname, base.name))
352-
self.add_import_line('import %s\n' % modname)
350+
base_types.append(get_qualified_name(base))
353351
return base_types
354352

355353
def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
@@ -437,8 +435,9 @@ def visit_import_from(self, o: ImportFrom) -> None:
437435
exported_names.update(sub_names)
438436
self.import_and_export_names(o.id, o.relative, sub_names)
439437
# Import names used as base classes.
438+
base_class_imports = [base_class.split('.')[0] for base_class in self._base_classes]
440439
base_names = [(name, alias) for name, alias in o.names
441-
if alias or name in self._base_classes and name not in exported_names]
440+
if alias or name in base_class_imports and name not in exported_names]
442441
if base_names:
443442
imp_names = [] # type: List[str]
444443
for name, alias in base_names:
@@ -468,6 +467,12 @@ def visit_import(self, o: Import) -> None:
468467
'.' not in id):
469468
self.add_import_line('import %s as %s\n' % (id, target_name))
470469
self.record_name(target_name)
470+
base_class_imports = [base_class.split('.')[0] for base_class in self._base_classes]
471+
if target_name in base_class_imports:
472+
if as_id:
473+
self.add_import_line('import %s as %s\n' % (id, as_id))
474+
else:
475+
self.add_import_line('import %s\n' % (id, ))
471476

472477
def get_init(self, lvalue: str, rvalue: Expression) -> Optional[str]:
473478
"""Return initializer for a variable.

test-data/unit/stubgen.test

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,14 @@ import x.y
530530

531531
class D(x.y.C): ...
532532

533+
[case testArbitraryBaseClassWithAlias]
534+
import x as y
535+
class D(y.C): ...
536+
[out]
537+
import x as y
538+
539+
class D(y.C): ...
540+
533541
[case testUnqualifiedArbitraryBaseClassWithNoDef]
534542
class A(int): ...
535543
[out]
@@ -628,5 +636,24 @@ class A:
628636
x = ... # type: Any
629637
def __init__(self, a: Optional[Any] = ...) -> None: ...
630638

639+
[case testImportAddedForQualifiedBaseClass]
640+
from foo import bar
641+
642+
class A(bar.fuzz.Baz): ...
643+
[out]
644+
from foo import bar
645+
646+
class A(bar.fuzz.Baz): ...
647+
648+
[case testImportAddedForQualifiedBaseClassWithAlias]
649+
from foo import bar as baz
650+
651+
class A(baz.Baz): ...
652+
[out]
653+
from foo import bar as baz
654+
655+
class A(baz.Baz): ...
656+
657+
631658
-- More features/fixes:
632659
-- do not export deleted names

0 commit comments

Comments
 (0)