diff --git a/integration_test/cases/transaction_test.exs b/integration_test/cases/transaction_test.exs index dd245b7..f77b1bd 100644 --- a/integration_test/cases/transaction_test.exs +++ b/integration_test/cases/transaction_test.exs @@ -3,6 +3,7 @@ defmodule TransactionTest do alias TestPool, as: P alias TestAgent, as: A + alias TestQuery, as: Q test "transaction returns result" do stack = [ @@ -1253,4 +1254,64 @@ defmodule TransactionTest do handle_status: [_, :newest_state] ] = A.record(agent) end + + test "log query from handle_begin" do + stack = [ + {:ok, :state}, + {:ok, %Q{statement: "custom begin"}, :begin, :new_state}, + {:ok, :committed, :newest_state} + ] + + {:ok, agent} = A.start_link(stack) + + parent = self() + opts = [agent: agent, parent: parent] + {:ok, pool} = P.start_link(opts) + + log = &send(parent, &1) + + assert P.transaction(pool, fn _ -> :result end, log: log) == {:ok, :result} + + assert_received %DBConnection.LogEntry{call: :begin} = entry + assert %{query: "custom begin"} = entry + + assert [ + connect: [_], + handle_begin: [_, :state], + handle_commit: [_, :new_state] + ] = A.record(agent) + end + + test "log query from handle_begin: transaction inside run" do + stack = [ + {:ok, :state}, + {:idle, :new_state}, + {:ok, %Q{statement: "custom begin"}, :begin, :newer_state}, + {:ok, :committed, :newest_state}, + {:idle, :newest_state} + ] + + {:ok, agent} = A.start_link(stack) + + parent = self() + opts = [agent: agent, parent: parent] + {:ok, pool} = P.start_link(opts) + + log = &send(parent, &1) + + assert P.run(pool, fn conn -> + P.transaction(conn, fn _ -> :result end, log: log) + end) == {:ok, :result} + + assert_received %DBConnection.LogEntry{call: :begin} = entry + assert %{query: "custom begin"} = entry + + assert [ + connect: [_], + handle_status: [_, :state], + handle_begin: [_, :new_state], + handle_commit: [_, :newer_state], + handle_status: [_, :newest_state] + ] = A.record(agent) + end end diff --git a/lib/db_connection.ex b/lib/db_connection.ex index e837c82..892b060 100644 --- a/lib/db_connection.ex +++ b/lib/db_connection.ex @@ -201,9 +201,12 @@ defmodule DBConnection do @doc """ Handle the beginning of a transaction. - Return `{:ok, result, state}` to continue, `{status, state}` to notify caller - that the transaction can not begin due to the transaction status `status`, - or `{:disconnect, exception, state}` to error and disconnect. + Return `{:ok, result, state}`/`{:ok, query, result, state}` to continue, + `{status, state}` to notify caller that the transaction can not begin due + to the transaction status `status`, or `{:disconnect, exception, state}` + to error and disconnect. If `{:ok, query, result, state}` is returned, + the query will be used to log the begin command. Otherwise, it will be + logged as `begin`. A callback implementation should only return `status` if it can determine the database's transaction status without side effect. @@ -212,6 +215,7 @@ defmodule DBConnection do """ @callback handle_begin(opts :: Keyword.t(), state :: any) :: {:ok, result, new_state :: any} + | {:ok, query, result, new_state :: any} | {status, new_state :: any} | {:disconnect, Exception.t(), new_state :: any} @@ -1762,9 +1766,18 @@ defmodule DBConnection do end defp begin(conn, run, opts) do - conn - |> run.(&run_begin/3, meter(opts), opts) - |> log(:begin, :begin, nil) + case run.(conn, &run_begin/3, meter(opts), opts) do + {:ok, conn, {query, result}, meter} -> + query = String.Chars.to_string(query) + log({:ok, conn, result, meter}, :begin, query, nil) + + {:ok, {query, result}, meter} -> + query = String.Chars.to_string(query) + log({:ok, result, meter}, :begin, query, nil) + + other -> + log(other, :begin, :begin, nil) + end end defp run_begin(conn, meter, opts) do @@ -1775,6 +1788,9 @@ defmodule DBConnection do {status, _conn_state} when status in [:idle, :transaction, :error] -> status_disconnect(conn, status, meter) + {:ok, query, result, _conn_status} -> + {:ok, {query, result}, meter} + other -> handle_common_result(other, conn, meter) end diff --git a/test/test_support.exs b/test/test_support.exs index 644132f..3c92ed1 100644 --- a/test/test_support.exs +++ b/test/test_support.exs @@ -145,7 +145,13 @@ defmodule TestConnection do end defmodule TestQuery do - defstruct [:state] + defstruct [:state, :statement] + + defimpl String.Chars do + def to_string(%{statement: statement}) do + IO.iodata_to_binary(statement) + end + end end defmodule TestCursor do