Skip to content

Commit 0b47f55

Browse files
committed
Infer types for fn
1 parent 73cf0d3 commit 0b47f55

File tree

2 files changed

+68
-30
lines changed

2 files changed

+68
-30
lines changed

lib/elixir/lib/module/types/expr.ex

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ defmodule Module.Types.Expr do
319319
else
320320
clauses
321321
end
322-
|> of_clauses([case_type], expected, expr, info, stack, {none(), context})
322+
|> of_clauses([case_type], expected, expr, info, stack, context, none())
323323
|> dynamic_unless_static(stack)
324324
end
325325

@@ -328,8 +328,20 @@ defmodule Module.Types.Expr do
328328
[{:->, _, [head, _]} | _] = clauses
329329
{patterns, _guards} = extract_head(head)
330330
domain = Enum.map(patterns, fn _ -> dynamic() end)
331-
{_acc, context} = of_clauses(clauses, domain, @pending, nil, :fn, stack, {none(), context})
332-
{dynamic(fun(length(patterns))), context}
331+
332+
if stack.mode == :traversal do
333+
{_acc, context} = of_clauses(clauses, domain, @pending, nil, :fn, stack, context, none())
334+
{dynamic(fun(length(patterns))), context}
335+
else
336+
{acc, context} =
337+
of_clauses_fun(clauses, domain, @pending, nil, :fn, stack, context, [], fn
338+
trees, body, context, acc ->
339+
args = Pattern.of_domain(trees, domain, context)
340+
add_inferred(acc, args, body)
341+
end)
342+
343+
{fun_from_overlapping_clauses(acc), context}
344+
end
333345
end
334346

335347
def of_expr({:try, _meta, [[do: body] ++ blocks]}, expected, expr, stack, original) do
@@ -340,7 +352,7 @@ defmodule Module.Types.Expr do
340352
if else_block do
341353
{type, context} = of_expr(body, @pending, body, stack, original)
342354
info = {:try_else, type}
343-
of_clauses(else_block, [type], expected, expr, info, stack, {none(), context})
355+
of_clauses(else_block, [type], expected, expr, info, stack, context, none())
344356
else
345357
of_expr(body, expected, expr, stack, original)
346358
end
@@ -364,15 +376,8 @@ defmodule Module.Types.Expr do
364376
end)
365377

366378
{:catch, clauses}, {acc, context} ->
367-
of_clauses(
368-
clauses,
369-
[@try_catch, dynamic()],
370-
expected,
371-
expr,
372-
:try_catch,
373-
stack,
374-
{acc, context}
375-
)
379+
args = [@try_catch, dynamic()]
380+
of_clauses(clauses, args, expected, expr, :try_catch, stack, context, acc)
376381
end)
377382
|> dynamic_unless_static(stack)
378383

@@ -392,8 +397,8 @@ defmodule Module.Types.Expr do
392397
{:do, {:__block__, _, []}}, acc_context ->
393398
acc_context
394399

395-
{:do, clauses}, acc_context ->
396-
of_clauses(clauses, [dynamic()], expected, expr, :receive, stack, acc_context)
400+
{:do, clauses}, {acc, context} ->
401+
of_clauses(clauses, [dynamic()], expected, expr, :receive, stack, context, acc)
397402

398403
{:after, [{:->, meta, [[timeout], body]}] = after_expr}, {acc, context} ->
399404
{timeout_type, context} = of_expr(timeout, @timeout_type, after_expr, stack, context)
@@ -420,7 +425,7 @@ defmodule Module.Types.Expr do
420425
{reduce_type, context} = of_expr(reduce, expected, expr, stack, context)
421426
# TODO: We need to type check against dynamic() instead of using reduce_type
422427
# because this is recursive. We need to infer the block type first.
423-
of_clauses(block, [dynamic()], expected, expr, :for_reduce, stack, {reduce_type, context})
428+
of_clauses(block, [dynamic()], expected, expr, :for_reduce, stack, context, reduce_type)
424429
else
425430
# TODO: Use the collectable protocol for the output
426431
into = Keyword.get(opts, :into, [])
@@ -665,7 +670,7 @@ defmodule Module.Types.Expr do
665670

666671
defp with_option({:else, clauses}, stack, context, _original) do
667672
{_, context} =
668-
of_clauses(clauses, [dynamic()], @pending, nil, :with_else, stack, {none(), context})
673+
of_clauses(clauses, [dynamic()], @pending, nil, :with_else, stack, context, none())
669674

670675
context
671676
end
@@ -723,22 +728,27 @@ defmodule Module.Types.Expr do
723728
defp dynamic_unless_static({_, _} = output, %{mode: :static}), do: output
724729
defp dynamic_unless_static({type, context}, %{mode: _}), do: {dynamic(type), context}
725730

726-
defp of_clauses(clauses, domain, expected, expr, info, %{mode: mode} = stack, {acc, original}) do
731+
defp of_clauses(clauses, domain, expected, expr, info, %{mode: mode} = stack, context, acc) do
732+
fun =
733+
if mode == :traversal do
734+
fn _, _, _, _ -> dynamic() end
735+
else
736+
fn _trees, result, _context, acc -> union(result, acc) end
737+
end
738+
739+
of_clauses_fun(clauses, domain, expected, expr, info, stack, context, acc, fun)
740+
end
741+
742+
defp of_clauses_fun(clauses, domain, expected, expr, info, stack, original, acc, fun) do
727743
%{failed: failed?} = original
728744

729745
Enum.reduce(clauses, {acc, original}, fn {:->, meta, [head, body]}, {acc, context} ->
730746
{failed?, context} = reset_failed(context, failed?)
731747
{patterns, guards} = extract_head(head)
732-
{_trees, context} = Pattern.of_head(patterns, guards, domain, info, meta, stack, context)
748+
{trees, context} = Pattern.of_head(patterns, guards, domain, info, meta, stack, context)
733749

734-
{body, context} = of_expr(body, expected, expr || body, stack, context)
735-
context = context |> set_failed(failed?) |> reset_vars(original)
736-
737-
if mode == :traversal do
738-
{dynamic(), context}
739-
else
740-
{union(acc, body), context}
741-
end
750+
{result, context} = of_expr(body, expected, expr || body, stack, context)
751+
{fun.(trees, result, context, acc), context |> set_failed(failed?) |> reset_vars(original)}
742752
end)
743753
end
744754

@@ -770,6 +780,15 @@ defmodule Module.Types.Expr do
770780
defp repack_match(left_expr, right_expr),
771781
do: {left_expr, right_expr}
772782

783+
defp add_inferred([{args, existing_return} | tail], args, return),
784+
do: [{args, union(existing_return, return)} | tail]
785+
786+
defp add_inferred([head | tail], args, return),
787+
do: [head | add_inferred(tail, args, return)]
788+
789+
defp add_inferred([], args, return),
790+
do: [{args, return}]
791+
773792
## Warning formatting
774793

775794
def format_diagnostic({:badmap, type, expr, context}) do

lib/elixir/test/elixir/module/types/expr_test.exs

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ defmodule Module.Types.ExprTest do
2525
assert typecheck!("foo") == binary()
2626
assert typecheck!([]) == empty_list()
2727
assert typecheck!(%{}) == closed_map([])
28-
assert typecheck!(& &1) == dynamic(fun(1))
29-
assert typecheck!(fn -> :ok end) == dynamic(fun(0))
3028
end
3129

3230
test "generated" do
@@ -129,7 +127,7 @@ defmodule Module.Types.ExprTest do
129127
end
130128

131129
describe "funs" do
132-
test "infers funs" do
130+
test "infers calls" do
133131
assert typecheck!(
134132
[x],
135133
(
@@ -139,6 +137,27 @@ defmodule Module.Types.ExprTest do
139137
) == dynamic(fun(2))
140138
end
141139

140+
test "infers functions" do
141+
assert typecheck!(& &1) == fun([dynamic()], dynamic())
142+
assert typecheck!(fn -> :ok end) == fun([], atom([:ok]))
143+
144+
assert typecheck!(fn
145+
<<"ok">>, {} -> :ok
146+
<<"error">>, {} -> :error
147+
[_ | _], %{} -> :list
148+
end) ==
149+
intersection(
150+
fun(
151+
[dynamic(non_empty_list(term(), term())), dynamic(open_map())],
152+
atom([:list])
153+
),
154+
fun(
155+
[dynamic(binary()), dynamic(tuple([]))],
156+
atom([:ok, :error])
157+
)
158+
)
159+
end
160+
142161
test "bad function" do
143162
assert typeerror!([%x{}, a1, a2], x.(a1, a2)) == ~l"""
144163
expected a 2-arity function on call:

0 commit comments

Comments
 (0)