diff --git a/lib/flame/pool.ex b/lib/flame/pool.ex index 6a13397..37e10a5 100644 --- a/lib/flame/pool.ex +++ b/lib/flame/pool.ex @@ -514,30 +514,51 @@ defmodule FLAME.Pool do end end + defp await_downs(child_pids) do + if MapSet.size(child_pids) == 0 do + :ok + else + receive do + {:DOWN, _ref, :process, pid, _reason} -> await_downs(MapSet.delete(child_pids, pid)) + end + end + end + defp replace_caller(%Pool{} = state, checkout_ref, caller_pid, [_ | _] = child_pids) do - # replace caller with child pids and increase concurrency counts for the runner + # replace caller with child pid and do not inc concurrency counts since we are replacing %{^caller_pid => %Caller{checkout_ref: ^checkout_ref} = caller} = state.callers Process.demonitor(caller.monitor_ref, [:flush]) - new_callers = Map.delete(state.callers, caller_pid) + # if we have more than 1 child pid, such as for multiple trackables returned for a single + # call, we monitor all of them under a new process and the new process takes the slot in the + # pool. When all trackables are finished, the new process goes down and frees the slot. + child_pid = + case child_pids do + [child_pid] -> + child_pid + + [_ | _] -> + {:ok, child_pid} = + Task.Supervisor.start_child(state.task_sup, fn -> + Enum.each(child_pids, &Process.monitor(&1)) + await_downs(MapSet.new(child_pids)) + end) + + child_pid + end + + new_caller = %Caller{ + checkout_ref: checkout_ref, + monitor_ref: Process.monitor(child_pid), + runner_ref: caller.runner_ref + } new_callers = - Enum.reduce(child_pids, new_callers, fn child_pid, acc -> - new_caller = %Caller{ - checkout_ref: checkout_ref, - monitor_ref: Process.monitor(child_pid), - runner_ref: caller.runner_ref - } - - Map.put(acc, child_pid, new_caller) - end) + state.callers + |> Map.delete(caller_pid) + |> Map.put(child_pid, new_caller) - inc_runner_count( - %Pool{state | callers: new_callers}, - caller.runner_ref, - # subtract 1 because the caller we are replacing is already in the count - length(child_pids) - 1 - ) + %Pool{state | callers: new_callers} end defp checkin_runner(state, ref, caller_pid, reason) @@ -687,10 +708,10 @@ defmodule FLAME.Pool do {runner, new_state} end - defp inc_runner_count(%Pool{} = state, ref, amount \\ 1) do + defp inc_runner_count(%Pool{} = state, ref) do new_runners = Map.update!(state.runners, ref, fn %RunnerState{} = runner -> - %RunnerState{runner | count: runner.count + amount} + %RunnerState{runner | count: runner.count + 1} end) %Pool{state | runners: new_runners}