Skip to content

Commit c57bd8b

Browse files
Allow handle_begin callbacks to return query (#297)
1 parent 82dbd10 commit c57bd8b

File tree

3 files changed

+90
-7
lines changed

3 files changed

+90
-7
lines changed

integration_test/cases/transaction_test.exs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ defmodule TransactionTest do
33

44
alias TestPool, as: P
55
alias TestAgent, as: A
6+
alias TestQuery, as: Q
67

78
test "transaction returns result" do
89
stack = [
@@ -1253,4 +1254,64 @@ defmodule TransactionTest do
12531254
handle_status: [_, :newest_state]
12541255
] = A.record(agent)
12551256
end
1257+
1258+
test "log query from handle_begin" do
1259+
stack = [
1260+
{:ok, :state},
1261+
{:ok, %Q{statement: "custom begin"}, :begin, :new_state},
1262+
{:ok, :committed, :newest_state}
1263+
]
1264+
1265+
{:ok, agent} = A.start_link(stack)
1266+
1267+
parent = self()
1268+
opts = [agent: agent, parent: parent]
1269+
{:ok, pool} = P.start_link(opts)
1270+
1271+
log = &send(parent, &1)
1272+
1273+
assert P.transaction(pool, fn _ -> :result end, log: log) == {:ok, :result}
1274+
1275+
assert_received %DBConnection.LogEntry{call: :begin} = entry
1276+
assert %{query: "custom begin"} = entry
1277+
1278+
assert [
1279+
connect: [_],
1280+
handle_begin: [_, :state],
1281+
handle_commit: [_, :new_state]
1282+
] = A.record(agent)
1283+
end
1284+
1285+
test "log query from handle_begin: transaction inside run" do
1286+
stack = [
1287+
{:ok, :state},
1288+
{:idle, :new_state},
1289+
{:ok, %Q{statement: "custom begin"}, :begin, :newer_state},
1290+
{:ok, :committed, :newest_state},
1291+
{:idle, :newest_state}
1292+
]
1293+
1294+
{:ok, agent} = A.start_link(stack)
1295+
1296+
parent = self()
1297+
opts = [agent: agent, parent: parent]
1298+
{:ok, pool} = P.start_link(opts)
1299+
1300+
log = &send(parent, &1)
1301+
1302+
assert P.run(pool, fn conn ->
1303+
P.transaction(conn, fn _ -> :result end, log: log)
1304+
end) == {:ok, :result}
1305+
1306+
assert_received %DBConnection.LogEntry{call: :begin} = entry
1307+
assert %{query: "custom begin"} = entry
1308+
1309+
assert [
1310+
connect: [_],
1311+
handle_status: [_, :state],
1312+
handle_begin: [_, :new_state],
1313+
handle_commit: [_, :newer_state],
1314+
handle_status: [_, :newest_state]
1315+
] = A.record(agent)
1316+
end
12561317
end

lib/db_connection.ex

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,12 @@ defmodule DBConnection do
201201
@doc """
202202
Handle the beginning of a transaction.
203203
204-
Return `{:ok, result, state}` to continue, `{status, state}` to notify caller
205-
that the transaction can not begin due to the transaction status `status`,
206-
or `{:disconnect, exception, state}` to error and disconnect.
204+
Return `{:ok, result, state}`/`{:ok, query, result, state}` to continue,
205+
`{status, state}` to notify caller that the transaction can not begin due
206+
to the transaction status `status`, or `{:disconnect, exception, state}`
207+
to error and disconnect. If `{:ok, query, result, state}` is returned,
208+
the query will be used to log the begin command. Otherwise, it will be
209+
logged as `begin`.
207210
208211
A callback implementation should only return `status` if it
209212
can determine the database's transaction status without side effect.
@@ -212,6 +215,7 @@ defmodule DBConnection do
212215
"""
213216
@callback handle_begin(opts :: Keyword.t(), state :: any) ::
214217
{:ok, result, new_state :: any}
218+
| {:ok, query, result, new_state :: any}
215219
| {status, new_state :: any}
216220
| {:disconnect, Exception.t(), new_state :: any}
217221

@@ -1762,9 +1766,18 @@ defmodule DBConnection do
17621766
end
17631767

17641768
defp begin(conn, run, opts) do
1765-
conn
1766-
|> run.(&run_begin/3, meter(opts), opts)
1767-
|> log(:begin, :begin, nil)
1769+
case run.(conn, &run_begin/3, meter(opts), opts) do
1770+
{:ok, conn, {query, result}, meter} ->
1771+
query = String.Chars.to_string(query)
1772+
log({:ok, conn, result, meter}, :begin, query, nil)
1773+
1774+
{:ok, {query, result}, meter} ->
1775+
query = String.Chars.to_string(query)
1776+
log({:ok, result, meter}, :begin, query, nil)
1777+
1778+
other ->
1779+
log(other, :begin, :begin, nil)
1780+
end
17681781
end
17691782

17701783
defp run_begin(conn, meter, opts) do
@@ -1775,6 +1788,9 @@ defmodule DBConnection do
17751788
{status, _conn_state} when status in [:idle, :transaction, :error] ->
17761789
status_disconnect(conn, status, meter)
17771790

1791+
{:ok, query, result, _conn_status} ->
1792+
{:ok, {query, result}, meter}
1793+
17781794
other ->
17791795
handle_common_result(other, conn, meter)
17801796
end

test/test_support.exs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,13 @@ defmodule TestConnection do
145145
end
146146

147147
defmodule TestQuery do
148-
defstruct [:state]
148+
defstruct [:state, :statement]
149+
150+
defimpl String.Chars do
151+
def to_string(%{statement: statement}) do
152+
IO.iodata_to_binary(statement)
153+
end
154+
end
149155
end
150156

151157
defmodule TestCursor do

0 commit comments

Comments
 (0)