Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,27 @@ func sum(affected []int64) int64 {
return sum
}

// WriteMutations is not part of the public API of the database/sql driver.
// It is exported for internal reasons, and may receive breaking changes without prior notice.
//
// WriteMutations writes mutations using this connection. The mutations are either buffered in the current transaction,
// or written directly to Spanner using a new read/write transaction if the connection does not have a transaction.
//
// The function returns an error if the connection currently has a read-only transaction.
//
// The returned CommitResponse is nil if the connection currently has a transaction, as the mutations will only be
// applied to Spanner when the transaction commits.
func (c *conn) WriteMutations(ctx context.Context, ms []*spanner.Mutation) (*spanner.CommitResponse, error) {
if c.inTransaction() {
return nil, c.BufferWrite(ms)
}
ts, err := c.Apply(ctx, ms)
if err != nil {
return nil, err
}
return &spanner.CommitResponse{CommitTs: ts}, nil
}

func (c *conn) Apply(ctx context.Context, ms []*spanner.Mutation, opts ...spanner.ApplyOption) (commitTimestamp time.Time, err error) {
if c.inTransaction() {
return time.Time{}, spanner.ToSpannerError(
Expand Down
45 changes: 45 additions & 0 deletions spannerlib/api/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,22 @@ func CloseConnection(ctx context.Context, poolId, connId int64) error {
return conn.close(ctx)
}

// WriteMutations writes an array of mutations to Spanner. The mutations are buffered in
// the current read/write transaction if the connection currently has a read/write transaction.
// The mutations are applied to the database in a new read/write transaction that is automatically
// committed if the connection currently does not have a transaction.
//
// The function returns an error if the connection is currently in a read-only transaction.
//
// The mutationsBytes must be an encoded BatchWriteRequest_MutationGroup protobuf object.
func WriteMutations(ctx context.Context, poolId, connId int64, mutations *spannerpb.BatchWriteRequest_MutationGroup) (*spannerpb.CommitResponse, error) {
conn, err := findConnection(poolId, connId)
if err != nil {
return nil, err
}
return conn.writeMutations(ctx, mutations)
}

// BeginTransaction starts a new transaction on the given connection.
// A connection can have at most one transaction at any time. This function therefore returns an error if the
// connection has an active transaction.
Expand Down Expand Up @@ -104,6 +120,7 @@ type Connection struct {
// spannerConn is an internal interface that contains the internal functions that are used by this API.
// It is implemented by the spannerdriver.conn struct.
type spannerConn interface {
WriteMutations(ctx context.Context, ms []*spanner.Mutation) (*spanner.CommitResponse, error)
BeginReadOnlyTransaction(ctx context.Context, options *spannerdriver.ReadOnlyTransactionOptions) (driver.Tx, error)
BeginReadWriteTransaction(ctx context.Context, options *spannerdriver.ReadWriteTransactionOptions) (driver.Tx, error)
Commit(ctx context.Context) (*spanner.CommitResponse, error)
Expand All @@ -127,6 +144,34 @@ func (conn *Connection) close(ctx context.Context) error {
return nil
}

func (conn *Connection) writeMutations(ctx context.Context, mutation *spannerpb.BatchWriteRequest_MutationGroup) (*spannerpb.CommitResponse, error) {
mutations := make([]*spanner.Mutation, 0, len(mutation.Mutations))
for _, m := range mutation.Mutations {
spannerMutation, err := spanner.WrapMutation(m)
if err != nil {
return nil, err
}
mutations = append(mutations, spannerMutation)
}
var commitResponse *spanner.CommitResponse
if err := conn.backend.Raw(func(driverConn any) (err error) {
sc, _ := driverConn.(spannerConn)
commitResponse, err = sc.WriteMutations(ctx, mutations)
return err
}); err != nil {
return nil, err
}

// The commit response is nil if the connection is currently in a transaction.
if commitResponse == nil {
return nil, nil
}
response := spannerpb.CommitResponse{
CommitTimestamp: timestamppb.New(commitResponse.CommitTs),
}
return &response, nil
}

func (conn *Connection) BeginTransaction(ctx context.Context, txOpts *spannerpb.TransactionOptions) error {
var err error
if txOpts.GetReadOnly() != nil {
Expand Down
157 changes: 157 additions & 0 deletions spannerlib/api/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"cloud.google.com/go/spanner/apiv1/spannerpb"
"github.com/googleapis/go-sql-spanner/testutil"
"google.golang.org/grpc/codes"
"google.golang.org/protobuf/types/known/structpb"
)

func TestCreateAndCloseConnection(t *testing.T) {
Expand Down Expand Up @@ -143,3 +144,159 @@ func TestCloseConnectionTwice(t *testing.T) {
t.Fatalf("ClosePool returned unexpected error: %v", err)
}
}

func TestWriteMutations(t *testing.T) {
t.Parallel()

ctx := context.Background()
server, teardown := setupMockServer(t)
defer teardown()
dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address)

poolId, err := CreatePool(ctx, dsn)
if err != nil {
t.Fatalf("CreatePool returned unexpected error: %v", err)
}
connId, err := CreateConnection(ctx, poolId)
if err != nil {
t.Fatalf("CreateConnection returned unexpected error: %v", err)
}

mutations := &spannerpb.BatchWriteRequest_MutationGroup{Mutations: []*spannerpb.Mutation{
{Operation: &spannerpb.Mutation_Insert{Insert: &spannerpb.Mutation_Write{
Table: "my_table",
Columns: []string{"id", "value"},
Values: []*structpb.ListValue{
{Values: []*structpb.Value{structpb.NewStringValue("1"), structpb.NewStringValue("One")}},
{Values: []*structpb.Value{structpb.NewStringValue("2"), structpb.NewStringValue("Two")}},
{Values: []*structpb.Value{structpb.NewStringValue("3"), structpb.NewStringValue("Three")}},
},
}}},
{Operation: &spannerpb.Mutation_Update{Update: &spannerpb.Mutation_Write{
Table: "my_table",
Columns: []string{"id", "value"},
Values: []*structpb.ListValue{
{Values: []*structpb.Value{structpb.NewStringValue("0"), structpb.NewStringValue("Zero")}},
},
}}},
}}
resp, err := WriteMutations(ctx, poolId, connId, mutations)
if err != nil {
t.Fatalf("WriteMutations returned unexpected error: %v", err)
}
if resp.CommitTimestamp == nil {
t.Fatalf("CommitTimestamp is nil")
}
requests := server.TestSpanner.DrainRequestsFromServer()
beginRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{}))
if g, w := len(beginRequests), 1; g != w {
t.Fatalf("num BeginTransaction requests mismatch\n Got: %d\nWant: %d", g, w)
}
commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{}))
if g, w := len(commitRequests), 1; g != w {
t.Fatalf("num CommitRequests mismatch\n Got: %d\nWant: %d", g, w)
}
commitRequest := commitRequests[0].(*spannerpb.CommitRequest)
if g, w := len(commitRequest.Mutations), 2; g != w {
t.Fatalf("num mutations mismatch\n Got: %d\nWant: %d", g, w)
}

// Write the same mutations in a transaction.
if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil {
t.Fatalf("BeginTransaction returned unexpected error: %v", err)
}
resp, err = WriteMutations(ctx, poolId, connId, mutations)
if err != nil {
t.Fatalf("WriteMutations returned unexpected error: %v", err)
}
if resp != nil {
t.Fatalf("WriteMutations returned unexpected response: %v", resp)
}
resp, err = Commit(ctx, poolId, connId)
if err != nil {
t.Fatalf("Commit returned unexpected error: %v", err)
}
if resp == nil {
t.Fatalf("Commit returned nil response")
}
if resp.CommitTimestamp == nil {
t.Fatalf("CommitTimestamp is nil")
}
requests = server.TestSpanner.DrainRequestsFromServer()
beginRequests = testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{}))
if g, w := len(beginRequests), 1; g != w {
t.Fatalf("num BeginTransaction requests mismatch\n Got: %d\nWant: %d", g, w)
}
commitRequests = testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{}))
if g, w := len(commitRequests), 1; g != w {
t.Fatalf("num CommitRequests mismatch\n Got: %d\nWant: %d", g, w)
}
commitRequest = commitRequests[0].(*spannerpb.CommitRequest)
if g, w := len(commitRequest.Mutations), 2; g != w {
t.Fatalf("num mutations mismatch\n Got: %d\nWant: %d", g, w)
}

if err := ClosePool(ctx, poolId); err != nil {
t.Fatalf("ClosePool returned unexpected error: %v", err)
}
}

func TestWriteMutationsInReadOnlyTx(t *testing.T) {
t.Parallel()

ctx := context.Background()
server, teardown := setupMockServer(t)
defer teardown()
dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address)

poolId, err := CreatePool(ctx, dsn)
if err != nil {
t.Fatalf("CreatePool returned unexpected error: %v", err)
}
connId, err := CreateConnection(ctx, poolId)
if err != nil {
t.Fatalf("CreateConnection returned unexpected error: %v", err)
}

// Start a read-only transaction and try to write mutations to that transaction. That should return an error.
if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{
Mode: &spannerpb.TransactionOptions_ReadOnly_{ReadOnly: &spannerpb.TransactionOptions_ReadOnly{}},
}); err != nil {
t.Fatalf("BeginTransaction returned unexpected error: %v", err)
}

mutations := &spannerpb.BatchWriteRequest_MutationGroup{Mutations: []*spannerpb.Mutation{
{Operation: &spannerpb.Mutation_Insert{Insert: &spannerpb.Mutation_Write{
Table: "my_table",
Columns: []string{"id", "value"},
Values: []*structpb.ListValue{
{Values: []*structpb.Value{structpb.NewStringValue("1"), structpb.NewStringValue("One")}},
},
}}},
}}
_, err = WriteMutations(ctx, poolId, connId, mutations)
if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w {
t.Fatalf("WriteMutations error code mismatch\n Got: %d\nWant: %d", g, w)
}

// Committing the read-only transaction should not lead to any commits on Spanner.
_, err = Commit(ctx, poolId, connId)
if err != nil {
t.Fatalf("Commit returned unexpected error: %v", err)
}
requests := server.TestSpanner.DrainRequestsFromServer()
// There should also not be any BeginTransaction requests on Spanner, as the transaction was never really started
// by a query or other statement.
beginRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{}))
if g, w := len(beginRequests), 0; g != w {
t.Fatalf("num BeginTransaction requests mismatch\n Got: %d\nWant: %d", g, w)
}
commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{}))
if g, w := len(commitRequests), 0; g != w {
t.Fatalf("num CommitRequests mismatch\n Got: %d\nWant: %d", g, w)
}

if err := ClosePool(ctx, poolId); err != nil {
t.Fatalf("ClosePool returned unexpected error: %v", err)
}
}
24 changes: 24 additions & 0 deletions spannerlib/lib/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,30 @@ func CloseConnection(ctx context.Context, poolId, connId int64) *Message {
return &Message{}
}

// WriteMutations writes an array of mutations to Spanner. The mutations are buffered in
// the current read/write transaction if the connection currently has a read/write transaction.
// The mutations are applied to the database in a new read/write transaction that is automatically
// committed if the connection currently does not have a transaction.
//
// The function returns an error if the connection is currently in a read-only transaction.
//
// The mutationsBytes must be an encoded BatchWriteRequest_MutationGroup protobuf object.
func WriteMutations(ctx context.Context, poolId, connId int64, mutationBytes []byte) *Message {
mutations := spannerpb.BatchWriteRequest_MutationGroup{}
if err := proto.Unmarshal(mutationBytes, &mutations); err != nil {
return errMessage(err)
}
response, err := api.WriteMutations(ctx, poolId, connId, &mutations)
if err != nil {
return errMessage(err)
}
res, err := proto.Marshal(response)
if err != nil {
return errMessage(err)
}
return &Message{Res: res}
}

// BeginTransaction starts a new transaction on the given connection. A connection can have at most one active
// transaction at any time. This function therefore returns an error if the connection has an active transaction.
func BeginTransaction(ctx context.Context, poolId, connId int64, txOptsBytes []byte) *Message {
Expand Down
60 changes: 60 additions & 0 deletions spannerlib/lib/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/googleapis/go-sql-spanner/testutil"
"google.golang.org/grpc/codes"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/structpb"
)

func TestCreateAndCloseConnection(t *testing.T) {
Expand Down Expand Up @@ -262,3 +263,62 @@ func TestBeginAndRollback(t *testing.T) {
t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w)
}
}

func TestWriteMutations(t *testing.T) {
t.Parallel()

ctx := context.Background()
server, teardown := setupMockServer(t)
defer teardown()
dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address)

poolMsg := CreatePool(ctx, dsn)
if g, w := poolMsg.Code, int32(0); g != w {
t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w)
}
connMsg := CreateConnection(ctx, poolMsg.ObjectId)
if g, w := connMsg.Code, int32(0); g != w {
t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w)
}
mutations := &spannerpb.BatchWriteRequest_MutationGroup{Mutations: []*spannerpb.Mutation{
{Operation: &spannerpb.Mutation_Insert{Insert: &spannerpb.Mutation_Write{
Table: "my_table",
Columns: []string{"id", "value"},
Values: []*structpb.ListValue{
{Values: []*structpb.Value{structpb.NewStringValue("1"), structpb.NewStringValue("One")}},
{Values: []*structpb.Value{structpb.NewStringValue("2"), structpb.NewStringValue("Two")}},
{Values: []*structpb.Value{structpb.NewStringValue("3"), structpb.NewStringValue("Three")}},
},
}}},
}}
mutationBytes, err := proto.Marshal(mutations)
if err != nil {
t.Fatal(err)
}
mutationsMsg := WriteMutations(ctx, poolMsg.ObjectId, connMsg.ObjectId, mutationBytes)
if g, w := mutationsMsg.Code, int32(0); g != w {
t.Fatalf("WriteMutations result mismatch\n Got: %v\nWant: %v", g, w)
}
if mutationsMsg.Length() == 0 {
t.Fatal("WriteMutations returned no data")
}

// Write mutations in a transaction.
mutationsMsg = BeginTransaction(ctx, poolMsg.ObjectId, connMsg.ObjectId, mutationBytes)
// The response should now be an empty message, as the mutations were only buffered in the transaction.
if g, w := mutationsMsg.Code, int32(0); g != w {
t.Fatalf("WriteMutations result mismatch\n Got: %v\nWant: %v", g, w)
}
if g, w := mutationsMsg.Length(), int32(0); g != w {
t.Fatalf("WriteMutations data length mismatch\n Got: %v\nWant: %v", g, w)
}

closeMsg := CloseConnection(ctx, poolMsg.ObjectId, connMsg.ObjectId)
if g, w := closeMsg.Code, int32(0); g != w {
t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w)
}
closeMsg = ClosePool(ctx, poolMsg.ObjectId)
if g, w := closeMsg.Code, int32(0); g != w {
t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w)
}
}
16 changes: 16 additions & 0 deletions spannerlib/shared/shared_lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,22 @@ func CloseConnection(poolId, connId int64) (int64, int32, int64, int32, unsafe.P
return pin(msg)
}

// WriteMutations writes an array of mutations to Spanner. The mutations are buffered in
// the current read/write transaction if the connection currently has a read/write transaction.
// The mutations are applied to the database in a new read/write transaction that is automatically
// committed if the connection currently does not have a transaction.
//
// The function returns an error if the connection is currently in a read-only transaction.
//
// The mutationsBytes must be an encoded BatchWriteRequest_MutationGroup protobuf object.
//
//export WriteMutations
func WriteMutations(poolId, connectionId int64, mutationsBytes []byte) (int64, int32, int64, int32, unsafe.Pointer) {
ctx := context.Background()
msg := lib.WriteMutations(ctx, poolId, connectionId, mutationsBytes)
return pin(msg)
}

// Execute executes a SQL statement on the given connection.
// The return type is an identifier for a Rows object. This identifier can be used to
// call the functions Metadata and Next to get respectively the metadata of the result
Expand Down
Loading
Loading