Skip to content

Commit 66b0f15

Browse files
authored
Merge pull request #986 from ellemouton/sql16Sessions8
[sql-16] sessions: update Store interface methods to take a context
2 parents bc4439f + 01c19c5 commit 66b0f15

File tree

11 files changed

+170
-120
lines changed

11 files changed

+170
-120
lines changed

firewall/privacy_mapper.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ func (p *PrivacyMapper) checkAndReplaceIncomingRequest(ctx context.Context,
190190
uri string, req proto.Message, sessionID session.ID) (proto.Message,
191191
error) {
192192

193-
session, err := p.sessionDB.GetSessionByID(sessionID)
193+
session, err := p.sessionDB.GetSessionByID(ctx, sessionID)
194194
if err != nil {
195195
return nil, err
196196
}
@@ -220,7 +220,7 @@ func (p *PrivacyMapper) checkAndReplaceIncomingRequest(ctx context.Context,
220220
func (p *PrivacyMapper) replaceOutgoingResponse(ctx context.Context, uri string,
221221
resp proto.Message, sessionID session.ID) (proto.Message, error) {
222222

223-
session, err := p.sessionDB.GetSessionByID(sessionID)
223+
session, err := p.sessionDB.GetSessionByID(ctx, sessionID)
224224
if err != nil {
225225
return nil, err
226226
}

firewall/rule_enforcer.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ func (r *RuleEnforcer) initRule(ctx context.Context, reqID uint64, name string,
386386
return nil, err
387387
}
388388

389-
session, err := r.sessionDB.GetSessionByID(sessionID)
389+
session, err := r.sessionDB.GetSessionByID(ctx, sessionID)
390390
if err != nil {
391391
return nil, err
392392
}

firewalldb/actions.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ func (db *DB) ListSessionActions(sessionID session.ID,
391391
// pass the filterFn requirements.
392392
//
393393
// TODO: update to allow for pagination.
394-
func (db *DB) ListGroupActions(groupID session.ID,
394+
func (db *DB) ListGroupActions(ctx context.Context, groupID session.ID,
395395
filterFn ListActionsFilterFn) ([]*Action, error) {
396396

397397
if filterFn == nil {
@@ -400,7 +400,7 @@ func (db *DB) ListGroupActions(groupID session.ID,
400400
}
401401
}
402402

403-
sessionIDs, err := db.sessionIDIndex.GetSessionIDs(groupID)
403+
sessionIDs, err := db.sessionIDIndex.GetSessionIDs(ctx, groupID)
404404
if err != nil {
405405
return nil, err
406406
}
@@ -629,11 +629,11 @@ type groupActionsReadDB struct {
629629
var _ ActionsDB = (*groupActionsReadDB)(nil)
630630

631631
// ListActions will return all the Actions for a particular group.
632-
func (s *groupActionsReadDB) ListActions(_ context.Context) ([]*RuleAction,
632+
func (s *groupActionsReadDB) ListActions(ctx context.Context) ([]*RuleAction,
633633
error) {
634634

635635
sessionActions, err := s.db.ListGroupActions(
636-
s.groupID, func(a *Action, _ bool) (bool, bool) {
636+
ctx, s.groupID, func(a *Action, _ bool) (bool, bool) {
637637
return a.State == ActionStateDone, true
638638
},
639639
)
@@ -660,11 +660,11 @@ var _ ActionsDB = (*groupFeatureActionsReadDB)(nil)
660660

661661
// ListActions will return all the Actions for a particular group that were
662662
// executed by a particular feature.
663-
func (a *groupFeatureActionsReadDB) ListActions(_ context.Context) (
663+
func (a *groupFeatureActionsReadDB) ListActions(ctx context.Context) (
664664
[]*RuleAction, error) {
665665

666666
featureActions, err := a.db.ListGroupActions(
667-
a.groupID, func(action *Action, _ bool) (bool, bool) {
667+
ctx, a.groupID, func(action *Action, _ bool) (bool, bool) {
668668
return action.State == ActionStateDone &&
669669
action.FeatureName == a.featureName, true
670670
},

firewalldb/actions_test.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package firewalldb
22

33
import (
4+
"context"
45
"fmt"
56
"testing"
67
"time"
@@ -342,6 +343,9 @@ func TestListActions(t *testing.T) {
342343
// TestListGroupActions tests that the ListGroupActions correctly returns all
343344
// actions in a particular session group.
344345
func TestListGroupActions(t *testing.T) {
346+
t.Parallel()
347+
ctx := context.Background()
348+
345349
group1 := intToSessionID(0)
346350

347351
// Link session 1 and session 2 to group 1.
@@ -356,7 +360,7 @@ func TestListGroupActions(t *testing.T) {
356360
})
357361

358362
// There should not be any actions in group 1 yet.
359-
al, err := db.ListGroupActions(group1, nil)
363+
al, err := db.ListGroupActions(ctx, group1, nil)
360364
require.NoError(t, err)
361365
require.Empty(t, al)
362366

@@ -365,7 +369,7 @@ func TestListGroupActions(t *testing.T) {
365369
require.NoError(t, err)
366370

367371
// There should now be one action in the group.
368-
al, err = db.ListGroupActions(group1, nil)
372+
al, err = db.ListGroupActions(ctx, group1, nil)
369373
require.NoError(t, err)
370374
require.Len(t, al, 1)
371375
require.Equal(t, sessionID1, al[0].SessionID)
@@ -375,7 +379,7 @@ func TestListGroupActions(t *testing.T) {
375379
require.NoError(t, err)
376380

377381
// There should now be actions in the group.
378-
al, err = db.ListGroupActions(group1, nil)
382+
al, err = db.ListGroupActions(ctx, group1, nil)
379383
require.NoError(t, err)
380384
require.Len(t, al, 2)
381385
require.Equal(t, sessionID1, al[0].SessionID)

firewalldb/interface.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
package firewalldb
22

3-
import "github.com/lightninglabs/lightning-terminal/session"
3+
import (
4+
"context"
5+
6+
"github.com/lightninglabs/lightning-terminal/session"
7+
)
48

59
// SessionDB is an interface that abstracts the database operations needed for
610
// the privacy mapper to function.
711
type SessionDB interface {
812
session.IDToGroupIndex
913

1014
// GetSessionByID returns the session for a specific id.
11-
GetSessionByID(session.ID) (*session.Session, error)
15+
GetSessionByID(context.Context, session.ID) (*session.Session, error)
1216
}

firewalldb/mock.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package firewalldb
22

33
import (
4+
"context"
45
"fmt"
56

67
"github.com/lightninglabs/lightning-terminal/session"
@@ -33,7 +34,9 @@ func (m *mockSessionDB) AddPair(sessionID, groupID session.ID) {
3334
}
3435

3536
// GetGroupID returns the group ID for the given session ID.
36-
func (m *mockSessionDB) GetGroupID(sessionID session.ID) (session.ID, error) {
37+
func (m *mockSessionDB) GetGroupID(_ context.Context, sessionID session.ID) (
38+
session.ID, error) {
39+
3740
id, ok := m.sessionToGroupID[sessionID]
3841
if !ok {
3942
return session.ID{}, fmt.Errorf("no group ID found for " +
@@ -44,7 +47,9 @@ func (m *mockSessionDB) GetGroupID(sessionID session.ID) (session.ID, error) {
4447
}
4548

4649
// GetSessionIDs returns the set of session IDs that are in the group
47-
func (m *mockSessionDB) GetSessionIDs(groupID session.ID) ([]session.ID, error) {
50+
func (m *mockSessionDB) GetSessionIDs(_ context.Context, groupID session.ID) (
51+
[]session.ID, error) {
52+
4853
ids, ok := m.groupToSessionIDs[groupID]
4954
if !ok {
5055
return nil, fmt.Errorf("no session IDs found for group ID")
@@ -54,8 +59,8 @@ func (m *mockSessionDB) GetSessionIDs(groupID session.ID) ([]session.ID, error)
5459
}
5560

5661
// GetSessionByID returns the session for a specific id.
57-
func (m *mockSessionDB) GetSessionByID(sessionID session.ID) (*session.Session,
58-
error) {
62+
func (m *mockSessionDB) GetSessionByID(_ context.Context,
63+
sessionID session.ID) (*session.Session, error) {
5964

6065
s, ok := m.sessionToGroupID[sessionID]
6166
if !ok {

session/interface.go

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package session
22

33
import (
4+
"context"
45
"fmt"
56
"time"
67

@@ -260,11 +261,11 @@ func WithMacaroonRecipe(caveats []macaroon.Caveat, perms []bakery.Op) Option {
260261
// IDToGroupIndex defines an interface for the session ID to group ID index.
261262
type IDToGroupIndex interface {
262263
// GetGroupID will return the group ID for the given session ID.
263-
GetGroupID(sessionID ID) (ID, error)
264+
GetGroupID(ctx context.Context, sessionID ID) (ID, error)
264265

265266
// GetSessionIDs will return the set of session IDs that are in the
266267
// group with the given ID.
267-
GetSessionIDs(groupID ID) ([]ID, error)
268+
GetSessionIDs(ctx context.Context, groupID ID) ([]ID, error)
268269
}
269270

270271
// Store is the interface a persistent storage must implement for storing and
@@ -273,37 +274,39 @@ type Store interface {
273274
// NewSession creates a new session with the given user-defined
274275
// parameters. The session will remain in the StateReserved state until
275276
// ShiftState is called to update the state.
276-
NewSession(label string, typ Type, expiry time.Time, serverAddr string,
277-
opts ...Option) (*Session, error)
277+
NewSession(ctx context.Context, label string, typ Type,
278+
expiry time.Time, serverAddr string, opts ...Option) (*Session,
279+
error)
278280

279281
// GetSession fetches the session with the given key.
280-
GetSession(key *btcec.PublicKey) (*Session, error)
282+
GetSession(ctx context.Context, key *btcec.PublicKey) (*Session, error)
281283

282284
// ListAllSessions returns all sessions currently known to the store.
283-
ListAllSessions() ([]*Session, error)
285+
ListAllSessions(ctx context.Context) ([]*Session, error)
284286

285287
// ListSessionsByType returns all sessions of the given type.
286-
ListSessionsByType(t Type) ([]*Session, error)
288+
ListSessionsByType(ctx context.Context, t Type) ([]*Session, error)
287289

288290
// ListSessionsByState returns all sessions currently known to the store
289291
// that are in the given states.
290-
ListSessionsByState(...State) ([]*Session, error)
292+
ListSessionsByState(ctx context.Context, state ...State) ([]*Session,
293+
error)
291294

292295
// UpdateSessionRemotePubKey can be used to add the given remote pub key
293296
// to the session with the given local pub key.
294-
UpdateSessionRemotePubKey(localPubKey,
297+
UpdateSessionRemotePubKey(ctx context.Context, localPubKey,
295298
remotePubKey *btcec.PublicKey) error
296299

297300
// GetSessionByID fetches the session with the given ID.
298-
GetSessionByID(id ID) (*Session, error)
301+
GetSessionByID(ctx context.Context, id ID) (*Session, error)
299302

300303
// DeleteReservedSessions deletes all sessions that are in the
301304
// StateReserved state.
302-
DeleteReservedSessions() error
305+
DeleteReservedSessions(ctx context.Context) error
303306

304307
// ShiftState updates the state of the session with the given ID to the
305308
// "dest" state.
306-
ShiftState(id ID, dest State) error
309+
ShiftState(ctx context.Context, id ID, dest State) error
307310

308311
IDToGroupIndex
309312
}

session/kvdb_store.go

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package session
22

33
import (
44
"bytes"
5+
"context"
56
"encoding/binary"
67
"errors"
78
"fmt"
@@ -185,8 +186,8 @@ func getSessionKey(session *Session) []byte {
185186
// ShiftState is called with StateCreated.
186187
//
187188
// NOTE: this is part of the Store interface.
188-
func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time,
189-
serverAddr string, opts ...Option) (*Session, error) {
189+
func (db *BoltStore) NewSession(ctx context.Context, label string, typ Type,
190+
expiry time.Time, serverAddr string, opts ...Option) (*Session, error) {
190191

191192
var session *Session
192193
err := db.Update(func(tx *bbolt.Tx) error {
@@ -285,7 +286,7 @@ func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time,
285286
// to the session with the given local pub key.
286287
//
287288
// NOTE: this is part of the Store interface.
288-
func (db *BoltStore) UpdateSessionRemotePubKey(localPubKey,
289+
func (db *BoltStore) UpdateSessionRemotePubKey(_ context.Context, localPubKey,
289290
remotePubKey *btcec.PublicKey) error {
290291

291292
key := localPubKey.SerializeCompressed()
@@ -318,7 +319,9 @@ func (db *BoltStore) UpdateSessionRemotePubKey(localPubKey,
318319
// GetSession fetches the session with the given key.
319320
//
320321
// NOTE: this is part of the Store interface.
321-
func (db *BoltStore) GetSession(key *btcec.PublicKey) (*Session, error) {
322+
func (db *BoltStore) GetSession(_ context.Context, key *btcec.PublicKey) (
323+
*Session, error) {
324+
322325
var session *Session
323326
err := db.View(func(tx *bbolt.Tx) error {
324327
sessionBucket, err := getBucket(tx, sessionBucketKey)
@@ -348,7 +351,7 @@ func (db *BoltStore) GetSession(key *btcec.PublicKey) (*Session, error) {
348351
// ListAllSessions returns all sessions currently known to the store.
349352
//
350353
// NOTE: this is part of the Store interface.
351-
func (db *BoltStore) ListAllSessions() ([]*Session, error) {
354+
func (db *BoltStore) ListAllSessions(_ context.Context) ([]*Session, error) {
352355
return db.listSessions(func(s *Session) bool {
353356
return true
354357
})
@@ -358,7 +361,9 @@ func (db *BoltStore) ListAllSessions() ([]*Session, error) {
358361
// have the given type.
359362
//
360363
// NOTE: this is part of the Store interface.
361-
func (db *BoltStore) ListSessionsByType(t Type) ([]*Session, error) {
364+
func (db *BoltStore) ListSessionsByType(_ context.Context, t Type) ([]*Session,
365+
error) {
366+
362367
return db.listSessions(func(s *Session) bool {
363368
return s.Type == t
364369
})
@@ -368,7 +373,9 @@ func (db *BoltStore) ListSessionsByType(t Type) ([]*Session, error) {
368373
// are in the given states.
369374
//
370375
// NOTE: this is part of the Store interface.
371-
func (db *BoltStore) ListSessionsByState(states ...State) ([]*Session, error) {
376+
func (db *BoltStore) ListSessionsByState(_ context.Context, states ...State) (
377+
[]*Session, error) {
378+
372379
return db.listSessions(func(s *Session) bool {
373380
for _, state := range states {
374381
if s.State == state {
@@ -429,7 +436,7 @@ func (db *BoltStore) listSessions(filterFn func(s *Session) bool) ([]*Session,
429436
// state.
430437
//
431438
// NOTE: this is part of the Store interface.
432-
func (db *BoltStore) DeleteReservedSessions() error {
439+
func (db *BoltStore) DeleteReservedSessions(_ context.Context) error {
433440
return db.Update(func(tx *bbolt.Tx) error {
434441
sessionBucket, err := getBucket(tx, sessionBucketKey)
435442
if err != nil {
@@ -522,7 +529,7 @@ func (db *BoltStore) DeleteReservedSessions() error {
522529
// state.
523530
//
524531
// NOTE: this is part of the Store interface.
525-
func (db *BoltStore) ShiftState(id ID, dest State) error {
532+
func (db *BoltStore) ShiftState(_ context.Context, id ID, dest State) error {
526533
return db.Update(func(tx *bbolt.Tx) error {
527534
sessionBucket, err := getBucket(tx, sessionBucketKey)
528535
if err != nil {
@@ -562,7 +569,9 @@ func (db *BoltStore) ShiftState(id ID, dest State) error {
562569
// GetSessionByID fetches the session with the given ID.
563570
//
564571
// NOTE: this is part of the Store interface.
565-
func (db *BoltStore) GetSessionByID(id ID) (*Session, error) {
572+
func (db *BoltStore) GetSessionByID(_ context.Context, id ID) (*Session,
573+
error) {
574+
566575
var session *Session
567576
err := db.View(func(tx *bbolt.Tx) error {
568577
sessionBucket, err := getBucket(tx, sessionBucketKey)
@@ -615,7 +624,7 @@ func getUnusedIDAndKeyPair(bucket *bbolt.Bucket) (ID, *btcec.PrivateKey,
615624
// GetGroupID will return the group ID for the given session ID.
616625
//
617626
// NOTE: this is part of the IDToGroupIndex interface.
618-
func (db *BoltStore) GetGroupID(sessionID ID) (ID, error) {
627+
func (db *BoltStore) GetGroupID(_ context.Context, sessionID ID) (ID, error) {
619628
var groupID ID
620629
err := db.View(func(tx *bbolt.Tx) error {
621630
sessionBkt, err := getBucket(tx, sessionBucketKey)
@@ -655,7 +664,9 @@ func (db *BoltStore) GetGroupID(sessionID ID) (ID, error) {
655664
// group with the given ID.
656665
//
657666
// NOTE: this is part of the IDToGroupIndex interface.
658-
func (db *BoltStore) GetSessionIDs(groupID ID) ([]ID, error) {
667+
func (db *BoltStore) GetSessionIDs(_ context.Context, groupID ID) ([]ID,
668+
error) {
669+
659670
var (
660671
sessionIDs []ID
661672
err error

0 commit comments

Comments
 (0)