From c71df78c0fa4f2765dff7ffee3762528fdbca642 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sun, 9 Feb 2025 12:12:20 +0200 Subject: [PATCH 1/6] session: add a variadic test session creation modifier So that we can add more modifiers for future tests. --- session/store_test.go | 46 ++++++++++++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/session/store_test.go b/session/store_test.go index 18bd933d6..cb491f573 100644 --- a/session/store_test.go +++ b/session/store_test.go @@ -24,10 +24,10 @@ func TestBasicSessionStore(t *testing.T) { }) // Create a few sessions. - s1 := newSession(t, db, clock, "session 1", nil) - s2 := newSession(t, db, clock, "session 2", nil) - s3 := newSession(t, db, clock, "session 3", nil) - s4 := newSession(t, db, clock, "session 4", nil) + s1 := newSession(t, db, clock, "session 1") + s2 := newSession(t, db, clock, "session 2") + s3 := newSession(t, db, clock, "session 3") + s4 := newSession(t, db, clock, "session 4") // Persist session 1. This should now succeed. require.NoError(t, db.CreateSession(s1)) @@ -101,10 +101,10 @@ func TestLinkingSessions(t *testing.T) { }) // Create a new session with no previous link. - s1 := newSession(t, db, clock, "session 1", nil) + s1 := newSession(t, db, clock, "session 1") // Create another session and link it to the first. - s2 := newSession(t, db, clock, "session 2", &s1.GroupID) + s2 := newSession(t, db, clock, "session 2", withLinkedGroupID(&s1.GroupID)) // Try to persist the second session and assert that it fails due to the // linked session not existing in the DB yet. @@ -141,9 +141,9 @@ func TestLinkedSessions(t *testing.T) { // after are all linked to the prior one. All these sessions belong to // the same group. The group ID is equivalent to the session ID of the // first session. - s1 := newSession(t, db, clock, "session 1", nil) - s2 := newSession(t, db, clock, "session 2", &s1.GroupID) - s3 := newSession(t, db, clock, "session 3", &s2.GroupID) + s1 := newSession(t, db, clock, "session 1") + s2 := newSession(t, db, clock, "session 2", withLinkedGroupID(&s1.GroupID)) + s3 := newSession(t, db, clock, "session 3", withLinkedGroupID(&s2.GroupID)) // Persist the sessions. require.NoError(t, db.CreateSession(s1)) @@ -169,8 +169,8 @@ func TestLinkedSessions(t *testing.T) { // To ensure that different groups don't interfere with each other, // let's add another set of linked sessions not linked to the first. - s4 := newSession(t, db, clock, "session 4", nil) - s5 := newSession(t, db, clock, "session 5", &s4.GroupID) + s4 := newSession(t, db, clock, "session 4") + s5 := newSession(t, db, clock, "session 5", withLinkedGroupID(&s4.GroupID)) require.NotEqual(t, s4.GroupID, s1.GroupID) @@ -209,7 +209,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) { // function is checked correctly. // Add a new session to the DB. - s1 := newSession(t, db, clock, "label 1", nil) + s1 := newSession(t, db, clock, "label 1") require.NoError(t, db.CreateSession(s1)) // Check that the group passes against an appropriate predicate. @@ -234,7 +234,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) { require.NoError(t, db.RevokeSession(s1.LocalPublicKey)) // Add a new session to the same group as the first one. - s2 := newSession(t, db, clock, "label 2", &s1.GroupID) + s2 := newSession(t, db, clock, "label 2", withLinkedGroupID(&s1.GroupID)) require.NoError(t, db.CreateSession(s2)) // Check that the group passes against an appropriate predicate. @@ -256,7 +256,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) { require.False(t, ok) // Add a new session that is not linked to the first one. - s3 := newSession(t, db, clock, "completely different", nil) + s3 := newSession(t, db, clock, "completely different") require.NoError(t, db.CreateSession(s3)) // Ensure that the first group is unaffected. @@ -286,8 +286,18 @@ func TestCheckSessionGroupPredicate(t *testing.T) { require.True(t, ok) } +// testSessionModifier is a functional option that can be used to modify the +// default test session created by newSession. +type testSessionModifier func(*Session) + +func withLinkedGroupID(groupID *ID) testSessionModifier { + return func(s *Session) { + s.GroupID = *groupID + } +} + func newSession(t *testing.T, db Store, clock clock.Clock, label string, - linkedGroupID *ID) *Session { + mods ...testSessionModifier) *Session { id, priv, err := db.GetUnusedIDAndKeyPair() require.NoError(t, err) @@ -296,11 +306,15 @@ func newSession(t *testing.T, db Store, clock clock.Clock, label string, id, priv, label, TypeMacaroonAdmin, clock.Now(), time.Date(99999, 1, 1, 0, 0, 0, 0, time.UTC), - "foo.bar.baz:1234", true, nil, nil, nil, true, linkedGroupID, + "foo.bar.baz:1234", true, nil, nil, nil, true, nil, []PrivacyFlag{ClearPubkeys}, ) require.NoError(t, err) + for _, mod := range mods { + mod(session) + } + return session } From 00230029f35444ec882744bb1862faa54e4a9834 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sun, 9 Feb 2025 12:14:33 +0200 Subject: [PATCH 2/6] session: ensure listed sessions are sorted Sorted by creation time. Also add a test to cover this. --- session/kvdb_store.go | 14 ++++++++++++++ session/store_test.go | 16 +++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/session/kvdb_store.go b/session/kvdb_store.go index 19c7f7db2..e5862056f 100644 --- a/session/kvdb_store.go +++ b/session/kvdb_store.go @@ -7,6 +7,7 @@ import ( "fmt" "os" "path/filepath" + "sort" "time" "github.com/btcsuite/btcd/btcec/v2" @@ -367,6 +368,14 @@ func (db *BoltStore) GetSession(key *btcec.PublicKey) (*Session, error) { // // NOTE: this is part of the Store interface. func (db *BoltStore) ListSessions(filterFn func(s *Session) bool) ([]*Session, error) { + return db.listSessions(filterFn) +} + +// listSessions returns all sessions currently known to the store that pass the +// given filter function. +func (db *BoltStore) listSessions(filterFn func(s *Session) bool) ([]*Session, + error) { + var sessions []*Session err := db.View(func(tx *bbolt.Tx) error { sessionBucket, err := getBucket(tx, sessionBucketKey) @@ -399,6 +408,11 @@ func (db *BoltStore) ListSessions(filterFn func(s *Session) bool) ([]*Session, e return nil, err } + // Make sure to sort the sessions by creation time. + sort.Slice(sessions, func(i, j int) bool { + return sessions[i].CreatedAt.Before(sessions[j].CreatedAt) + }) + return sessions, nil } diff --git a/session/store_test.go b/session/store_test.go index cb491f573..051e3187e 100644 --- a/session/store_test.go +++ b/session/store_test.go @@ -23,10 +23,16 @@ func TestBasicSessionStore(t *testing.T) { _ = db.Close() }) - // Create a few sessions. + // Create a few sessions. We increment the time by one second between + // each session to ensure that the created at time is unique and hence + // that the ListSessions method returns the sessions in a deterministic + // order. s1 := newSession(t, db, clock, "session 1") + clock.SetTime(testTime.Add(time.Second)) s2 := newSession(t, db, clock, "session 2") + clock.SetTime(testTime.Add(2 * time.Second)) s3 := newSession(t, db, clock, "session 3") + clock.SetTime(testTime.Add(3 * time.Second)) s4 := newSession(t, db, clock, "session 4") // Persist session 1. This should now succeed. @@ -50,6 +56,14 @@ func TestBasicSessionStore(t *testing.T) { require.NoError(t, db.CreateSession(s2)) require.NoError(t, db.CreateSession(s3)) + // Check that all sessions are returned in ListSessions. + sessions, err := db.ListSessions(nil) + require.NoError(t, err) + require.Equal(t, 3, len(sessions)) + assertEqualSessions(t, s1, sessions[0]) + assertEqualSessions(t, s2, sessions[1]) + assertEqualSessions(t, s3, sessions[2]) + // Ensure that we can retrieve each session by both its local pub key // and by its ID. for _, s := range []*Session{s1, s2, s3} { From 6c36b01fd856ecbc3fa22cb3c783f3e57182bd6c Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sun, 9 Feb 2025 12:22:53 +0200 Subject: [PATCH 3/6] session: add ListSessionsByType method And use it to replace one call to ListSessions which uses a filter function which would be inefficient in SQL land. --- session/interface.go | 3 +++ session/kvdb_store.go | 10 ++++++++++ session/store_test.go | 24 +++++++++++++++++++++++- session_rpcserver.go | 4 +--- 4 files changed, 37 insertions(+), 4 deletions(-) diff --git a/session/interface.go b/session/interface.go index 41bd354cd..08b3c0cb2 100644 --- a/session/interface.go +++ b/session/interface.go @@ -164,6 +164,9 @@ type Store interface { // ListSessions returns all sessions currently known to the store. ListSessions(filterFn func(s *Session) bool) ([]*Session, error) + // ListSessionsByType returns all sessions of the given type. + ListSessionsByType(t Type) ([]*Session, error) + // RevokeSession updates the state of the session with the given local // public key to be revoked. RevokeSession(*btcec.PublicKey) error diff --git a/session/kvdb_store.go b/session/kvdb_store.go index e5862056f..2fc714c26 100644 --- a/session/kvdb_store.go +++ b/session/kvdb_store.go @@ -371,6 +371,16 @@ func (db *BoltStore) ListSessions(filterFn func(s *Session) bool) ([]*Session, e return db.listSessions(filterFn) } +// ListSessionsByType returns all sessions currently known to the store that +// have the given type. +// +// NOTE: this is part of the Store interface. +func (db *BoltStore) ListSessionsByType(t Type) ([]*Session, error) { + return db.listSessions(func(s *Session) bool { + return s.Type == t + }) +} + // listSessions returns all sessions currently known to the store that pass the // given filter function. func (db *BoltStore) listSessions(filterFn func(s *Session) bool) ([]*Session, diff --git a/session/store_test.go b/session/store_test.go index 051e3187e..6057e9fed 100644 --- a/session/store_test.go +++ b/session/store_test.go @@ -31,7 +31,7 @@ func TestBasicSessionStore(t *testing.T) { clock.SetTime(testTime.Add(time.Second)) s2 := newSession(t, db, clock, "session 2") clock.SetTime(testTime.Add(2 * time.Second)) - s3 := newSession(t, db, clock, "session 3") + s3 := newSession(t, db, clock, "session 3", withType(TypeAutopilot)) clock.SetTime(testTime.Add(3 * time.Second)) s4 := newSession(t, db, clock, "session 4") @@ -64,6 +64,22 @@ func TestBasicSessionStore(t *testing.T) { assertEqualSessions(t, s2, sessions[1]) assertEqualSessions(t, s3, sessions[2]) + // Test the ListSessionsByType method. + sessions, err = db.ListSessionsByType(TypeMacaroonAdmin) + require.NoError(t, err) + require.Equal(t, 2, len(sessions)) + assertEqualSessions(t, s1, sessions[0]) + assertEqualSessions(t, s2, sessions[1]) + + sessions, err = db.ListSessionsByType(TypeAutopilot) + require.NoError(t, err) + require.Equal(t, 1, len(sessions)) + assertEqualSessions(t, s3, sessions[0]) + + sessions, err = db.ListSessionsByType(TypeMacaroonReadonly) + require.NoError(t, err) + require.Empty(t, sessions) + // Ensure that we can retrieve each session by both its local pub key // and by its ID. for _, s := range []*Session{s1, s2, s3} { @@ -310,6 +326,12 @@ func withLinkedGroupID(groupID *ID) testSessionModifier { } } +func withType(t Type) testSessionModifier { + return func(s *Session) { + s.Type = t + } +} + func newSession(t *testing.T, db Store, clock clock.Clock, label string, mods ...testSessionModifier) *Session { diff --git a/session_rpcserver.go b/session_rpcserver.go index 666744cd0..d3a8b18a9 100644 --- a/session_rpcserver.go +++ b/session_rpcserver.go @@ -1259,9 +1259,7 @@ func (s *sessionRpcServer) ListAutopilotSessions(_ context.Context, _ *litrpc.ListAutopilotSessionsRequest) ( *litrpc.ListAutopilotSessionsResponse, error) { - sessions, err := s.cfg.db.ListSessions(func(s *session.Session) bool { - return s.Type == session.TypeAutopilot - }) + sessions, err := s.cfg.db.ListSessionsByType(session.TypeAutopilot) if err != nil { return nil, fmt.Errorf("error fetching sessions: %v", err) } From 01410f7950e8d90400485b2a741cb1a97aa995a8 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sun, 9 Feb 2025 12:29:40 +0200 Subject: [PATCH 4/6] session: remove the filter fn in ListSessions And replace with ListAllSessions since no callers of ListSessions currently make use of the filter function. --- session/interface.go | 4 ++-- session/kvdb_store.go | 8 +++++--- session/store_test.go | 22 +++++++++++----------- session_rpcserver.go | 4 ++-- 4 files changed, 20 insertions(+), 18 deletions(-) diff --git a/session/interface.go b/session/interface.go index 08b3c0cb2..c53987fd3 100644 --- a/session/interface.go +++ b/session/interface.go @@ -161,8 +161,8 @@ type Store interface { // GetSession fetches the session with the given key. GetSession(key *btcec.PublicKey) (*Session, error) - // ListSessions returns all sessions currently known to the store. - ListSessions(filterFn func(s *Session) bool) ([]*Session, error) + // ListAllSessions returns all sessions currently known to the store. + ListAllSessions() ([]*Session, error) // ListSessionsByType returns all sessions of the given type. ListSessionsByType(t Type) ([]*Session, error) diff --git a/session/kvdb_store.go b/session/kvdb_store.go index 2fc714c26..161d3b05d 100644 --- a/session/kvdb_store.go +++ b/session/kvdb_store.go @@ -364,11 +364,13 @@ func (db *BoltStore) GetSession(key *btcec.PublicKey) (*Session, error) { return session, nil } -// ListSessions returns all sessions currently known to the store. +// ListAllSessions returns all sessions currently known to the store. // // NOTE: this is part of the Store interface. -func (db *BoltStore) ListSessions(filterFn func(s *Session) bool) ([]*Session, error) { - return db.listSessions(filterFn) +func (db *BoltStore) ListAllSessions() ([]*Session, error) { + return db.listSessions(func(s *Session) bool { + return true + }) } // ListSessionsByType returns all sessions currently known to the store that diff --git a/session/store_test.go b/session/store_test.go index 6057e9fed..99c1fde4b 100644 --- a/session/store_test.go +++ b/session/store_test.go @@ -56,16 +56,8 @@ func TestBasicSessionStore(t *testing.T) { require.NoError(t, db.CreateSession(s2)) require.NoError(t, db.CreateSession(s3)) - // Check that all sessions are returned in ListSessions. - sessions, err := db.ListSessions(nil) - require.NoError(t, err) - require.Equal(t, 3, len(sessions)) - assertEqualSessions(t, s1, sessions[0]) - assertEqualSessions(t, s2, sessions[1]) - assertEqualSessions(t, s3, sessions[2]) - // Test the ListSessionsByType method. - sessions, err = db.ListSessionsByType(TypeMacaroonAdmin) + sessions, err := db.ListSessionsByType(TypeMacaroonAdmin) require.NoError(t, err) require.Equal(t, 2, len(sessions)) assertEqualSessions(t, s1, sessions[0]) @@ -115,9 +107,17 @@ func TestBasicSessionStore(t *testing.T) { // Now revoke the session and assert that the state is revoked. require.NoError(t, db.RevokeSession(s1.LocalPublicKey)) - session1, err = db.GetSession(s1.LocalPublicKey) + s1, err = db.GetSession(s1.LocalPublicKey) require.NoError(t, err) - require.Equal(t, session1.State, StateRevoked) + require.Equal(t, s1.State, StateRevoked) + + // Test that ListAllSessions works. + sessions, err = db.ListAllSessions() + require.NoError(t, err) + require.Equal(t, 3, len(sessions)) + assertEqualSessions(t, s1, sessions[0]) + assertEqualSessions(t, s2, sessions[1]) + assertEqualSessions(t, s3, sessions[2]) } // TestLinkingSessions tests that session linking works as expected. diff --git a/session_rpcserver.go b/session_rpcserver.go index d3a8b18a9..92beafb26 100644 --- a/session_rpcserver.go +++ b/session_rpcserver.go @@ -101,7 +101,7 @@ func newSessionRPCServer(cfg *sessionRpcServerConfig) (*sessionRpcServer, // requests. This includes resuming all non-revoked sessions. func (s *sessionRpcServer) start(ctx context.Context) error { // Start up all previously created sessions. - sessions, err := s.cfg.db.ListSessions(nil) + sessions, err := s.cfg.db.ListAllSessions() if err != nil { return fmt.Errorf("error listing sessions: %v", err) } @@ -536,7 +536,7 @@ func (s *sessionRpcServer) resumeSession(ctx context.Context, func (s *sessionRpcServer) ListSessions(_ context.Context, _ *litrpc.ListSessionsRequest) (*litrpc.ListSessionsResponse, error) { - sessions, err := s.cfg.db.ListSessions(nil) + sessions, err := s.cfg.db.ListAllSessions() if err != nil { return nil, fmt.Errorf("error fetching sessions: %v", err) } From d2b077b9d5d2fb04c142632de247eb936afdbaca Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 13 Feb 2025 15:28:23 +0200 Subject: [PATCH 5/6] session: add ListSessionsByState method --- session/interface.go | 4 ++++ session/kvdb_store.go | 16 ++++++++++++++++ session/store_test.go | 27 +++++++++++++++++++++++++++ 3 files changed, 47 insertions(+) diff --git a/session/interface.go b/session/interface.go index c53987fd3..6b27a1f38 100644 --- a/session/interface.go +++ b/session/interface.go @@ -167,6 +167,10 @@ type Store interface { // ListSessionsByType returns all sessions of the given type. ListSessionsByType(t Type) ([]*Session, error) + // ListSessionsByState returns all sessions currently known to the store + // that are in the given states. + ListSessionsByState(...State) ([]*Session, error) + // RevokeSession updates the state of the session with the given local // public key to be revoked. RevokeSession(*btcec.PublicKey) error diff --git a/session/kvdb_store.go b/session/kvdb_store.go index 161d3b05d..2f5f252d5 100644 --- a/session/kvdb_store.go +++ b/session/kvdb_store.go @@ -383,6 +383,22 @@ func (db *BoltStore) ListSessionsByType(t Type) ([]*Session, error) { }) } +// ListSessionsByState returns all sessions currently known to the store that +// are in the given states. +// +// NOTE: this is part of the Store interface. +func (db *BoltStore) ListSessionsByState(states ...State) ([]*Session, error) { + return db.listSessions(func(s *Session) bool { + for _, state := range states { + if s.State == state { + return true + } + } + + return false + }) +} + // listSessions returns all sessions currently known to the store that pass the // given filter function. func (db *BoltStore) listSessions(filterFn func(s *Session) bool) ([]*Session, diff --git a/session/store_test.go b/session/store_test.go index 99c1fde4b..e5530fdda 100644 --- a/session/store_test.go +++ b/session/store_test.go @@ -118,6 +118,33 @@ func TestBasicSessionStore(t *testing.T) { assertEqualSessions(t, s1, sessions[0]) assertEqualSessions(t, s2, sessions[1]) assertEqualSessions(t, s3, sessions[2]) + + // Test that ListSessionsByState works. + sessions, err = db.ListSessionsByState(StateRevoked) + require.NoError(t, err) + require.Equal(t, 1, len(sessions)) + assertEqualSessions(t, s1, sessions[0]) + + sessions, err = db.ListSessionsByState(StateCreated) + require.NoError(t, err) + require.Equal(t, 2, len(sessions)) + assertEqualSessions(t, s2, sessions[0]) + assertEqualSessions(t, s3, sessions[1]) + + sessions, err = db.ListSessionsByState(StateCreated, StateRevoked) + require.NoError(t, err) + require.Equal(t, 3, len(sessions)) + assertEqualSessions(t, s1, sessions[0]) + assertEqualSessions(t, s2, sessions[1]) + assertEqualSessions(t, s3, sessions[2]) + + sessions, err = db.ListSessionsByState() + require.NoError(t, err) + require.Empty(t, sessions) + + sessions, err = db.ListSessionsByState(StateInUse) + require.NoError(t, err) + require.Empty(t, sessions) } // TestLinkingSessions tests that session linking works as expected. From 14cb0be3bc4bae3039a57d61798eedd4348f70ba Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sun, 9 Feb 2025 12:31:04 +0200 Subject: [PATCH 6/6] lit: only fetch active sessions on startup Using the new ListSessions by type method, we no longer need to fetch and iterate through all our sessions on start up to figure out which ones to spin up. --- session_rpcserver.go | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/session_rpcserver.go b/session_rpcserver.go index 92beafb26..e85d3578f 100644 --- a/session_rpcserver.go +++ b/session_rpcserver.go @@ -101,7 +101,10 @@ func newSessionRPCServer(cfg *sessionRpcServerConfig) (*sessionRpcServer, // requests. This includes resuming all non-revoked sessions. func (s *sessionRpcServer) start(ctx context.Context) error { // Start up all previously created sessions. - sessions, err := s.cfg.db.ListAllSessions() + sessions, err := s.cfg.db.ListSessionsByState( + session.StateCreated, + session.StateInUse, + ) if err != nil { return fmt.Errorf("error listing sessions: %v", err) } @@ -126,12 +129,6 @@ func (s *sessionRpcServer) start(ctx context.Context) error { continue } - if sess.State != session.StateInUse && - sess.State != session.StateCreated { - - continue - } - if sess.Expiry.Before(time.Now()) { continue } @@ -345,24 +342,13 @@ func (s *sessionRpcServer) AddSession(ctx context.Context, }, nil } -// resumeSession tries to start an existing session if it is not expired, not -// revoked and a LiT session. +// resumeSession tries to start the given session if it is not expired. func (s *sessionRpcServer) resumeSession(ctx context.Context, sess *session.Session) error { pubKey := sess.LocalPublicKey pubKeyBytes := pubKey.SerializeCompressed() - // We only start non-revoked, non-expired LiT sessions. Everything else - // we just skip. - if sess.State != session.StateInUse && - sess.State != session.StateCreated { - - log.Debugf("Not resuming session %x with state %d", pubKeyBytes, - sess.State) - return nil - } - // Don't resume an expired session. if sess.Expiry.Before(time.Now()) { log.Debugf("Not resuming session %x with expiry %s",