@@ -347,9 +347,7 @@ def get_base_types(self, cdef: ClassDef) -> List[str]:
347
347
if base .name != 'object' :
348
348
base_types .append (base .name )
349
349
elif isinstance (base , MemberExpr ):
350
- modname = get_qualified_name (base .expr )
351
- base_types .append ('%s.%s' % (modname , base .name ))
352
- self .add_import_line ('import %s\n ' % modname )
350
+ base_types .append (get_qualified_name (base ))
353
351
return base_types
354
352
355
353
def visit_assignment_stmt (self , o : AssignmentStmt ) -> None :
@@ -437,8 +435,9 @@ def visit_import_from(self, o: ImportFrom) -> None:
437
435
exported_names .update (sub_names )
438
436
self .import_and_export_names (o .id , o .relative , sub_names )
439
437
# Import names used as base classes.
438
+ base_class_imports = [base_class .split ('.' )[0 ] for base_class in self ._base_classes ]
440
439
base_names = [(name , alias ) for name , alias in o .names
441
- if alias or name in self . _base_classes and name not in exported_names ]
440
+ if alias or name in base_class_imports and name not in exported_names ]
442
441
if base_names :
443
442
imp_names = [] # type: List[str]
444
443
for name , alias in base_names :
@@ -468,6 +467,12 @@ def visit_import(self, o: Import) -> None:
468
467
'.' not in id ):
469
468
self .add_import_line ('import %s as %s\n ' % (id , target_name ))
470
469
self .record_name (target_name )
470
+ base_class_imports = [base_class .split ('.' )[0 ] for base_class in self ._base_classes ]
471
+ if target_name in base_class_imports :
472
+ if as_id :
473
+ self .add_import_line ('import %s as %s\n ' % (id , as_id ))
474
+ else :
475
+ self .add_import_line ('import %s\n ' % (id , ))
471
476
472
477
def get_init (self , lvalue : str , rvalue : Expression ) -> Optional [str ]:
473
478
"""Return initializer for a variable.
0 commit comments