Skip to content

Commit 92923b2

Browse files
authored
stubgen: properly convert overloaded functions (#9613)
1 parent e9edcb9 commit 92923b2

File tree

2 files changed

+253
-15
lines changed

2 files changed

+253
-15
lines changed

mypy/stubgen.py

Lines changed: 72 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -559,10 +559,33 @@ def visit_mypy_file(self, o: MypyFile) -> None:
559559
for name in sorted(undefined_names):
560560
self.add('# %s\n' % name)
561561

562-
def visit_func_def(self, o: FuncDef, is_abstract: bool = False) -> None:
562+
def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None:
563+
"""@property with setters and getters, or @overload chain"""
564+
overload_chain = False
565+
for item in o.items:
566+
if not isinstance(item, Decorator):
567+
continue
568+
569+
if self.is_private_name(item.func.name, item.func.fullname):
570+
continue
571+
572+
is_abstract, is_overload = self.process_decorator(item)
573+
574+
if not overload_chain:
575+
self.visit_func_def(item.func, is_abstract=is_abstract, is_overload=is_overload)
576+
if is_overload:
577+
overload_chain = True
578+
elif overload_chain and is_overload:
579+
self.visit_func_def(item.func, is_abstract=is_abstract, is_overload=is_overload)
580+
else:
581+
# skip the overload implementation and clear the decorator we just processed
582+
self.clear_decorators()
583+
584+
def visit_func_def(self, o: FuncDef, is_abstract: bool = False,
585+
is_overload: bool = False) -> None:
563586
if (self.is_private_name(o.name, o.fullname)
564587
or self.is_not_in_all(o.name)
565-
or self.is_recorded_name(o.name)):
588+
or (self.is_recorded_name(o.name) and not is_overload)):
566589
self.clear_decorators()
567590
return
568591
if not self._indent and self._state not in (EMPTY, FUNC) and not o.is_awaitable_coroutine:
@@ -599,7 +622,7 @@ def visit_func_def(self, o: FuncDef, is_abstract: bool = False) -> None:
599622
and not is_cls_arg):
600623
self.add_typing_import("Any")
601624
annotation = ": {}".format(self.typing_name("Any"))
602-
elif annotated_type and not is_self_arg:
625+
elif annotated_type and not is_self_arg and not is_cls_arg:
603626
annotation = ": {}".format(self.print_annotation(annotated_type))
604627
else:
605628
annotation = ""
@@ -642,24 +665,43 @@ def visit_func_def(self, o: FuncDef, is_abstract: bool = False) -> None:
642665
def visit_decorator(self, o: Decorator) -> None:
643666
if self.is_private_name(o.func.name, o.func.fullname):
644667
return
668+
669+
is_abstract, _ = self.process_decorator(o)
670+
self.visit_func_def(o.func, is_abstract=is_abstract)
671+
672+
def process_decorator(self, o: Decorator) -> Tuple[bool, bool]:
673+
"""Process a series of decorataors.
674+
675+
Only preserve certain special decorators such as @abstractmethod.
676+
677+
Return a pair of booleans:
678+
- True if any of the decorators makes a method abstract.
679+
- True if any of the decorators is typing.overload.
680+
"""
645681
is_abstract = False
682+
is_overload = False
646683
for decorator in o.original_decorators:
647684
if isinstance(decorator, NameExpr):
648-
if self.process_name_expr_decorator(decorator, o):
649-
is_abstract = True
685+
i_is_abstract, i_is_overload = self.process_name_expr_decorator(decorator, o)
686+
is_abstract = is_abstract or i_is_abstract
687+
is_overload = is_overload or i_is_overload
650688
elif isinstance(decorator, MemberExpr):
651-
if self.process_member_expr_decorator(decorator, o):
652-
is_abstract = True
653-
self.visit_func_def(o.func, is_abstract=is_abstract)
689+
i_is_abstract, i_is_overload = self.process_member_expr_decorator(decorator, o)
690+
is_abstract = is_abstract or i_is_abstract
691+
is_overload = is_overload or i_is_overload
692+
return is_abstract, is_overload
654693

655-
def process_name_expr_decorator(self, expr: NameExpr, context: Decorator) -> bool:
694+
def process_name_expr_decorator(self, expr: NameExpr, context: Decorator) -> Tuple[bool, bool]:
656695
"""Process a function decorator of form @foo.
657696
658697
Only preserve certain special decorators such as @abstractmethod.
659698
660-
Return True if the decorator makes a method abstract.
699+
Return a pair of booleans:
700+
- True if the decorator makes a method abstract.
701+
- True if the decorator is typing.overload.
661702
"""
662703
is_abstract = False
704+
is_overload = False
663705
name = expr.name
664706
if name in ('property', 'staticmethod', 'classmethod'):
665707
self.add_decorator(name)
@@ -675,27 +717,35 @@ def process_name_expr_decorator(self, expr: NameExpr, context: Decorator) -> boo
675717
self.add_decorator('property')
676718
self.add_decorator('abc.abstractmethod')
677719
is_abstract = True
678-
return is_abstract
720+
elif self.refers_to_fullname(name, 'typing.overload'):
721+
self.add_decorator(name)
722+
self.add_typing_import('overload')
723+
is_overload = True
724+
return is_abstract, is_overload
679725

680726
def refers_to_fullname(self, name: str, fullname: str) -> bool:
681727
module, short = fullname.rsplit('.', 1)
682728
return (self.import_tracker.module_for.get(name) == module and
683729
(name == short or
684730
self.import_tracker.reverse_alias.get(name) == short))
685731

686-
def process_member_expr_decorator(self, expr: MemberExpr, context: Decorator) -> bool:
732+
def process_member_expr_decorator(self, expr: MemberExpr, context: Decorator) -> Tuple[bool,
733+
bool]:
687734
"""Process a function decorator of form @foo.bar.
688735
689736
Only preserve certain special decorators such as @abstractmethod.
690737
691-
Return True if the decorator makes a method abstract.
738+
Return a pair of booleans:
739+
- True if the decorator makes a method abstract.
740+
- True if the decorator is typing.overload.
692741
"""
693742
is_abstract = False
743+
is_overload = False
694744
if expr.name == 'setter' and isinstance(expr.expr, NameExpr):
695745
self.add_decorator('%s.setter' % expr.expr.name)
696746
elif (isinstance(expr.expr, NameExpr) and
697747
(expr.expr.name == 'abc' or
698-
self.import_tracker.reverse_alias.get('abc')) and
748+
self.import_tracker.reverse_alias.get(expr.expr.name) == 'abc') and
699749
expr.name in ('abstractmethod', 'abstractproperty')):
700750
if expr.name == 'abstractproperty':
701751
self.import_tracker.require_name(expr.expr.name)
@@ -723,7 +773,14 @@ def process_member_expr_decorator(self, expr: MemberExpr, context: Decorator) ->
723773
self.add_coroutine_decorator(context.func,
724774
expr.expr.name + '.coroutine',
725775
expr.expr.name)
726-
return is_abstract
776+
elif (isinstance(expr.expr, NameExpr) and
777+
(expr.expr.name == 'typing' or
778+
self.import_tracker.reverse_alias.get(expr.expr.name) == 'typing') and
779+
expr.name == 'overload'):
780+
self.import_tracker.require_name(expr.expr.name)
781+
self.add_decorator('%s.%s' % (expr.expr.name, 'overload'))
782+
is_overload = True
783+
return is_abstract, is_overload
727784

728785
def visit_class_def(self, o: ClassDef) -> None:
729786
self.method_names = find_method_names(o.defs.body)

test-data/unit/stubgen.test

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1471,6 +1471,20 @@ class A(metaclass=abc.ABCMeta):
14711471
@abc.abstractmethod
14721472
def meth(self): ...
14731473

1474+
[case testAbstractMethodMemberExpr2]
1475+
import abc as _abc
1476+
1477+
class A(metaclass=abc.ABCMeta):
1478+
@_abc.abstractmethod
1479+
def meth(self):
1480+
pass
1481+
[out]
1482+
import abc as _abc
1483+
1484+
class A(metaclass=abc.ABCMeta):
1485+
@_abc.abstractmethod
1486+
def meth(self): ...
1487+
14741488
[case testABCMeta_semanal]
14751489
from base import Base
14761490
from abc import abstractmethod
@@ -2288,3 +2302,170 @@ import p.a
22882302

22892303
x: a.X
22902304
y: p.a.Y
2305+
2306+
[case testOverload_fromTypingImport]
2307+
from typing import Tuple, Union, overload
2308+
2309+
class A:
2310+
@overload
2311+
def f(self, x: int, y: int) -> int:
2312+
...
2313+
2314+
@overload
2315+
def f(self, x: Tuple[int, int]) -> int:
2316+
...
2317+
2318+
def f(self, *args: Union[int, Tuple[int, int]]) -> int:
2319+
pass
2320+
2321+
@overload
2322+
def f(x: int, y: int) -> int:
2323+
...
2324+
2325+
@overload
2326+
def f(x: Tuple[int, int]) -> int:
2327+
...
2328+
2329+
def f(*args: Union[int, Tuple[int, int]]) -> int:
2330+
pass
2331+
2332+
2333+
[out]
2334+
from typing import Tuple, overload
2335+
2336+
class A:
2337+
@overload
2338+
def f(self, x: int, y: int) -> int: ...
2339+
@overload
2340+
def f(self, x: Tuple[int, int]) -> int: ...
2341+
2342+
2343+
@overload
2344+
def f(x: int, y: int) -> int: ...
2345+
@overload
2346+
def f(x: Tuple[int, int]) -> int: ...
2347+
2348+
[case testOverload_importTyping]
2349+
import typing
2350+
2351+
class A:
2352+
@typing.overload
2353+
def f(self, x: int, y: int) -> int:
2354+
...
2355+
2356+
@typing.overload
2357+
def f(self, x: typing.Tuple[int, int]) -> int:
2358+
...
2359+
2360+
def f(self, *args: typing.Union[int, typing.Tuple[int, int]]) -> int:
2361+
pass
2362+
2363+
@typing.overload
2364+
@classmethod
2365+
def g(cls, x: int, y: int) -> int:
2366+
...
2367+
2368+
@typing.overload
2369+
@classmethod
2370+
def g(cls, x: typing.Tuple[int, int]) -> int:
2371+
...
2372+
2373+
@classmethod
2374+
def g(self, *args: typing.Union[int, typing.Tuple[int, int]]) -> int:
2375+
pass
2376+
2377+
@typing.overload
2378+
def f(x: int, y: int) -> int:
2379+
...
2380+
2381+
@typing.overload
2382+
def f(x: typing.Tuple[int, int]) -> int:
2383+
...
2384+
2385+
def f(*args: typing.Union[int, typing.Tuple[int, int]]) -> int:
2386+
pass
2387+
2388+
2389+
[out]
2390+
import typing
2391+
2392+
class A:
2393+
@typing.overload
2394+
def f(self, x: int, y: int) -> int: ...
2395+
@typing.overload
2396+
def f(self, x: typing.Tuple[int, int]) -> int: ...
2397+
@typing.overload
2398+
@classmethod
2399+
def g(cls, x: int, y: int) -> int: ...
2400+
@typing.overload
2401+
@classmethod
2402+
def g(cls, x: typing.Tuple[int, int]) -> int: ...
2403+
2404+
2405+
@typing.overload
2406+
def f(x: int, y: int) -> int: ...
2407+
@typing.overload
2408+
def f(x: typing.Tuple[int, int]) -> int: ...
2409+
2410+
2411+
[case testOverload_importTypingAs]
2412+
import typing as t
2413+
2414+
class A:
2415+
@t.overload
2416+
def f(self, x: int, y: int) -> int:
2417+
...
2418+
2419+
@t.overload
2420+
def f(self, x: t.Tuple[int, int]) -> int:
2421+
...
2422+
2423+
def f(self, *args: typing.Union[int, t.Tuple[int, int]]) -> int:
2424+
pass
2425+
2426+
@t.overload
2427+
@classmethod
2428+
def g(cls, x: int, y: int) -> int:
2429+
...
2430+
2431+
@t.overload
2432+
@classmethod
2433+
def g(cls, x: t.Tuple[int, int]) -> int:
2434+
...
2435+
2436+
@classmethod
2437+
def g(self, *args: t.Union[int, t.Tuple[int, int]]) -> int:
2438+
pass
2439+
2440+
@t.overload
2441+
def f(x: int, y: int) -> int:
2442+
...
2443+
2444+
@t.overload
2445+
def f(x: t.Tuple[int, int]) -> int:
2446+
...
2447+
2448+
def f(*args: t.Union[int, t.Tuple[int, int]]) -> int:
2449+
pass
2450+
2451+
2452+
[out]
2453+
import typing as t
2454+
2455+
class A:
2456+
@t.overload
2457+
def f(self, x: int, y: int) -> int: ...
2458+
@t.overload
2459+
def f(self, x: t.Tuple[int, int]) -> int: ...
2460+
@t.overload
2461+
@classmethod
2462+
def g(cls, x: int, y: int) -> int: ...
2463+
@t.overload
2464+
@classmethod
2465+
def g(cls, x: t.Tuple[int, int]) -> int: ...
2466+
2467+
2468+
@t.overload
2469+
def f(x: int, y: int) -> int: ...
2470+
@t.overload
2471+
def f(x: t.Tuple[int, int]) -> int: ...

0 commit comments

Comments
 (0)