Skip to content

Commit 0871c93

Browse files
authored
Add support for functools.partial (#16939)
Fixes #1484 Turns out that this is currently the second most popular mypy issue (and first most popular is a type system feature request that would need a PEP). I'm sure there's stuff missing, but this should handle most cases.
1 parent ca393dd commit 0871c93

File tree

9 files changed

+454
-27
lines changed

9 files changed

+454
-27
lines changed

mypy/checkexpr.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,14 +1229,14 @@ def apply_function_plugin(
12291229
assert callback is not None # Assume that caller ensures this
12301230
return callback(
12311231
FunctionContext(
1232-
formal_arg_types,
1233-
formal_arg_kinds,
1234-
callee.arg_names,
1235-
formal_arg_names,
1236-
callee.ret_type,
1237-
formal_arg_exprs,
1238-
context,
1239-
self.chk,
1232+
arg_types=formal_arg_types,
1233+
arg_kinds=formal_arg_kinds,
1234+
callee_arg_names=callee.arg_names,
1235+
arg_names=formal_arg_names,
1236+
default_return_type=callee.ret_type,
1237+
args=formal_arg_exprs,
1238+
context=context,
1239+
api=self.chk,
12401240
)
12411241
)
12421242
else:
@@ -1246,15 +1246,15 @@ def apply_function_plugin(
12461246
object_type = get_proper_type(object_type)
12471247
return method_callback(
12481248
MethodContext(
1249-
object_type,
1250-
formal_arg_types,
1251-
formal_arg_kinds,
1252-
callee.arg_names,
1253-
formal_arg_names,
1254-
callee.ret_type,
1255-
formal_arg_exprs,
1256-
context,
1257-
self.chk,
1249+
type=object_type,
1250+
arg_types=formal_arg_types,
1251+
arg_kinds=formal_arg_kinds,
1252+
callee_arg_names=callee.arg_names,
1253+
arg_names=formal_arg_names,
1254+
default_return_type=callee.ret_type,
1255+
args=formal_arg_exprs,
1256+
context=context,
1257+
api=self.chk,
12581258
)
12591259
)
12601260

mypy/fixup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,9 @@ def visit_instance(self, inst: Instance) -> None:
239239
a.accept(self)
240240
if inst.last_known_value is not None:
241241
inst.last_known_value.accept(self)
242+
if inst.extra_attrs:
243+
for v in inst.extra_attrs.attrs.values():
244+
v.accept(self)
242245

243246
def visit_type_alias_type(self, t: TypeAliasType) -> None:
244247
type_ref = t.type_ref

mypy/plugins/default.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
4747
return ctypes.array_constructor_callback
4848
elif fullname == "functools.singledispatch":
4949
return singledispatch.create_singledispatch_function_callback
50+
elif fullname == "functools.partial":
51+
import mypy.plugins.functools
52+
53+
return mypy.plugins.functools.partial_new_callback
5054

5155
return None
5256

@@ -118,6 +122,10 @@ def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | No
118122
return singledispatch.singledispatch_register_callback
119123
elif fullname == singledispatch.REGISTER_CALLABLE_CALL_METHOD:
120124
return singledispatch.call_singledispatch_function_after_register_argument
125+
elif fullname == "functools.partial.__call__":
126+
import mypy.plugins.functools
127+
128+
return mypy.plugins.functools.partial_call_callback
121129
return None
122130

123131
def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:
@@ -155,12 +163,13 @@ def get_class_decorator_hook(self, fullname: str) -> Callable[[ClassDefContext],
155163
def get_class_decorator_hook_2(
156164
self, fullname: str
157165
) -> Callable[[ClassDefContext], bool] | None:
158-
from mypy.plugins import attrs, dataclasses, functools
166+
import mypy.plugins.functools
167+
from mypy.plugins import attrs, dataclasses
159168

160169
if fullname in dataclasses.dataclass_makers:
161170
return dataclasses.dataclass_class_maker_callback
162-
elif fullname in functools.functools_total_ordering_makers:
163-
return functools.functools_total_ordering_maker_callback
171+
elif fullname in mypy.plugins.functools.functools_total_ordering_makers:
172+
return mypy.plugins.functools.functools_total_ordering_maker_callback
164173
elif fullname in attrs.attr_class_makers:
165174
return attrs.attr_class_maker_callback
166175
elif fullname in attrs.attr_dataclass_makers:

mypy/plugins/functools.py

Lines changed: 142 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,22 @@
44

55
from typing import Final, NamedTuple
66

7+
import mypy.checker
78
import mypy.plugin
8-
from mypy.nodes import ARG_POS, ARG_STAR2, Argument, FuncItem, Var
9+
from mypy.argmap import map_actuals_to_formals
10+
from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, FuncItem, Var
911
from mypy.plugins.common import add_method_to_class
10-
from mypy.types import AnyType, CallableType, Type, TypeOfAny, UnboundType, get_proper_type
12+
from mypy.types import (
13+
AnyType,
14+
CallableType,
15+
Instance,
16+
Overloaded,
17+
Type,
18+
TypeOfAny,
19+
UnboundType,
20+
UninhabitedType,
21+
get_proper_type,
22+
)
1123

1224
functools_total_ordering_makers: Final = {"functools.total_ordering"}
1325

@@ -102,3 +114,131 @@ def _analyze_class(ctx: mypy.plugin.ClassDefContext) -> dict[str, _MethodInfo |
102114
comparison_methods[name] = None
103115

104116
return comparison_methods
117+
118+
119+
def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
120+
"""Infer a more precise return type for functools.partial"""
121+
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals
122+
return ctx.default_return_type
123+
if len(ctx.arg_types) != 3: # fn, *args, **kwargs
124+
return ctx.default_return_type
125+
if len(ctx.arg_types[0]) != 1:
126+
return ctx.default_return_type
127+
128+
if isinstance(get_proper_type(ctx.arg_types[0][0]), Overloaded):
129+
# TODO: handle overloads, just fall back to whatever the non-plugin code does
130+
return ctx.default_return_type
131+
fn_type = ctx.api.extract_callable_type(ctx.arg_types[0][0], ctx=ctx.default_return_type)
132+
if fn_type is None:
133+
return ctx.default_return_type
134+
135+
defaulted = fn_type.copy_modified(
136+
arg_kinds=[
137+
(
138+
ArgKind.ARG_OPT
139+
if k == ArgKind.ARG_POS
140+
else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k)
141+
)
142+
for k in fn_type.arg_kinds
143+
]
144+
)
145+
if defaulted.line < 0:
146+
# Make up a line number if we don't have one
147+
defaulted.set_line(ctx.default_return_type)
148+
149+
actual_args = [a for param in ctx.args[1:] for a in param]
150+
actual_arg_kinds = [a for param in ctx.arg_kinds[1:] for a in param]
151+
actual_arg_names = [a for param in ctx.arg_names[1:] for a in param]
152+
actual_types = [a for param in ctx.arg_types[1:] for a in param]
153+
154+
_, bound = ctx.api.expr_checker.check_call(
155+
callee=defaulted,
156+
args=actual_args,
157+
arg_kinds=actual_arg_kinds,
158+
arg_names=actual_arg_names,
159+
context=defaulted,
160+
)
161+
bound = get_proper_type(bound)
162+
if not isinstance(bound, CallableType):
163+
return ctx.default_return_type
164+
165+
formal_to_actual = map_actuals_to_formals(
166+
actual_kinds=actual_arg_kinds,
167+
actual_names=actual_arg_names,
168+
formal_kinds=fn_type.arg_kinds,
169+
formal_names=fn_type.arg_names,
170+
actual_arg_type=lambda i: actual_types[i],
171+
)
172+
173+
partial_kinds = []
174+
partial_types = []
175+
partial_names = []
176+
# We need to fully apply any positional arguments (they cannot be respecified)
177+
# However, keyword arguments can be respecified, so just give them a default
178+
for i, actuals in enumerate(formal_to_actual):
179+
if len(bound.arg_types) == len(fn_type.arg_types):
180+
arg_type = bound.arg_types[i]
181+
if isinstance(get_proper_type(arg_type), UninhabitedType):
182+
arg_type = fn_type.arg_types[i] # bit of a hack
183+
else:
184+
# TODO: I assume that bound and fn_type have the same arguments. It appears this isn't
185+
# true when PEP 646 things are happening. See testFunctoolsPartialTypeVarTuple
186+
arg_type = fn_type.arg_types[i]
187+
188+
if not actuals or fn_type.arg_kinds[i] in (ArgKind.ARG_STAR, ArgKind.ARG_STAR2):
189+
partial_kinds.append(fn_type.arg_kinds[i])
190+
partial_types.append(arg_type)
191+
partial_names.append(fn_type.arg_names[i])
192+
elif actuals:
193+
if any(actual_arg_kinds[j] == ArgKind.ARG_POS for j in actuals):
194+
continue
195+
kind = actual_arg_kinds[actuals[0]]
196+
if kind == ArgKind.ARG_NAMED:
197+
kind = ArgKind.ARG_NAMED_OPT
198+
partial_kinds.append(kind)
199+
partial_types.append(arg_type)
200+
partial_names.append(fn_type.arg_names[i])
201+
202+
ret_type = bound.ret_type
203+
if isinstance(get_proper_type(ret_type), UninhabitedType):
204+
ret_type = fn_type.ret_type # same kind of hack as above
205+
206+
partially_applied = fn_type.copy_modified(
207+
arg_types=partial_types,
208+
arg_kinds=partial_kinds,
209+
arg_names=partial_names,
210+
ret_type=ret_type,
211+
)
212+
213+
ret = ctx.api.named_generic_type("functools.partial", [ret_type])
214+
ret = ret.copy_with_extra_attr("__mypy_partial", partially_applied)
215+
return ret
216+
217+
218+
def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
219+
"""Infer a more precise return type for functools.partial.__call__."""
220+
if (
221+
not isinstance(ctx.api, mypy.checker.TypeChecker) # use internals
222+
or not isinstance(ctx.type, Instance)
223+
or ctx.type.type.fullname != "functools.partial"
224+
or not ctx.type.extra_attrs
225+
or "__mypy_partial" not in ctx.type.extra_attrs.attrs
226+
):
227+
return ctx.default_return_type
228+
229+
partial_type = ctx.type.extra_attrs.attrs["__mypy_partial"]
230+
if len(ctx.arg_types) != 2: # *args, **kwargs
231+
return ctx.default_return_type
232+
233+
args = [a for param in ctx.args for a in param]
234+
arg_kinds = [a for param in ctx.arg_kinds for a in param]
235+
arg_names = [a for param in ctx.arg_names for a in param]
236+
237+
result = ctx.api.expr_checker.check_call(
238+
callee=partial_type,
239+
args=args,
240+
arg_kinds=arg_kinds,
241+
arg_names=arg_names,
242+
context=ctx.context,
243+
)
244+
return result[0]

mypy/server/astdiff.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,11 +378,20 @@ def visit_deleted_type(self, typ: DeletedType) -> SnapshotItem:
378378
return snapshot_simple_type(typ)
379379

380380
def visit_instance(self, typ: Instance) -> SnapshotItem:
381+
extra_attrs: SnapshotItem
382+
if typ.extra_attrs:
383+
extra_attrs = (
384+
tuple(sorted((k, v.accept(self)) for k, v in typ.extra_attrs.attrs.items())),
385+
tuple(typ.extra_attrs.immutable),
386+
)
387+
else:
388+
extra_attrs = ()
381389
return (
382390
"Instance",
383391
encode_optional_str(typ.type.fullname),
384392
snapshot_types(typ.args),
385393
("None",) if typ.last_known_value is None else snapshot_type(typ.last_known_value),
394+
extra_attrs,
386395
)
387396

388397
def visit_type_var(self, typ: TypeVarType) -> SnapshotItem:

mypy/types.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,6 +1322,23 @@ def copy(self) -> ExtraAttrs:
13221322
def __repr__(self) -> str:
13231323
return f"ExtraAttrs({self.attrs!r}, {self.immutable!r}, {self.mod_name!r})"
13241324

1325+
def serialize(self) -> JsonDict:
1326+
return {
1327+
".class": "ExtraAttrs",
1328+
"attrs": {k: v.serialize() for k, v in self.attrs.items()},
1329+
"immutable": list(self.immutable),
1330+
"mod_name": self.mod_name,
1331+
}
1332+
1333+
@classmethod
1334+
def deserialize(cls, data: JsonDict) -> ExtraAttrs:
1335+
assert data[".class"] == "ExtraAttrs"
1336+
return ExtraAttrs(
1337+
{k: deserialize_type(v) for k, v in data["attrs"].items()},
1338+
set(data["immutable"]),
1339+
data["mod_name"],
1340+
)
1341+
13251342

13261343
class Instance(ProperType):
13271344
"""An instance type of form C[T1, ..., Tn].
@@ -1434,6 +1451,7 @@ def serialize(self) -> JsonDict | str:
14341451
data["args"] = [arg.serialize() for arg in self.args]
14351452
if self.last_known_value is not None:
14361453
data["last_known_value"] = self.last_known_value.serialize()
1454+
data["extra_attrs"] = self.extra_attrs.serialize() if self.extra_attrs else None
14371455
return data
14381456

14391457
@classmethod
@@ -1452,6 +1470,8 @@ def deserialize(cls, data: JsonDict | str) -> Instance:
14521470
inst.type_ref = data["type_ref"] # Will be fixed up by fixup.py later.
14531471
if "last_known_value" in data:
14541472
inst.last_known_value = LiteralType.deserialize(data["last_known_value"])
1473+
if data.get("extra_attrs") is not None:
1474+
inst.extra_attrs = ExtraAttrs.deserialize(data["extra_attrs"])
14551475
return inst
14561476

14571477
def copy_modified(
@@ -1461,13 +1481,14 @@ def copy_modified(
14611481
last_known_value: Bogus[LiteralType | None] = _dummy,
14621482
) -> Instance:
14631483
new = Instance(
1464-
self.type,
1465-
args if args is not _dummy else self.args,
1466-
self.line,
1467-
self.column,
1484+
typ=self.type,
1485+
args=args if args is not _dummy else self.args,
1486+
line=self.line,
1487+
column=self.column,
14681488
last_known_value=(
14691489
last_known_value if last_known_value is not _dummy else self.last_known_value
14701490
),
1491+
extra_attrs=self.extra_attrs,
14711492
)
14721493
# We intentionally don't copy the extra_attrs here, so they will be erased.
14731494
new.can_be_true = self.can_be_true

0 commit comments

Comments
 (0)