Skip to content

Commit 97b3b90

Browse files
authored
[mypyc] Fix order of dispatch type checking in singledispatch functions (#10844)
* Add test Add a test that makes sure that our implementation goes up the MRO of the dispatch type and uses the registered implementation that comes first, instead of going through implementations in the reverse order. * Check subclasses before their superclasses When creating a dispatch function for a singledispatch function, avoid checking any classes before we've checked all of their subclasses. That makes sure we use the implementation with the dispatch type that appears first in the actual argument's type's MRO, which is what the standard library implementation of singledispatch does (it goes up the argument's MRO, looking for classes that have registered implementations associated with them). * Add test to reproduce incorrect class sorting Add a test to reproduce the bug in the class sorting implementation mentioned in #10844 (comment). * Improve class sorting implementation Improve the sorting of classes to: - always put subclasses before any of their superclasses, even when there are multiple related classes (see testMultipleRelatedClassesBeingRegistered for an example of this) - preserve the order of implementations with unrelated dispatch types * Remove unused Iterable import The one use of Iterable was removed in the previous commit, so this is unnecessary. * Change class sorting to use topological sort Change the sorting of classes that we use when generating dispatch functions to use topological sort on a graph made up of the classes with edges pointing from subclasses to the classes on their MRO. This also modifies the signature of topsort to take any type for the vertices of the graph. * Create graph with a dictionary comprehension Use a dictionary comprehension to make the graph creation nicer. * Remove unnecessary imports
1 parent fb75cdc commit 97b3b90

File tree

3 files changed

+78
-12
lines changed

3 files changed

+78
-12
lines changed

mypy/build.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import types
2323

2424
from typing import (AbstractSet, Any, Dict, Iterable, Iterator, List, Sequence,
25-
Mapping, NamedTuple, Optional, Set, Tuple, Union, Callable, TextIO)
25+
Mapping, NamedTuple, Optional, Set, Tuple, TypeVar, Union, Callable, TextIO)
2626
from typing_extensions import ClassVar, Final, TYPE_CHECKING
2727
from mypy_extensions import TypedDict
2828

@@ -3234,21 +3234,22 @@ def dfs(v: str) -> Iterator[Set[str]]:
32343234
yield from dfs(v)
32353235

32363236

3237-
def topsort(data: Dict[AbstractSet[str],
3238-
Set[AbstractSet[str]]]) -> Iterable[Set[AbstractSet[str]]]:
3237+
T = TypeVar("T")
3238+
3239+
3240+
def topsort(data: Dict[T, Set[T]]) -> Iterable[Set[T]]:
32393241
"""Topological sort.
32403242
32413243
Args:
3242-
data: A map from SCCs (represented as frozen sets of strings) to
3243-
sets of SCCs, its dependencies. NOTE: This data structure
3244+
data: A map from vertices to all vertices that it has an edge
3245+
connecting it to. NOTE: This data structure
32443246
is modified in place -- for normalization purposes,
32453247
self-dependencies are removed and entries representing
32463248
orphans are added.
32473249
32483250
Returns:
3249-
An iterator yielding sets of SCCs that have an equivalent
3250-
ordering. NOTE: The algorithm doesn't care about the internal
3251-
structure of SCCs.
3251+
An iterator yielding sets of vertices that have an equivalent
3252+
ordering.
32523253
32533254
Example:
32543255
Suppose the input has the following structure:

mypyc/irbuild/function.py

+22-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
instance of the callable class.
1111
"""
1212

13-
from typing import NamedTuple, Optional, List, Sequence, Tuple, Union, Dict
13+
from mypy.build import topsort
14+
from typing import (
15+
NamedTuple, Optional, List, Sequence, Tuple, Union, Dict, Iterator,
16+
)
1417

1518
from mypy.nodes import (
1619
ClassDef, FuncDef, OverloadedFuncDef, Decorator, Var, YieldFromExpr, AwaitExpr, YieldExpr,
@@ -822,9 +825,10 @@ def gen_func_call_and_return(func_name: str) -> None:
822825
coerced = builder.coerce(ret_val, current_func_decl.sig.ret_type, line)
823826
builder.nonlocal_control[-1].gen_return(builder, coerced, line)
824827

825-
# Reverse the list of registered implementations so we use the implementations defined later
826-
# if there are multiple overlapping implementations
827-
for dispatch_type, impl in reversed(impls):
828+
# Sort the list of implementations so that we check any subclasses before we check the classes
829+
# they inherit from, to better match singledispatch's behavior of going through the argument's
830+
# MRO, and using the first implementation it finds
831+
for dispatch_type, impl in sort_with_subclasses_first(impls):
828832
call_impl, next_impl = BasicBlock(), BasicBlock()
829833
should_call_impl = check_if_isinstance(builder, arg_info.args[0], dispatch_type, line)
830834
builder.add_bool_branch(should_call_impl, call_impl, next_impl)
@@ -857,3 +861,17 @@ def gen_dispatch_func_ir(
857861
func_decl = FuncDecl(dispatch_name, None, builder.module_name, sig)
858862
dispatch_func_ir = FuncIR(func_decl, args, blocks)
859863
return dispatch_func_ir
864+
865+
866+
def sort_with_subclasses_first(
867+
impls: List[Tuple[TypeInfo, FuncDef]]
868+
) -> Iterator[Tuple[TypeInfo, FuncDef]]:
869+
870+
# graph with edges pointing from every class to their subclasses
871+
graph = {typ: set(typ.mro[1:]) for typ, _ in impls}
872+
873+
dispatch_types = topsort(graph)
874+
impl_dict = {typ: func for typ, func in impls}
875+
876+
for group in reversed(list(dispatch_types)):
877+
yield from ((typ, impl_dict[typ]) for typ in group if typ in impl_dict)

mypyc/test-data/run-singledispatch.test

+47
Original file line numberDiff line numberDiff line change
@@ -465,3 +465,50 @@ def h(arg: str) -> int:
465465
def test_singledispatch():
466466
assert f('a') == 35
467467
assert f(A()) == 10
468+
469+
[case testMoreSpecificTypeBeforeLessSpecificType]
470+
from functools import singledispatch
471+
class A: pass
472+
class B(A): pass
473+
474+
@singledispatch
475+
def f(arg) -> str:
476+
return 'default'
477+
478+
@f.register
479+
def g(arg: B) -> str:
480+
return 'b'
481+
482+
@f.register
483+
def h(arg: A) -> str:
484+
return 'a'
485+
486+
def test_singledispatch():
487+
assert f(B()) == 'b'
488+
assert f(A()) == 'a'
489+
assert f(5) == 'default'
490+
491+
[case testMultipleRelatedClassesBeingRegistered]
492+
from functools import singledispatch
493+
494+
class A: pass
495+
class B(A): pass
496+
class C(B): pass
497+
498+
@singledispatch
499+
def f(arg) -> str: return 'default'
500+
501+
@f.register
502+
def _(arg: A) -> str: return 'a'
503+
504+
@f.register
505+
def _(arg: C) -> str: return 'c'
506+
507+
@f.register
508+
def _(arg: B) -> str: return 'b'
509+
510+
def test_singledispatch():
511+
assert f(A()) == 'a'
512+
assert f(B()) == 'b'
513+
assert f(C()) == 'c'
514+
assert f(1) == 'default'

0 commit comments

Comments
 (0)