Skip to content

[sql-12] sessions: make ListSession methods SQL ready #970

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions session/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,15 @@ type Store interface {
// GetSession fetches the session with the given key.
GetSession(key *btcec.PublicKey) (*Session, error)

// ListSessions returns all sessions currently known to the store.
ListSessions(filterFn func(s *Session) bool) ([]*Session, error)
// ListAllSessions returns all sessions currently known to the store.
ListAllSessions() ([]*Session, error)

// ListSessionsByType returns all sessions of the given type.
ListSessionsByType(t Type) ([]*Session, error)

// ListSessionsByState returns all sessions currently known to the store
// that are in the given states.
ListSessionsByState(...State) ([]*Session, error)

// RevokeSession updates the state of the session with the given local
// public key to be revoked.
Expand Down
46 changes: 44 additions & 2 deletions session/kvdb_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"os"
"path/filepath"
"sort"
"time"

"github.com/btcsuite/btcd/btcec/v2"
Expand Down Expand Up @@ -363,10 +364,46 @@ func (db *BoltStore) GetSession(key *btcec.PublicKey) (*Session, error) {
return session, nil
}

// ListSessions returns all sessions currently known to the store.
// ListAllSessions returns all sessions currently known to the store.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) ListSessions(filterFn func(s *Session) bool) ([]*Session, error) {
func (db *BoltStore) ListAllSessions() ([]*Session, error) {
return db.listSessions(func(s *Session) bool {
return true
})
}

// ListSessionsByType returns all sessions currently known to the store that
// have the given type.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) ListSessionsByType(t Type) ([]*Session, error) {
return db.listSessions(func(s *Session) bool {
return s.Type == t
})
}

// ListSessionsByState returns all sessions currently known to the store that
// are in the given states.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) ListSessionsByState(states ...State) ([]*Session, error) {
return db.listSessions(func(s *Session) bool {
for _, state := range states {
if s.State == state {
return true
}
}

return false
})
}

// listSessions returns all sessions currently known to the store that pass the
// given filter function.
func (db *BoltStore) listSessions(filterFn func(s *Session) bool) ([]*Session,
error) {

var sessions []*Session
err := db.View(func(tx *bbolt.Tx) error {
sessionBucket, err := getBucket(tx, sessionBucketKey)
Expand Down Expand Up @@ -399,6 +436,11 @@ func (db *BoltStore) ListSessions(filterFn func(s *Session) bool) ([]*Session, e
return nil, err
}

// Make sure to sort the sessions by creation time.
sort.Slice(sessions, func(i, j int) bool {
return sessions[i].CreatedAt.Before(sessions[j].CreatedAt)
})

return sessions, nil
}

Expand Down
115 changes: 96 additions & 19 deletions session/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,17 @@ func TestBasicSessionStore(t *testing.T) {
_ = db.Close()
})

// Create a few sessions.
s1 := newSession(t, db, clock, "session 1", nil)
s2 := newSession(t, db, clock, "session 2", nil)
s3 := newSession(t, db, clock, "session 3", nil)
s4 := newSession(t, db, clock, "session 4", nil)
// Create a few sessions. We increment the time by one second between
// each session to ensure that the created at time is unique and hence
// that the ListSessions method returns the sessions in a deterministic
// order.
s1 := newSession(t, db, clock, "session 1")
clock.SetTime(testTime.Add(time.Second))
s2 := newSession(t, db, clock, "session 2")
clock.SetTime(testTime.Add(2 * time.Second))
s3 := newSession(t, db, clock, "session 3", withType(TypeAutopilot))
Comment on lines +30 to +34
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this really tests that the sessions are sorted by creation time, as one could argue they're just sorted by the order they were added here. So I would ensure that one of the latter sessions we add, say s3, has a creation time prior to s2.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the main reason for incrementing the timestamps is to have deterministic results from ListSessions. We are just testing ListSessions as a whole, not that the results are ordered.

clock.SetTime(testTime.Add(3 * time.Second))
s4 := newSession(t, db, clock, "session 4")

// Persist session 1. This should now succeed.
require.NoError(t, db.CreateSession(s1))
Expand All @@ -50,6 +56,22 @@ func TestBasicSessionStore(t *testing.T) {
require.NoError(t, db.CreateSession(s2))
require.NoError(t, db.CreateSession(s3))

// Test the ListSessionsByType method.
sessions, err := db.ListSessionsByType(TypeMacaroonAdmin)
require.NoError(t, err)
require.Equal(t, 2, len(sessions))
assertEqualSessions(t, s1, sessions[0])
assertEqualSessions(t, s2, sessions[1])

sessions, err = db.ListSessionsByType(TypeAutopilot)
require.NoError(t, err)
require.Equal(t, 1, len(sessions))
assertEqualSessions(t, s3, sessions[0])

sessions, err = db.ListSessionsByType(TypeMacaroonReadonly)
require.NoError(t, err)
require.Empty(t, sessions)

// Ensure that we can retrieve each session by both its local pub key
// and by its ID.
for _, s := range []*Session{s1, s2, s3} {
Expand Down Expand Up @@ -85,9 +107,44 @@ func TestBasicSessionStore(t *testing.T) {

// Now revoke the session and assert that the state is revoked.
require.NoError(t, db.RevokeSession(s1.LocalPublicKey))
session1, err = db.GetSession(s1.LocalPublicKey)
s1, err = db.GetSession(s1.LocalPublicKey)
require.NoError(t, err)
require.Equal(t, s1.State, StateRevoked)

// Test that ListAllSessions works.
sessions, err = db.ListAllSessions()
require.NoError(t, err)
require.Equal(t, 3, len(sessions))
assertEqualSessions(t, s1, sessions[0])
assertEqualSessions(t, s2, sessions[1])
assertEqualSessions(t, s3, sessions[2])

// Test that ListSessionsByState works.
sessions, err = db.ListSessionsByState(StateRevoked)
require.NoError(t, err)
require.Equal(t, 1, len(sessions))
assertEqualSessions(t, s1, sessions[0])

sessions, err = db.ListSessionsByState(StateCreated)
require.NoError(t, err)
require.Equal(t, 2, len(sessions))
assertEqualSessions(t, s2, sessions[0])
assertEqualSessions(t, s3, sessions[1])

sessions, err = db.ListSessionsByState(StateCreated, StateRevoked)
require.NoError(t, err)
require.Equal(t, 3, len(sessions))
assertEqualSessions(t, s1, sessions[0])
assertEqualSessions(t, s2, sessions[1])
assertEqualSessions(t, s3, sessions[2])

sessions, err = db.ListSessionsByState()
require.NoError(t, err)
require.Equal(t, session1.State, StateRevoked)
require.Empty(t, sessions)

sessions, err = db.ListSessionsByState(StateInUse)
require.NoError(t, err)
require.Empty(t, sessions)
}

// TestLinkingSessions tests that session linking works as expected.
Expand All @@ -101,10 +158,10 @@ func TestLinkingSessions(t *testing.T) {
})

// Create a new session with no previous link.
s1 := newSession(t, db, clock, "session 1", nil)
s1 := newSession(t, db, clock, "session 1")

// Create another session and link it to the first.
s2 := newSession(t, db, clock, "session 2", &s1.GroupID)
s2 := newSession(t, db, clock, "session 2", withLinkedGroupID(&s1.GroupID))

// Try to persist the second session and assert that it fails due to the
// linked session not existing in the DB yet.
Expand Down Expand Up @@ -141,9 +198,9 @@ func TestLinkedSessions(t *testing.T) {
// after are all linked to the prior one. All these sessions belong to
// the same group. The group ID is equivalent to the session ID of the
// first session.
s1 := newSession(t, db, clock, "session 1", nil)
s2 := newSession(t, db, clock, "session 2", &s1.GroupID)
s3 := newSession(t, db, clock, "session 3", &s2.GroupID)
s1 := newSession(t, db, clock, "session 1")
s2 := newSession(t, db, clock, "session 2", withLinkedGroupID(&s1.GroupID))
s3 := newSession(t, db, clock, "session 3", withLinkedGroupID(&s2.GroupID))

// Persist the sessions.
require.NoError(t, db.CreateSession(s1))
Expand All @@ -169,8 +226,8 @@ func TestLinkedSessions(t *testing.T) {

// To ensure that different groups don't interfere with each other,
// let's add another set of linked sessions not linked to the first.
s4 := newSession(t, db, clock, "session 4", nil)
s5 := newSession(t, db, clock, "session 5", &s4.GroupID)
s4 := newSession(t, db, clock, "session 4")
s5 := newSession(t, db, clock, "session 5", withLinkedGroupID(&s4.GroupID))

require.NotEqual(t, s4.GroupID, s1.GroupID)

Expand Down Expand Up @@ -209,7 +266,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
// function is checked correctly.

// Add a new session to the DB.
s1 := newSession(t, db, clock, "label 1", nil)
s1 := newSession(t, db, clock, "label 1")
require.NoError(t, db.CreateSession(s1))

// Check that the group passes against an appropriate predicate.
Expand All @@ -234,7 +291,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
require.NoError(t, db.RevokeSession(s1.LocalPublicKey))

// Add a new session to the same group as the first one.
s2 := newSession(t, db, clock, "label 2", &s1.GroupID)
s2 := newSession(t, db, clock, "label 2", withLinkedGroupID(&s1.GroupID))
require.NoError(t, db.CreateSession(s2))

// Check that the group passes against an appropriate predicate.
Expand All @@ -256,7 +313,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
require.False(t, ok)

// Add a new session that is not linked to the first one.
s3 := newSession(t, db, clock, "completely different", nil)
s3 := newSession(t, db, clock, "completely different")
require.NoError(t, db.CreateSession(s3))

// Ensure that the first group is unaffected.
Expand Down Expand Up @@ -286,8 +343,24 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
require.True(t, ok)
}

// testSessionModifier is a functional option that can be used to modify the
// default test session created by newSession.
type testSessionModifier func(*Session)

func withLinkedGroupID(groupID *ID) testSessionModifier {
return func(s *Session) {
s.GroupID = *groupID
}
}

func withType(t Type) testSessionModifier {
return func(s *Session) {
s.Type = t
}
}

func newSession(t *testing.T, db Store, clock clock.Clock, label string,
linkedGroupID *ID) *Session {
mods ...testSessionModifier) *Session {

id, priv, err := db.GetUnusedIDAndKeyPair()
require.NoError(t, err)
Expand All @@ -296,11 +369,15 @@ func newSession(t *testing.T, db Store, clock clock.Clock, label string,
id, priv, label, TypeMacaroonAdmin,
clock.Now(),
time.Date(99999, 1, 1, 0, 0, 0, 0, time.UTC),
"foo.bar.baz:1234", true, nil, nil, nil, true, linkedGroupID,
"foo.bar.baz:1234", true, nil, nil, nil, true, nil,
[]PrivacyFlag{ClearPubkeys},
)
require.NoError(t, err)

for _, mod := range mods {
mod(session)
}

return session
}

Expand Down
30 changes: 7 additions & 23 deletions session_rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ func newSessionRPCServer(cfg *sessionRpcServerConfig) (*sessionRpcServer,
// requests. This includes resuming all non-revoked sessions.
func (s *sessionRpcServer) start(ctx context.Context) error {
// Start up all previously created sessions.
sessions, err := s.cfg.db.ListSessions(nil)
sessions, err := s.cfg.db.ListSessionsByState(
session.StateCreated,
session.StateInUse,
)
if err != nil {
return fmt.Errorf("error listing sessions: %v", err)
}
Expand All @@ -126,12 +129,6 @@ func (s *sessionRpcServer) start(ctx context.Context) error {
continue
}

if sess.State != session.StateInUse &&
sess.State != session.StateCreated {

continue
}

if sess.Expiry.Before(time.Now()) {
continue
}
Expand Down Expand Up @@ -345,24 +342,13 @@ func (s *sessionRpcServer) AddSession(ctx context.Context,
}, nil
}

// resumeSession tries to start an existing session if it is not expired, not
// revoked and a LiT session.
// resumeSession tries to start the given session if it is not expired.
func (s *sessionRpcServer) resumeSession(ctx context.Context,
sess *session.Session) error {

pubKey := sess.LocalPublicKey
pubKeyBytes := pubKey.SerializeCompressed()

// We only start non-revoked, non-expired LiT sessions. Everything else
// we just skip.
if sess.State != session.StateInUse &&
sess.State != session.StateCreated {

log.Debugf("Not resuming session %x with state %d", pubKeyBytes,
sess.State)
return nil
}
Comment on lines -356 to -364
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it safe to assume that the passed in session will not have an revoked/expired state? I see that all current calls to resumeSession ensure that the sesion is in a valid state, but I'm thinking about future code that may not have these safeguards in place.

If you think it's best to still remove this, then the function comment should be updated.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool yeah currently all the session passed to this method will already be in the Created/InUse state since it is either called from start which now only fetches sessions in these states or it is called on session creation which will defs have the CreatedState.

GOod call - will update the comment 👍


// Don't resume an expired session.
if sess.Expiry.Before(time.Now()) {
log.Debugf("Not resuming session %x with expiry %s",
Expand Down Expand Up @@ -536,7 +522,7 @@ func (s *sessionRpcServer) resumeSession(ctx context.Context,
func (s *sessionRpcServer) ListSessions(_ context.Context,
_ *litrpc.ListSessionsRequest) (*litrpc.ListSessionsResponse, error) {

sessions, err := s.cfg.db.ListSessions(nil)
sessions, err := s.cfg.db.ListAllSessions()
if err != nil {
return nil, fmt.Errorf("error fetching sessions: %v", err)
}
Expand Down Expand Up @@ -1259,9 +1245,7 @@ func (s *sessionRpcServer) ListAutopilotSessions(_ context.Context,
_ *litrpc.ListAutopilotSessionsRequest) (
*litrpc.ListAutopilotSessionsResponse, error) {

sessions, err := s.cfg.db.ListSessions(func(s *session.Session) bool {
return s.Type == session.TypeAutopilot
})
sessions, err := s.cfg.db.ListSessionsByType(session.TypeAutopilot)
if err != nil {
return nil, fmt.Errorf("error fetching sessions: %v", err)
}
Expand Down
Loading