Skip to content

[mypyc] Fix order of dispatch type checking in singledispatch functions #10844

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import types

from typing import (AbstractSet, Any, Dict, Iterable, Iterator, List, Sequence,
Mapping, NamedTuple, Optional, Set, Tuple, Union, Callable, TextIO)
Mapping, NamedTuple, Optional, Set, Tuple, TypeVar, Union, Callable, TextIO)
from typing_extensions import ClassVar, Final, TYPE_CHECKING
from mypy_extensions import TypedDict

Expand Down Expand Up @@ -3234,21 +3234,22 @@ def dfs(v: str) -> Iterator[Set[str]]:
yield from dfs(v)


def topsort(data: Dict[AbstractSet[str],
Set[AbstractSet[str]]]) -> Iterable[Set[AbstractSet[str]]]:
T = TypeVar("T")


def topsort(data: Dict[T, Set[T]]) -> Iterable[Set[T]]:
"""Topological sort.

Args:
data: A map from SCCs (represented as frozen sets of strings) to
sets of SCCs, its dependencies. NOTE: This data structure
data: A map from vertices to all vertices that it has an edge
connecting it to. NOTE: This data structure
is modified in place -- for normalization purposes,
self-dependencies are removed and entries representing
orphans are added.

Returns:
An iterator yielding sets of SCCs that have an equivalent
ordering. NOTE: The algorithm doesn't care about the internal
structure of SCCs.
An iterator yielding sets of vertices that have an equivalent
ordering.

Example:
Suppose the input has the following structure:
Expand Down
26 changes: 22 additions & 4 deletions mypyc/irbuild/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
instance of the callable class.
"""

from typing import NamedTuple, Optional, List, Sequence, Tuple, Union, Dict
from mypy.build import topsort
from typing import (
NamedTuple, Optional, List, Sequence, Tuple, Union, Dict, Iterator,
)

from mypy.nodes import (
ClassDef, FuncDef, OverloadedFuncDef, Decorator, Var, YieldFromExpr, AwaitExpr, YieldExpr,
Expand Down Expand Up @@ -822,9 +825,10 @@ def gen_func_call_and_return(func_name: str) -> None:
coerced = builder.coerce(ret_val, current_func_decl.sig.ret_type, line)
builder.nonlocal_control[-1].gen_return(builder, coerced, line)

# Reverse the list of registered implementations so we use the implementations defined later
# if there are multiple overlapping implementations
for dispatch_type, impl in reversed(impls):
# Sort the list of implementations so that we check any subclasses before we check the classes
# they inherit from, to better match singledispatch's behavior of going through the argument's
# MRO, and using the first implementation it finds
for dispatch_type, impl in sort_with_subclasses_first(impls):
call_impl, next_impl = BasicBlock(), BasicBlock()
should_call_impl = check_if_isinstance(builder, arg_info.args[0], dispatch_type, line)
builder.add_bool_branch(should_call_impl, call_impl, next_impl)
Expand Down Expand Up @@ -857,3 +861,17 @@ def gen_dispatch_func_ir(
func_decl = FuncDecl(dispatch_name, None, builder.module_name, sig)
dispatch_func_ir = FuncIR(func_decl, args, blocks)
return dispatch_func_ir


def sort_with_subclasses_first(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems plausible, but I think that it might be more straightforward to implement this as a topological sort, instead. I'll follow up via PM.

impls: List[Tuple[TypeInfo, FuncDef]]
) -> Iterator[Tuple[TypeInfo, FuncDef]]:

# graph with edges pointing from every class to their subclasses
graph = {typ: set(typ.mro[1:]) for typ, _ in impls}

dispatch_types = topsort(graph)
impl_dict = {typ: func for typ, func in impls}

for group in reversed(list(dispatch_types)):
yield from ((typ, impl_dict[typ]) for typ in group if typ in impl_dict)
47 changes: 47 additions & 0 deletions mypyc/test-data/run-singledispatch.test
Original file line number Diff line number Diff line change
Expand Up @@ -465,3 +465,50 @@ def h(arg: str) -> int:
def test_singledispatch():
assert f('a') == 35
assert f(A()) == 10

[case testMoreSpecificTypeBeforeLessSpecificType]
from functools import singledispatch
class A: pass
class B(A): pass

@singledispatch
def f(arg) -> str:
return 'default'

@f.register
def g(arg: B) -> str:
return 'b'

@f.register
def h(arg: A) -> str:
return 'a'

def test_singledispatch():
assert f(B()) == 'b'
assert f(A()) == 'a'
assert f(5) == 'default'

[case testMultipleRelatedClassesBeingRegistered]
from functools import singledispatch

class A: pass
class B(A): pass
class C(B): pass

@singledispatch
def f(arg) -> str: return 'default'

@f.register
def _(arg: A) -> str: return 'a'

@f.register
def _(arg: C) -> str: return 'c'

@f.register
def _(arg: B) -> str: return 'b'

def test_singledispatch():
assert f(A()) == 'a'
assert f(B()) == 'b'
assert f(C()) == 'c'
assert f(1) == 'default'