Skip to content

Add fun from inferred clauses #14597

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/elixir/lib/module/types/apply.ex
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ defmodule Module.Types.Apply do
{union(type, fun_from_non_overlapping_clauses(clauses)), fallback?, context}

{{:infer, _, clauses}, context} when length(clauses) <= @max_clauses ->
{union(type, fun_from_overlapping_clauses(clauses)), fallback?, context}
{union(type, fun_from_inferred_clauses(clauses)), fallback?, context}

{_, context} ->
{type, true, context}
Expand Down Expand Up @@ -705,7 +705,7 @@ defmodule Module.Types.Apply do
result =
case info do
{:infer, _, clauses} when length(clauses) <= @max_clauses ->
fun_from_overlapping_clauses(clauses)
fun_from_inferred_clauses(clauses)

_ ->
dynamic(fun(arity))
Expand Down
25 changes: 13 additions & 12 deletions lib/elixir/lib/module/types/descr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,17 @@ defmodule Module.Types.Descr do
@doc """
Creates a function from overlapping function clauses.
"""
def fun_from_overlapping_clauses(args_clauses) do
def fun_from_inferred_clauses(args_clauses) do
domain_clauses =
Enum.reduce(args_clauses, [], fn {args, return}, acc ->
pivot_overlapping_clause(args_to_domain(args), return, acc)
domain = args |> Enum.map(&upper_bound/1) |> args_to_domain()
pivot_overlapping_clause(domain, upper_bound(return), acc)
end)

funs =
for {domain, return} <- domain_clauses,
args <- domain_to_args(domain),
do: fun(args, return)
do: fun(args, dynamic(return))

Enum.reduce(funs, &intersection/2)
end
Expand Down Expand Up @@ -200,19 +201,19 @@ defmodule Module.Types.Descr do
def domain_to_args(descr) do
case :maps.take(:dynamic, descr) do
:error ->
tuple_elim_negations_static(descr, &Function.identity/1)
unwrap_domain_tuple(descr, fn {:closed, elems} -> elems end)

{dynamic, static} ->
tuple_elim_negations_static(static, &Function.identity/1) ++
tuple_elim_negations_static(dynamic, fn elems -> Enum.map(elems, &dynamic/1) end)
unwrap_domain_tuple(static, fn {:closed, elems} -> elems end) ++
unwrap_domain_tuple(dynamic, fn {:closed, elems} -> Enum.map(elems, &dynamic/1) end)
end
end

defp tuple_elim_negations_static(%{tuple: dnf} = descr, transform) when map_size(descr) == 1 do
Enum.map(dnf, fn {:closed, elements} -> transform.(elements) end)
defp unwrap_domain_tuple(%{tuple: dnf} = descr, transform) when map_size(descr) == 1 do
Enum.map(dnf, transform)
end

defp tuple_elim_negations_static(descr, _transform) when descr == %{}, do: []
defp unwrap_domain_tuple(descr, _transform) when descr == %{}, do: []

defp domain_to_flat_args(domain, arity) do
case domain_to_args(domain) do
Expand Down Expand Up @@ -2115,9 +2116,6 @@ defmodule Module.Types.Descr do

defp dynamic_to_quoted(descr, opts) do
cond do
descr == %{} ->
[]

# We check for :term literally instead of using term_type?
# because we check for term_type? in to_quoted before we
# compute the difference(dynamic, static).
Expand All @@ -2127,6 +2125,9 @@ defmodule Module.Types.Descr do
single = indivisible_bitmap(descr, opts) ->
[single]

empty?(descr) ->
[]

true ->
case non_term_type_to_quoted(descr, opts) do
{:none, _meta, []} = none -> [none]
Expand Down
8 changes: 6 additions & 2 deletions lib/elixir/lib/module/types/expr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ defmodule Module.Types.Expr do
add_inferred(acc, args, body)
end)

{fun_from_overlapping_clauses(acc), context}
{fun_from_inferred_clauses(acc), context}
end
end

Expand Down Expand Up @@ -461,7 +461,11 @@ defmodule Module.Types.Expr do
{args_types, context} =
Enum.map_reduce(args, context, &of_expr(&1, @pending, &1, stack, &2))

Apply.fun_apply(fun_type, args_types, call, stack, context)
if stack.mode == :traversal do
{dynamic(), context}
else
Apply.fun_apply(fun_type, args_types, call, stack, context)
end
end

def of_expr({{:., _, [callee, key_or_fun]}, meta, []} = call, expected, expr, stack, context)
Expand Down
68 changes: 43 additions & 25 deletions lib/elixir/test/elixir/module/types/descr_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -767,54 +767,72 @@ defmodule Module.Types.DescrTest do
intersection(fun([integer()], atom()), fun([float()], binary()))
end

test "fun_from_overlapping_clauses" do
test "fun_from_inferred_clauses" do
# No overlap
assert fun_from_overlapping_clauses([{[integer()], atom()}, {[float()], binary()}])
assert fun_from_inferred_clauses([{[integer()], atom()}, {[float()], binary()}])
|> equal?(
fun_from_non_overlapping_clauses([{[integer()], atom()}, {[float()], binary()}])
intersection(
fun_from_non_overlapping_clauses([{[integer()], atom()}, {[float()], binary()}]),
fun([number()], dynamic())
)
)

# Subsets
assert fun_from_overlapping_clauses([{[integer()], atom()}, {[number()], binary()}])
assert fun_from_inferred_clauses([{[integer()], atom()}, {[number()], binary()}])
|> equal?(
fun_from_non_overlapping_clauses([
{[integer()], union(atom(), binary())},
{[float()], binary()}
])
intersection(
fun_from_non_overlapping_clauses([
{[integer()], union(atom(), binary())},
{[float()], binary()}
]),
fun([number()], dynamic())
)
)

assert fun_from_overlapping_clauses([{[number()], binary()}, {[integer()], atom()}])
assert fun_from_inferred_clauses([{[number()], binary()}, {[integer()], atom()}])
|> equal?(
fun_from_non_overlapping_clauses([
{[integer()], union(atom(), binary())},
{[float()], binary()}
])
intersection(
fun_from_non_overlapping_clauses([
{[integer()], union(atom(), binary())},
{[float()], binary()}
]),
fun([number()], dynamic())
)
)

# Partial
assert fun_from_overlapping_clauses([
assert fun_from_inferred_clauses([
{[union(integer(), pid())], atom()},
{[union(float(), pid())], binary()}
])
|> equal?(
fun_from_non_overlapping_clauses([
{[integer()], atom()},
{[float()], binary()},
{[pid()], union(atom(), binary())}
])
intersection(
fun_from_non_overlapping_clauses([
{[integer()], atom()},
{[float()], binary()},
{[pid()], union(atom(), binary())}
]),
fun([union(number(), pid())], dynamic())
)
)

# Difference
assert fun_from_overlapping_clauses([
assert fun_from_inferred_clauses([
{[integer(), union(pid(), atom())], atom()},
{[number(), pid()], binary()}
])
|> equal?(
fun_from_non_overlapping_clauses([
{[float(), pid()], binary()},
{[integer(), atom()], atom()},
{[integer(), pid()], union(atom(), binary())}
])
intersection(
fun_from_non_overlapping_clauses([
{[float(), pid()], binary()},
{[integer(), atom()], atom()},
{[integer(), pid()], union(atom(), binary())}
]),
fun_from_non_overlapping_clauses([
{[integer(), union(pid(), atom())], dynamic()},
{[number(), pid()], dynamic()}
])
)
)
end
end
Expand Down
31 changes: 22 additions & 9 deletions lib/elixir/test/elixir/module/types/expr_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -138,24 +138,37 @@ defmodule Module.Types.ExprTest do
end

test "infers functions" do
assert typecheck!(& &1) == fun([dynamic()], dynamic())
assert typecheck!(fn -> :ok end) == fun([], atom([:ok]))
assert typecheck!(& &1) |> equal?(fun([term()], dynamic()))

assert typecheck!(fn -> :ok end) |> equal?(fun([], dynamic(atom([:ok]))))

assert typecheck!(fn
<<"ok">>, {} -> :ok
<<"error">>, {} -> :error
[_ | _], %{} -> :list
end) ==
end)
|> equal?(
intersection(
fun(
[dynamic(non_empty_list(term(), term())), dynamic(open_map())],
atom([:list])
[non_empty_list(term(), term()), open_map()],
dynamic(atom([:list]))
),
fun(
[dynamic(binary()), dynamic(tuple([]))],
atom([:ok, :error])
[binary(), tuple([])],
dynamic(atom([:ok, :error]))
)
)
)
end

test "application" do
assert typecheck!(
[map],
(fn
%{a: a} = data -> %{data | b: a}
data -> data
end).(map)
) == dynamic()
end

test "bad function" do
Expand Down Expand Up @@ -253,7 +266,7 @@ defmodule Module.Types.ExprTest do

but function has type:

(dynamic(map()) -> :map)
(map() -> dynamic(:map))
"""
end

Expand All @@ -265,7 +278,7 @@ defmodule Module.Types.ExprTest do

because the right-hand side has type:

(dynamic() -> dynamic({:ok, term()}))
(term() -> dynamic({:ok, term()}))
"""
end
end
Expand Down
4 changes: 2 additions & 2 deletions lib/elixir/test/elixir/module/types/integration_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ defmodule Module.Types.IntegrationTest do
assert return.(:captured, 0)
|> equal?(
fun_from_non_overlapping_clauses([
{[dynamic(binary())], atom([:ok, :error])},
{[dynamic(non_empty_list(term(), term()))], atom([:list])}
{[binary()], dynamic(atom([:ok, :error]))},
{[non_empty_list(term(), term())], dynamic(atom([:list]))}
])
)
end
Expand Down
Loading