Skip to content

Commit 8264f8d

Browse files
committed
Minimal changes to make filter() test pass
1 parent 0d2eb06 commit 8264f8d

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

mypy/constraints.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,12 @@ def visit_callable_type(self, template: CallableType) -> List[Constraint]:
457457
for t, a in zip(template.arg_types, cactual.arg_types):
458458
# Negate direction due to function argument type contravariance.
459459
res.extend(infer_constraints(t, a, neg_op(self.direction)))
460-
res.extend(infer_constraints(template.ret_type, cactual.ret_type,
460+
template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type
461+
if template.type_guard is not None:
462+
template_ret_type = template.type_guard
463+
if cactual.type_guard is not None:
464+
cactual_ret_type = cactual.type_guard
465+
res.extend(infer_constraints(template_ret_type, cactual_ret_type,
461466
self.direction))
462467
return res
463468
elif isinstance(self.actual, AnyType):

mypy/expandtype.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,12 @@ def visit_type_var(self, t: TypeVarType) -> Type:
9696
return repl
9797

9898
def visit_callable_type(self, t: CallableType) -> Type:
99+
extra = {}
100+
if t.type_guard is not None:
101+
extra['type_guard'] = t.type_guard.accept(self)
99102
return t.copy_modified(arg_types=self.expand_types(t.arg_types),
100-
ret_type=t.ret_type.accept(self))
103+
ret_type=t.ret_type.accept(self),
104+
**extra)
101105

102106
def visit_overloaded(self, t: Overloaded) -> Type:
103107
items = [] # type: List[CallableType]

test-data/unit/check-typeguard.test

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ R = TypeVar('R')
118118
def filter(f: Callable[[T], TypeGuard[R]], it: Iterable[T]) -> Iterable[R]: pass
119119
def is_float(a: object) -> TypeGuard[float]: pass
120120
a: List[object] = ["a", 0, 0.0]
121-
# TODO: Make this pass
122-
##reveal_type(filter(is_float, a)) ## N: Revealed type is 'typing.Iterable[float]'
121+
b = filter(is_float, a)
122+
reveal_type(b) # N: Revealed type is 'typing.Iterable[builtins.float*]'
123123
[builtins fixtures/tuple.pyi]
124124

125125
[case testTypeGuardWalrus]

0 commit comments

Comments
 (0)