From 9e2ab27fd395d72eaf2731298064fa63c05acb50 Mon Sep 17 00:00:00 2001 From: Pranav Rajpal <78008260+pranavrajpal@users.noreply.github.com> Date: Tue, 20 Jul 2021 13:45:16 -0700 Subject: [PATCH 1/8] Add test Add a test that makes sure that our implementation goes up the MRO of the dispatch type and uses the registered implementation that comes first, instead of going through implementations in the reverse order. --- mypyc/test-data/run-singledispatch.test | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/mypyc/test-data/run-singledispatch.test b/mypyc/test-data/run-singledispatch.test index cd5d48f32aaf..8e278c52fef6 100644 --- a/mypyc/test-data/run-singledispatch.test +++ b/mypyc/test-data/run-singledispatch.test @@ -465,3 +465,25 @@ def h(arg: str) -> int: def test_singledispatch(): assert f('a') == 35 assert f(A()) == 10 + +[case testMoreSpecificTypeBeforeLessSpecificType] +from functools import singledispatch +class A: pass +class B(A): pass + +@singledispatch +def f(arg) -> str: + return 'default' + +@f.register +def g(arg: B) -> str: + return 'b' + +@f.register +def h(arg: A) -> str: + return 'a' + +def test_singledispatch(): + assert f(B()) == 'b' + assert f(A()) == 'a' + assert f(5) == 'default' From c03ef11fb92b62ebaee2240f8562136a57dd9c81 Mon Sep 17 00:00:00 2001 From: Pranav Rajpal <78008260+pranavrajpal@users.noreply.github.com> Date: Tue, 20 Jul 2021 14:00:07 -0700 Subject: [PATCH 2/8] Check subclasses before their superclasses When creating a dispatch function for a singledispatch function, avoid checking any classes before we've checked all of their subclasses. That makes sure we use the implementation with the dispatch type that appears first in the actual argument's type's MRO, which is what the standard library implementation of singledispatch does (it goes up the argument's MRO, looking for classes that have registered implementations associated with them). --- mypyc/irbuild/function.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/mypyc/irbuild/function.py b/mypyc/irbuild/function.py index 1b68e5b59029..ba469cc7c40f 100644 --- a/mypyc/irbuild/function.py +++ b/mypyc/irbuild/function.py @@ -10,7 +10,7 @@ instance of the callable class. """ -from typing import NamedTuple, Optional, List, Sequence, Tuple, Union, Dict +from typing import Iterable, NamedTuple, Optional, List, Sequence, Tuple, Union, Dict from mypy.nodes import ( ClassDef, FuncDef, OverloadedFuncDef, Decorator, Var, YieldFromExpr, AwaitExpr, YieldExpr, @@ -822,9 +822,10 @@ def gen_func_call_and_return(func_name: str) -> None: coerced = builder.coerce(ret_val, current_func_decl.sig.ret_type, line) builder.nonlocal_control[-1].gen_return(builder, coerced, line) - # Reverse the list of registered implementations so we use the implementations defined later - # if there are multiple overlapping implementations - for dispatch_type, impl in reversed(impls): + # Sort the list of implementations so that we check any subclasses before we check the classes + # they inherit from, to better match singledispatch's behavior of going through the argument's + # MRO, and using the first implementation it finds + for dispatch_type, impl in sort_with_subclasses_first(impls): call_impl, next_impl = BasicBlock(), BasicBlock() should_call_impl = check_if_isinstance(builder, arg_info.args[0], dispatch_type, line) builder.add_bool_branch(should_call_impl, call_impl, next_impl) @@ -857,3 +858,24 @@ def gen_dispatch_func_ir( func_decl = FuncDecl(dispatch_name, None, builder.module_name, sig) dispatch_func_ir = FuncIR(func_decl, args, blocks) return dispatch_func_ir + + +def sort_with_subclasses_first( + impls: List[Tuple[TypeInfo, FuncDef]] +) -> Iterable[Tuple[TypeInfo, FuncDef]]: + def is_subclass(typ1: TypeInfo, typ2: TypeInfo) -> bool: + return typ2 in typ1.mro + dispatch_types: List[TypeInfo] = [] + funcs: List[FuncDef] = [] + for typ, impl in impls: + # If this type is a subclass of anything we've seen previously, put it in the front so it + # gets checked first + if any(is_subclass(typ, prev_type) for prev_type in dispatch_types): + dispatch_types.insert(0, typ) + funcs.insert(0, impl) + # Otherwise, this type isn't related to any of the other dispatch types, so we can just put + # it at the back + else: + dispatch_types.append(typ) + funcs.append(impl) + return zip(dispatch_types, funcs) From d5f9f36bac2c11a36052d6518d628e891b62cf93 Mon Sep 17 00:00:00 2001 From: Pranav Rajpal <78008260+pranavrajpal@users.noreply.github.com> Date: Thu, 22 Jul 2021 13:09:43 -0700 Subject: [PATCH 3/8] Add test to reproduce incorrect class sorting Add a test to reproduce the bug in the class sorting implementation mentioned in https://github.com/python/mypy/pull/10844#discussion_r674721866. --- mypyc/test-data/run-singledispatch.test | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/mypyc/test-data/run-singledispatch.test b/mypyc/test-data/run-singledispatch.test index 8e278c52fef6..5d38e27272af 100644 --- a/mypyc/test-data/run-singledispatch.test +++ b/mypyc/test-data/run-singledispatch.test @@ -487,3 +487,28 @@ def test_singledispatch(): assert f(B()) == 'b' assert f(A()) == 'a' assert f(5) == 'default' + +[case testMultipleRelatedClassesBeingRegistered] +from functools import singledispatch + +class A: pass +class B(A): pass +class C(B): pass + +@singledispatch +def f(arg) -> str: return 'default' + +@f.register +def _(arg: A) -> str: return 'a' + +@f.register +def _(arg: C) -> str: return 'c' + +@f.register +def _(arg: B) -> str: return 'b' + +def test_singledispatch(): + assert f(A()) == 'a' + assert f(B()) == 'b' + assert f(C()) == 'c' + assert f(1) == 'default' From abfe44c05810f350e318345844de258297d00dae Mon Sep 17 00:00:00 2001 From: Pranav Rajpal <78008260+pranavrajpal@users.noreply.github.com> Date: Thu, 22 Jul 2021 13:13:40 -0700 Subject: [PATCH 4/8] Improve class sorting implementation Improve the sorting of classes to: - always put subclasses before any of their superclasses, even when there are multiple related classes (see testMultipleRelatedClassesBeingRegistered for an example of this) - preserve the order of implementations with unrelated dispatch types --- mypyc/irbuild/function.py | 52 +++++++++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 16 deletions(-) diff --git a/mypyc/irbuild/function.py b/mypyc/irbuild/function.py index ba469cc7c40f..dfebe0228b55 100644 --- a/mypyc/irbuild/function.py +++ b/mypyc/irbuild/function.py @@ -862,20 +862,40 @@ def gen_dispatch_func_ir( def sort_with_subclasses_first( impls: List[Tuple[TypeInfo, FuncDef]] -) -> Iterable[Tuple[TypeInfo, FuncDef]]: - def is_subclass(typ1: TypeInfo, typ2: TypeInfo) -> bool: - return typ2 in typ1.mro - dispatch_types: List[TypeInfo] = [] - funcs: List[FuncDef] = [] +) -> List[Tuple[TypeInfo, FuncDef]]: + def is_related(typ1: TypeInfo, typ2: TypeInfo) -> bool: + return typ2 in typ1.mro or typ2 in typ1.mro + + def overlapping_with_related_impl_list(typ: TypeInfo, func: FuncDef) -> bool: + """Place the dispatch type and registered function in the correct related_types list + + Returns True if the type and function were placed in a related_types list, and False if + they weren't + """ + for group in related_types: + if any(is_related(prev_type, typ) for prev_type, _ in group): + group.append((typ, func)) + return True + for i, (prev_type, impl) in enumerate(unrelated_types): + if is_related(typ, prev_type): + related_types.append([(prev_type, impl), (typ, func)]) + del unrelated_types[i] + return True + return False + + # a list of impls with dispatch types that are unrelated (none of the dispatch types overlap) + unrelated_types: List[Tuple[TypeInfo, FuncDef]] = [] + # each inner list of impls is a collection of impls that have overlapping dispatch types + related_types: List[List[Tuple[TypeInfo, FuncDef]]] = [] for typ, impl in impls: - # If this type is a subclass of anything we've seen previously, put it in the front so it - # gets checked first - if any(is_subclass(typ, prev_type) for prev_type in dispatch_types): - dispatch_types.insert(0, typ) - funcs.insert(0, impl) - # Otherwise, this type isn't related to any of the other dispatch types, so we can just put - # it at the back - else: - dispatch_types.append(typ) - funcs.append(impl) - return zip(dispatch_types, funcs) + if not overlapping_with_related_impl_list(typ, impl): + unrelated_types.append((typ, impl)) + sorted_impls = unrelated_types + for group in related_types: + # Classes with the longest MRO should be checked first because they have more superclasses, + # meaning they are probably subclasses of classes with shorter MROs + # TODO: check if that's always true (especially when dealing with multiple classes having a + # common base class) + group.sort(key=lambda impl: len(impl[0].mro), reverse=True) + sorted_impls.extend(group) + return sorted_impls From d5921e4e24649770b0848c6bb68a800d72b9d8a5 Mon Sep 17 00:00:00 2001 From: Pranav Rajpal <78008260+pranavrajpal@users.noreply.github.com> Date: Thu, 22 Jul 2021 13:34:45 -0700 Subject: [PATCH 5/8] Remove unused Iterable import The one use of Iterable was removed in the previous commit, so this is unnecessary. --- mypyc/irbuild/function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypyc/irbuild/function.py b/mypyc/irbuild/function.py index dfebe0228b55..6810435795c0 100644 --- a/mypyc/irbuild/function.py +++ b/mypyc/irbuild/function.py @@ -10,7 +10,7 @@ instance of the callable class. """ -from typing import Iterable, NamedTuple, Optional, List, Sequence, Tuple, Union, Dict +from typing import NamedTuple, Optional, List, Sequence, Tuple, Union, Dict from mypy.nodes import ( ClassDef, FuncDef, OverloadedFuncDef, Decorator, Var, YieldFromExpr, AwaitExpr, YieldExpr, From ea12de0ddcad2624128beaf3558001319d2dba26 Mon Sep 17 00:00:00 2001 From: Pranav Rajpal <78008260+pranavrajpal@users.noreply.github.com> Date: Thu, 22 Jul 2021 16:57:41 -0700 Subject: [PATCH 6/8] Change class sorting to use topological sort Change the sorting of classes that we use when generating dispatch functions to use topological sort on a graph made up of the classes with edges pointing from subclasses to the classes on their MRO. This also modifies the signature of topsort to take any type for the vertices of the graph. --- mypy/build.py | 17 ++++++------ mypyc/irbuild/function.py | 57 +++++++++++++-------------------------- 2 files changed, 28 insertions(+), 46 deletions(-) diff --git a/mypy/build.py b/mypy/build.py index bafdf6b86ba3..b8831b21a4ad 100644 --- a/mypy/build.py +++ b/mypy/build.py @@ -22,7 +22,7 @@ import types from typing import (AbstractSet, Any, Dict, Iterable, Iterator, List, Sequence, - Mapping, NamedTuple, Optional, Set, Tuple, Union, Callable, TextIO) + Mapping, NamedTuple, Optional, Set, Tuple, TypeVar, Union, Callable, TextIO) from typing_extensions import ClassVar, Final, TYPE_CHECKING from mypy_extensions import TypedDict @@ -3234,21 +3234,22 @@ def dfs(v: str) -> Iterator[Set[str]]: yield from dfs(v) -def topsort(data: Dict[AbstractSet[str], - Set[AbstractSet[str]]]) -> Iterable[Set[AbstractSet[str]]]: +T = TypeVar("T") + + +def topsort(data: Dict[T, Set[T]]) -> Iterable[Set[T]]: """Topological sort. Args: - data: A map from SCCs (represented as frozen sets of strings) to - sets of SCCs, its dependencies. NOTE: This data structure + data: A map from vertices to all vertices that it has an edge + connecting it to. NOTE: This data structure is modified in place -- for normalization purposes, self-dependencies are removed and entries representing orphans are added. Returns: - An iterator yielding sets of SCCs that have an equivalent - ordering. NOTE: The algorithm doesn't care about the internal - structure of SCCs. + An iterator yielding sets of vertices that have an equivalent + ordering. Example: Suppose the input has the following structure: diff --git a/mypyc/irbuild/function.py b/mypyc/irbuild/function.py index 6810435795c0..0f5adb02ce10 100644 --- a/mypyc/irbuild/function.py +++ b/mypyc/irbuild/function.py @@ -10,7 +10,11 @@ instance of the callable class. """ -from typing import NamedTuple, Optional, List, Sequence, Tuple, Union, Dict +from collections import defaultdict +from mypy.build import topsort +from typing import ( + NamedTuple, Optional, List, Sequence, Tuple, Union, Dict, DefaultDict, Iterator, Set, +) from mypy.nodes import ( ClassDef, FuncDef, OverloadedFuncDef, Decorator, Var, YieldFromExpr, AwaitExpr, YieldExpr, @@ -862,40 +866,17 @@ def gen_dispatch_func_ir( def sort_with_subclasses_first( impls: List[Tuple[TypeInfo, FuncDef]] -) -> List[Tuple[TypeInfo, FuncDef]]: - def is_related(typ1: TypeInfo, typ2: TypeInfo) -> bool: - return typ2 in typ1.mro or typ2 in typ1.mro - - def overlapping_with_related_impl_list(typ: TypeInfo, func: FuncDef) -> bool: - """Place the dispatch type and registered function in the correct related_types list - - Returns True if the type and function were placed in a related_types list, and False if - they weren't - """ - for group in related_types: - if any(is_related(prev_type, typ) for prev_type, _ in group): - group.append((typ, func)) - return True - for i, (prev_type, impl) in enumerate(unrelated_types): - if is_related(typ, prev_type): - related_types.append([(prev_type, impl), (typ, func)]) - del unrelated_types[i] - return True - return False - - # a list of impls with dispatch types that are unrelated (none of the dispatch types overlap) - unrelated_types: List[Tuple[TypeInfo, FuncDef]] = [] - # each inner list of impls is a collection of impls that have overlapping dispatch types - related_types: List[List[Tuple[TypeInfo, FuncDef]]] = [] - for typ, impl in impls: - if not overlapping_with_related_impl_list(typ, impl): - unrelated_types.append((typ, impl)) - sorted_impls = unrelated_types - for group in related_types: - # Classes with the longest MRO should be checked first because they have more superclasses, - # meaning they are probably subclasses of classes with shorter MROs - # TODO: check if that's always true (especially when dealing with multiple classes having a - # common base class) - group.sort(key=lambda impl: len(impl[0].mro), reverse=True) - sorted_impls.extend(group) - return sorted_impls +) -> Iterator[Tuple[TypeInfo, FuncDef]]: + + # graph with edges pointing from every class to their subclasses + graph: DefaultDict[TypeInfo, Set[TypeInfo]] = defaultdict(set) + for impl in impls: + typ, _ = impl + for t in typ.mro[1:]: + graph[typ].add(t) + + dispatch_types = topsort(dict(graph)) + impl_dict = {typ: func for typ, func in impls} + + for group in reversed(list(dispatch_types)): + yield from ((typ, impl_dict[typ]) for typ in group if typ in impl_dict) From c6d55c99a5ded0e438e697712b82c5bba8394f81 Mon Sep 17 00:00:00 2001 From: Pranav Rajpal <78008260+pranavrajpal@users.noreply.github.com> Date: Thu, 22 Jul 2021 17:39:52 -0700 Subject: [PATCH 7/8] Create graph with a dictionary comprehension Use a dictionary comprehension to make the graph creation nicer. --- mypyc/irbuild/function.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/mypyc/irbuild/function.py b/mypyc/irbuild/function.py index 0f5adb02ce10..14d319521763 100644 --- a/mypyc/irbuild/function.py +++ b/mypyc/irbuild/function.py @@ -869,13 +869,9 @@ def sort_with_subclasses_first( ) -> Iterator[Tuple[TypeInfo, FuncDef]]: # graph with edges pointing from every class to their subclasses - graph: DefaultDict[TypeInfo, Set[TypeInfo]] = defaultdict(set) - for impl in impls: - typ, _ = impl - for t in typ.mro[1:]: - graph[typ].add(t) + graph = {typ: set(typ.mro[1:]) for typ, _ in impls} - dispatch_types = topsort(dict(graph)) + dispatch_types = topsort(graph) impl_dict = {typ: func for typ, func in impls} for group in reversed(list(dispatch_types)): From cea7e917fcdf604fcd03ebe20788991065e809b4 Mon Sep 17 00:00:00 2001 From: Pranav Rajpal <78008260+pranavrajpal@users.noreply.github.com> Date: Thu, 22 Jul 2021 17:51:29 -0700 Subject: [PATCH 8/8] Remove unnecessary imports --- mypyc/irbuild/function.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mypyc/irbuild/function.py b/mypyc/irbuild/function.py index 14d319521763..c4e99646eb03 100644 --- a/mypyc/irbuild/function.py +++ b/mypyc/irbuild/function.py @@ -10,10 +10,9 @@ instance of the callable class. """ -from collections import defaultdict from mypy.build import topsort from typing import ( - NamedTuple, Optional, List, Sequence, Tuple, Union, Dict, DefaultDict, Iterator, Set, + NamedTuple, Optional, List, Sequence, Tuple, Union, Dict, Iterator, ) from mypy.nodes import (