Skip to content

Commit b4f9814

Browse files
authored
chore: add transaction support for SpannerLib (#530)
* chore: add transaction support for SpannerLib * chore: add ExecuteBatch to SpannerLib (#531) * chore: add ExecuteBatch to SpannerLib Adds an ExecuteBatch function to SpannerLib that supports executing DML or DDL statements as a single batch. The function accepts an ExecuteBatchDml request for both types of batches. The type of batch that is actually being executed is determined based on the statements in the batch. Mixing DML and DDL in the same batch is not supported. Queries are also not supported in batches. * chore: add WriteMutations function for SpannerLib (#532) Adds a WriteMutations function for SpannerLib. This function can be used to write mutations to Spanner in two ways: 1. In a transaction: The mutations are buffered in the current read/write transaction. The returned message is empty. 2. Outside a transaction: The mutations are written to Spanner directly in a new read/write transaction. The returned message contains the CommitResponse.
1 parent f331cf4 commit b4f9814

File tree

17 files changed

+2440
-7
lines changed

17 files changed

+2440
-7
lines changed

conn.go

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,9 @@ type SpannerConn interface {
222222
// return the same Spanner client.
223223
UnderlyingClient() (client *spanner.Client, err error)
224224

225+
// DetectStatementType returns the type of SQL statement.
226+
DetectStatementType(query string) parser.StatementType
227+
225228
// resetTransactionForRetry resets the current transaction after it has
226229
// been aborted by Spanner. Calling this function on a transaction that
227230
// has not been aborted is not supported and will cause an error to be
@@ -286,6 +289,11 @@ func (c *conn) UnderlyingClient() (*spanner.Client, error) {
286289
return c.client, nil
287290
}
288291

292+
func (c *conn) DetectStatementType(query string) parser.StatementType {
293+
info := c.parser.DetectStatementType(query)
294+
return info.StatementType
295+
}
296+
289297
func (c *conn) CommitTimestamp() (time.Time, error) {
290298
ts := propertyCommitTimestamp.GetValueOrDefault(c.state)
291299
if ts == nil {
@@ -675,6 +683,27 @@ func sum(affected []int64) int64 {
675683
return sum
676684
}
677685

686+
// WriteMutations is not part of the public API of the database/sql driver.
687+
// It is exported for internal reasons, and may receive breaking changes without prior notice.
688+
//
689+
// WriteMutations writes mutations using this connection. The mutations are either buffered in the current transaction,
690+
// or written directly to Spanner using a new read/write transaction if the connection does not have a transaction.
691+
//
692+
// The function returns an error if the connection currently has a read-only transaction.
693+
//
694+
// The returned CommitResponse is nil if the connection currently has a transaction, as the mutations will only be
695+
// applied to Spanner when the transaction commits.
696+
func (c *conn) WriteMutations(ctx context.Context, ms []*spanner.Mutation) (*spanner.CommitResponse, error) {
697+
if c.inTransaction() {
698+
return nil, c.BufferWrite(ms)
699+
}
700+
ts, err := c.Apply(ctx, ms)
701+
if err != nil {
702+
return nil, err
703+
}
704+
return &spanner.CommitResponse{CommitTs: ts}, nil
705+
}
706+
678707
func (c *conn) Apply(ctx context.Context, ms []*spanner.Mutation, opts ...spanner.ApplyOption) (commitTimestamp time.Time, err error) {
679708
if c.inTransaction() {
680709
return time.Time{}, spanner.ToSpannerError(
@@ -1071,6 +1100,34 @@ func (c *conn) getBatchReadOnlyTransactionOptions() BatchReadOnlyTransactionOpti
10711100
return BatchReadOnlyTransactionOptions{TimestampBound: c.ReadOnlyStaleness()}
10721101
}
10731102

1103+
// BeginReadOnlyTransaction is not part of the public API of the database/sql driver.
1104+
// It is exported for internal reasons, and may receive breaking changes without prior notice.
1105+
//
1106+
// BeginReadOnlyTransaction starts a new read-only transaction on this connection.
1107+
func (c *conn) BeginReadOnlyTransaction(ctx context.Context, options *ReadOnlyTransactionOptions) (driver.Tx, error) {
1108+
c.withTempReadOnlyTransactionOptions(options)
1109+
tx, err := c.BeginTx(ctx, driver.TxOptions{ReadOnly: true})
1110+
if err != nil {
1111+
c.withTempReadOnlyTransactionOptions(nil)
1112+
return nil, err
1113+
}
1114+
return tx, nil
1115+
}
1116+
1117+
// BeginReadWriteTransaction is not part of the public API of the database/sql driver.
1118+
// It is exported for internal reasons, and may receive breaking changes without prior notice.
1119+
//
1120+
// BeginReadWriteTransaction starts a new read/write transaction on this connection.
1121+
func (c *conn) BeginReadWriteTransaction(ctx context.Context, options *ReadWriteTransactionOptions) (driver.Tx, error) {
1122+
c.withTempTransactionOptions(options)
1123+
tx, err := c.BeginTx(ctx, driver.TxOptions{})
1124+
if err != nil {
1125+
c.withTempTransactionOptions(nil)
1126+
return nil, err
1127+
}
1128+
return tx, nil
1129+
}
1130+
10741131
func (c *conn) Begin() (driver.Tx, error) {
10751132
return c.BeginTx(context.Background(), driver.TxOptions{})
10761133
}
@@ -1254,18 +1311,29 @@ func (c *conn) inReadWriteTransaction() bool {
12541311
return false
12551312
}
12561313

1257-
func (c *conn) commit(ctx context.Context) (*spanner.CommitResponse, error) {
1314+
// Commit is not part of the public API of the database/sql driver.
1315+
// It is exported for internal reasons, and may receive breaking changes without prior notice.
1316+
//
1317+
// Commit commits the current transaction on this connection.
1318+
func (c *conn) Commit(ctx context.Context) (*spanner.CommitResponse, error) {
12581319
if !c.inTransaction() {
12591320
return nil, status.Errorf(codes.FailedPrecondition, "this connection does not have a transaction")
12601321
}
12611322
// TODO: Pass in context to the tx.Commit() function.
12621323
if err := c.tx.Commit(); err != nil {
12631324
return nil, err
12641325
}
1265-
return c.CommitResponse()
1326+
1327+
// This will return either the commit response or nil, depending on whether the transaction was a
1328+
// read/write transaction or a read-only transaction.
1329+
return propertyCommitResponse.GetValueOrDefault(c.state), nil
12661330
}
12671331

1268-
func (c *conn) rollback(ctx context.Context) error {
1332+
// Rollback is not part of the public API of the database/sql driver.
1333+
// It is exported for internal reasons, and may receive breaking changes without prior notice.
1334+
//
1335+
// Rollback rollbacks the current transaction on this connection.
1336+
func (c *conn) Rollback(ctx context.Context) error {
12691337
if !c.inTransaction() {
12701338
return status.Errorf(codes.FailedPrecondition, "this connection does not have a transaction")
12711339
}

conn_with_mockserver_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,26 @@ func TestTwoTransactionsOnOneConn(t *testing.T) {
8282
}
8383
}
8484

85+
func TestTwoQueriesOnOneConn(t *testing.T) {
86+
t.Parallel()
87+
88+
db, _, teardown := setupTestDBConnection(t)
89+
defer teardown()
90+
ctx := context.Background()
91+
92+
c, _ := db.Conn(ctx)
93+
defer silentClose(c)
94+
95+
for range 2 {
96+
r, err := c.QueryContext(context.Background(), testutil.SelectFooFromBar)
97+
if err != nil {
98+
t.Fatal(err)
99+
}
100+
_ = r.Next()
101+
defer silentClose(r)
102+
}
103+
}
104+
85105
func TestExplicitBeginTx(t *testing.T) {
86106
t.Parallel()
87107

spannerlib/api/batch_test.go

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package api
16+
17+
import (
18+
"context"
19+
"fmt"
20+
"reflect"
21+
"testing"
22+
23+
"cloud.google.com/go/longrunning/autogen/longrunningpb"
24+
"cloud.google.com/go/spanner"
25+
"cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
26+
"cloud.google.com/go/spanner/apiv1/spannerpb"
27+
"github.com/google/go-cmp/cmp"
28+
"github.com/google/go-cmp/cmp/cmpopts"
29+
"github.com/googleapis/go-sql-spanner/testutil"
30+
"google.golang.org/grpc/codes"
31+
"google.golang.org/protobuf/proto"
32+
"google.golang.org/protobuf/types/known/anypb"
33+
"google.golang.org/protobuf/types/known/emptypb"
34+
)
35+
36+
func TestExecuteDmlBatch(t *testing.T) {
37+
t.Parallel()
38+
39+
ctx := context.Background()
40+
server, teardown := setupMockServer(t)
41+
defer teardown()
42+
dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address)
43+
44+
poolId, err := CreatePool(ctx, dsn)
45+
if err != nil {
46+
t.Fatalf("CreatePool returned unexpected error: %v", err)
47+
}
48+
connId, err := CreateConnection(ctx, poolId)
49+
if err != nil {
50+
t.Fatalf("CreateConnection returned unexpected error: %v", err)
51+
}
52+
53+
// Execute a DML batch.
54+
request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{
55+
{Sql: testutil.UpdateBarSetFoo},
56+
{Sql: testutil.UpdateBarSetFoo},
57+
}}
58+
resp, err := ExecuteBatch(ctx, poolId, connId, request)
59+
if err != nil {
60+
t.Fatalf("ExecuteBatch returned unexpected error: %v", err)
61+
}
62+
if g, w := len(resp.ResultSets), 2; g != w {
63+
t.Fatalf("num results mismatch\n Got: %d\nWant: %d", g, w)
64+
}
65+
for i, result := range resp.ResultSets {
66+
if g, w := result.Stats.GetRowCountExact(), int64(testutil.UpdateBarSetFooRowCount); g != w {
67+
t.Fatalf("%d: update count mismatch\n Got: %d\nWant: %d", i, g, w)
68+
}
69+
}
70+
71+
requests := server.TestSpanner.DrainRequestsFromServer()
72+
// There should be no ExecuteSql requests.
73+
executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{}))
74+
if g, w := len(executeRequests), 0; g != w {
75+
t.Fatalf("Execute request count mismatch\n Got: %v\nWant: %v", g, w)
76+
}
77+
batchRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteBatchDmlRequest{}))
78+
if g, w := len(batchRequests), 1; g != w {
79+
t.Fatalf("Execute batch request count mismatch\n Got: %v\nWant: %v", g, w)
80+
}
81+
82+
if err := CloseConnection(ctx, poolId, connId); err != nil {
83+
t.Fatalf("CloseConnection returned unexpected error: %v", err)
84+
}
85+
if err := ClosePool(ctx, poolId); err != nil {
86+
t.Fatalf("ClosePool returned unexpected error: %v", err)
87+
}
88+
}
89+
90+
func TestExecuteDdlBatch(t *testing.T) {
91+
t.Parallel()
92+
93+
ctx := context.Background()
94+
server, teardown := setupMockServer(t)
95+
defer teardown()
96+
dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address)
97+
// Set up a result for a DDL statement on the mock server.
98+
var expectedResponse = &emptypb.Empty{}
99+
anyMsg, _ := anypb.New(expectedResponse)
100+
server.TestDatabaseAdmin.SetResps([]proto.Message{
101+
&longrunningpb.Operation{
102+
Done: true,
103+
Result: &longrunningpb.Operation_Response{Response: anyMsg},
104+
Name: "test-operation",
105+
},
106+
})
107+
108+
poolId, err := CreatePool(ctx, dsn)
109+
if err != nil {
110+
t.Fatalf("CreatePool returned unexpected error: %v", err)
111+
}
112+
connId, err := CreateConnection(ctx, poolId)
113+
if err != nil {
114+
t.Fatalf("CreateConnection returned unexpected error: %v", err)
115+
}
116+
117+
// Execute a DDL batch. This also uses a DML batch request.
118+
request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{
119+
{Sql: "create table my_table (id int64 primary key, value string(100))"},
120+
{Sql: "create index my_index on my_table (value)"},
121+
}}
122+
resp, err := ExecuteBatch(ctx, poolId, connId, request)
123+
if err != nil {
124+
t.Fatalf("ExecuteBatch returned unexpected error: %v", err)
125+
}
126+
// The response should contain an 'update count' per DDL statement.
127+
if g, w := len(resp.ResultSets), 2; g != w {
128+
t.Fatalf("num results mismatch\n Got: %d\nWant: %d", g, w)
129+
}
130+
// There is no update count for DDL statements.
131+
for i, result := range resp.ResultSets {
132+
emptyStats := &spannerpb.ResultSetStats{}
133+
if g, w := result.Stats, emptyStats; !cmp.Equal(g, w, cmpopts.IgnoreUnexported(spannerpb.ResultSetStats{})) {
134+
t.Fatalf("%d: ResultSetStats mismatch\n Got: %v\nWant: %v", i, g, w)
135+
}
136+
}
137+
138+
requests := server.TestSpanner.DrainRequestsFromServer()
139+
// There should be no ExecuteSql requests.
140+
executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{}))
141+
if g, w := len(executeRequests), 0; g != w {
142+
t.Fatalf("Execute request count mismatch\n Got: %v\nWant: %v", g, w)
143+
}
144+
// There should also be no ExecuteBatchDml requests.
145+
batchDmlRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteBatchDmlRequest{}))
146+
if g, w := len(batchDmlRequests), 0; g != w {
147+
t.Fatalf("ExecuteBatchDmlRequest count mismatch\n Got: %v\nWant: %v", g, w)
148+
}
149+
150+
adminRequests := server.TestDatabaseAdmin.Reqs()
151+
if g, w := len(adminRequests), 1; g != w {
152+
t.Fatalf("admin request count mismatch\n Got: %v\nWant: %v", g, w)
153+
}
154+
ddlRequest := adminRequests[0].(*databasepb.UpdateDatabaseDdlRequest)
155+
if g, w := len(ddlRequest.Statements), 2; g != w {
156+
t.Fatalf("DDL statement count mismatch\n Got: %v\nWant: %v", g, w)
157+
}
158+
159+
if err := CloseConnection(ctx, poolId, connId); err != nil {
160+
t.Fatalf("CloseConnection returned unexpected error: %v", err)
161+
}
162+
if err := ClosePool(ctx, poolId); err != nil {
163+
t.Fatalf("ClosePool returned unexpected error: %v", err)
164+
}
165+
}
166+
167+
func TestExecuteMixedBatch(t *testing.T) {
168+
t.Parallel()
169+
170+
ctx := context.Background()
171+
server, teardown := setupMockServer(t)
172+
defer teardown()
173+
dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address)
174+
175+
poolId, err := CreatePool(ctx, dsn)
176+
if err != nil {
177+
t.Fatalf("CreatePool returned unexpected error: %v", err)
178+
}
179+
connId, err := CreateConnection(ctx, poolId)
180+
if err != nil {
181+
t.Fatalf("CreateConnection returned unexpected error: %v", err)
182+
}
183+
184+
// Try to execute a batch with mixed DML and DDL statements. This should fail.
185+
request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{
186+
{Sql: "create table my_table (id int64 primary key, value string(100))"},
187+
{Sql: "update my_table set value = 100 where true"},
188+
}}
189+
_, err = ExecuteBatch(ctx, poolId, connId, request)
190+
if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w {
191+
t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w)
192+
}
193+
194+
if err := CloseConnection(ctx, poolId, connId); err != nil {
195+
t.Fatalf("CloseConnection returned unexpected error: %v", err)
196+
}
197+
if err := ClosePool(ctx, poolId); err != nil {
198+
t.Fatalf("ClosePool returned unexpected error: %v", err)
199+
}
200+
}
201+
202+
func TestExecuteDdlBatchInTransaction(t *testing.T) {
203+
t.Parallel()
204+
205+
ctx := context.Background()
206+
server, teardown := setupMockServer(t)
207+
defer teardown()
208+
dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address)
209+
210+
poolId, err := CreatePool(ctx, dsn)
211+
if err != nil {
212+
t.Fatalf("CreatePool returned unexpected error: %v", err)
213+
}
214+
connId, err := CreateConnection(ctx, poolId)
215+
if err != nil {
216+
t.Fatalf("CreateConnection returned unexpected error: %v", err)
217+
}
218+
if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil {
219+
t.Fatalf("BeginTransaction returned unexpected error: %v", err)
220+
}
221+
222+
// Try to execute a DDL batch in a transaction. This should fail.
223+
request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{
224+
{Sql: "create table my_table (id int64 primary key, value string(100))"},
225+
{Sql: "create index my_index on my_table (value)"},
226+
}}
227+
_, err = ExecuteBatch(ctx, poolId, connId, request)
228+
if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w {
229+
t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w)
230+
}
231+
232+
if err := CloseConnection(ctx, poolId, connId); err != nil {
233+
t.Fatalf("CloseConnection returned unexpected error: %v", err)
234+
}
235+
if err := ClosePool(ctx, poolId); err != nil {
236+
t.Fatalf("ClosePool returned unexpected error: %v", err)
237+
}
238+
}

0 commit comments

Comments
 (0)