diff --git a/accounts/store_sql.go b/accounts/store_sql.go index 830f16587..4febdde62 100644 --- a/accounts/store_sql.go +++ b/accounts/store_sql.go @@ -34,6 +34,8 @@ const ( //nolint:lll type SQLQueries interface { AddAccountInvoice(ctx context.Context, arg sqlc.AddAccountInvoiceParams) error + CreditAccount(ctx context.Context, arg sqlc.CreditAccountParams) (int64, error) + DebitAccount(ctx context.Context, arg sqlc.DebitAccountParams) (int64, error) DeleteAccount(ctx context.Context, id int64) error DeleteAccountPayment(ctx context.Context, arg sqlc.DeleteAccountPaymentParams) error GetAccount(ctx context.Context, id int64) (sqlc.Account, error) @@ -394,17 +396,10 @@ func (s *SQLStore) CreditAccount(ctx context.Context, alias AccountID, return err } - acct, err := db.GetAccount(ctx, id) - if err != nil { - return err - } - - newBalance := acct.CurrentBalanceMsat + int64(amount) - - _, err = db.UpdateAccountBalance( - ctx, sqlc.UpdateAccountBalanceParams{ - ID: id, - CurrentBalanceMsat: newBalance, + _, err = db.CreditAccount( + ctx, sqlc.CreditAccountParams{ + ID: id, + Amount: int64(amount), }, ) if err != nil { @@ -429,26 +424,17 @@ func (s *SQLStore) DebitAccount(ctx context.Context, alias AccountID, return err } - acct, err := db.GetAccount(ctx, id) - if err != nil { - return err - } - - if acct.CurrentBalanceMsat-int64(amount) < 0 { + id, err = db.DebitAccount( + ctx, sqlc.DebitAccountParams{ + ID: id, + Amount: int64(amount), + }, + ) + if errors.Is(err, sql.ErrNoRows) { return fmt.Errorf("cannot debit %v from the account "+ "balance, as the resulting balance would be "+ "below 0", int64(amount/1000)) - } - - newBalance := acct.CurrentBalanceMsat - int64(amount) - - _, err = db.UpdateAccountBalance( - ctx, sqlc.UpdateAccountBalanceParams{ - ID: id, - CurrentBalanceMsat: newBalance, - }, - ) - if err != nil { + } else if err != nil { return err } diff --git a/db/sqlc/accounts.sql.go b/db/sqlc/accounts.sql.go index 4deefdb88..1b991a183 100644 --- a/db/sqlc/accounts.sql.go +++ b/db/sqlc/accounts.sql.go @@ -26,6 +26,45 @@ func (q *Queries) AddAccountInvoice(ctx context.Context, arg AddAccountInvoicePa return err } +const creditAccount = `-- name: CreditAccount :one +UPDATE accounts +SET current_balance_msat = current_balance_msat + $2 +WHERE id = $1 +RETURNING id +` + +type CreditAccountParams struct { + ID int64 + Amount int64 +} + +func (q *Queries) CreditAccount(ctx context.Context, arg CreditAccountParams) (int64, error) { + row := q.db.QueryRowContext(ctx, creditAccount, arg.ID, arg.Amount) + var id int64 + err := row.Scan(&id) + return id, err +} + +const debitAccount = `-- name: DebitAccount :one +UPDATE accounts +SET current_balance_msat = current_balance_msat - $2 +WHERE id = $1 +AND current_balance_msat >= $2 +RETURNING id +` + +type DebitAccountParams struct { + ID int64 + Amount int64 +} + +func (q *Queries) DebitAccount(ctx context.Context, arg DebitAccountParams) (int64, error) { + row := q.db.QueryRowContext(ctx, debitAccount, arg.ID, arg.Amount) + var id int64 + err := row.Scan(&id) + return id, err +} + const deleteAccount = `-- name: DeleteAccount :exec DELETE FROM accounts WHERE id = $1 diff --git a/db/sqlc/querier.go b/db/sqlc/querier.go index 76355ed6f..7fad274aa 100644 --- a/db/sqlc/querier.go +++ b/db/sqlc/querier.go @@ -11,6 +11,8 @@ import ( type Querier interface { AddAccountInvoice(ctx context.Context, arg AddAccountInvoiceParams) error + CreditAccount(ctx context.Context, arg CreditAccountParams) (int64, error) + DebitAccount(ctx context.Context, arg DebitAccountParams) (int64, error) DeleteAccount(ctx context.Context, id int64) error DeleteAccountPayment(ctx context.Context, arg DeleteAccountPaymentParams) error DeleteSessionsWithState(ctx context.Context, state int16) error diff --git a/db/sqlc/queries/accounts.sql b/db/sqlc/queries/accounts.sql index 637a49727..1512a430b 100644 --- a/db/sqlc/queries/accounts.sql +++ b/db/sqlc/queries/accounts.sql @@ -9,6 +9,19 @@ SET current_balance_msat = $1 WHERE id = $2 RETURNING id; +-- name: CreditAccount :one +UPDATE accounts +SET current_balance_msat = current_balance_msat + sqlc.arg(amount) +WHERE id = $1 +RETURNING id; + +-- name: DebitAccount :one +UPDATE accounts +SET current_balance_msat = current_balance_msat - sqlc.arg(amount) +WHERE id = $1 +AND current_balance_msat >= sqlc.arg(amount) +RETURNING id; + -- name: UpdateAccountExpiry :one UPDATE accounts SET expiration = $1