From 013e7c081cedf8773c25cd6600e7e7f925a60bc1 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sun, 9 Feb 2025 13:41:18 +0200 Subject: [PATCH 1/2] session: introduce Reserve->Create pattern In this commit, we let StateReserved be the new initial state of a session for when NewSession is called. We then do predicate checks for linked sessions along with unique session alias (ID) and priv key derivations all under the same DB transaction in NewSession. ShiftState then moves a session to StateCreated. Only in StateCreated does a session become usable. With this change, we no longer need to ensure atomic session creation by acquiring the `sessRegMu` mutex in the session RPC server. --- session/interface.go | 47 ++++------ session/kvdb_store.go | 116 +++++++++++------------ session/store_test.go | 207 +++++++++++++++++++++--------------------- session_rpcserver.go | 62 +++++++------ 4 files changed, 205 insertions(+), 227 deletions(-) diff --git a/session/interface.go b/session/interface.go index 4898b882c..ab67f9207 100644 --- a/session/interface.go +++ b/session/interface.go @@ -27,15 +27,15 @@ const ( type State uint8 /* - /---> StateExpired (terminal) -StateCreated --- - \---> StateRevoked (terminal) + /---> StateExpired (terminal) +StateReserved ---> StateCreated --- + \---> StateRevoked (terminal) */ const ( // StateCreated is the state of a session once it has been fully - // committed to the Store and is ready to be used. This is the first - // state of a session. + // committed to the BoltStore and is ready to be used. This is the + // first state after StateReserved. StateCreated State = 0 // StateInUse is the state of a session that is currently being used. @@ -52,10 +52,10 @@ const ( // date. StateExpired State = 3 - // StateReserved is a temporary initial state of a session. On start-up, - // any sessions in this state should be cleaned up. - // - // NOTE: this isn't used yet. + // StateReserved is a temporary initial state of a session. This is used + // to reserve a unique ID and private key pair for a session before it + // is fully created. On start-up, any sessions in this state should be + // cleaned up. StateReserved State = 4 ) @@ -67,6 +67,9 @@ func (s State) Terminal() bool { // legalStateShifts is a map that defines the legal State transitions that a // Session can be put through. var legalStateShifts = map[State]map[State]bool{ + StateReserved: { + StateCreated: true, + }, StateCreated: { StateExpired: true, StateRevoked: true, @@ -141,7 +144,7 @@ func buildSession(id ID, localPrivKey *btcec.PrivateKey, label string, typ Type, sess := &Session{ ID: id, Label: label, - State: StateCreated, + State: StateReserved, Type: typ, Expiry: expiry.UTC(), CreatedAt: created.UTC(), @@ -185,23 +188,13 @@ type IDToGroupIndex interface { // retrieving Terminal Connect sessions. type Store interface { // NewSession creates a new session with the given user-defined - // parameters. - // - // NOTE: currently this purely a constructor of the Session type and - // does not make any database calls. This will be changed in a future - // commit. - NewSession(id ID, localPrivKey *btcec.PrivateKey, label string, - typ Type, expiry time.Time, serverAddr string, devServer bool, - perms []bakery.Op, caveats []macaroon.Caveat, + // parameters. The session will remain in the StateReserved state until + // ShiftState is called to update the state. + NewSession(label string, typ Type, expiry time.Time, serverAddr string, + devServer bool, perms []bakery.Op, caveats []macaroon.Caveat, featureConfig FeaturesConfig, privacy bool, linkedGroupID *ID, flags PrivacyFlags) (*Session, error) - // CreateSession adds a new session to the store. If a session with the - // same local public key already exists an error is returned. This - // can only be called with a Session with an ID that the Store has - // reserved. - CreateSession(*Session) error - // GetSession fetches the session with the given key. GetSession(key *btcec.PublicKey) (*Session, error) @@ -220,12 +213,6 @@ type Store interface { UpdateSessionRemotePubKey(localPubKey, remotePubKey *btcec.PublicKey) error - // GetUnusedIDAndKeyPair can be used to generate a new, unused, local - // private key and session ID pair. Care must be taken to ensure that no - // other thread calls this before the returned ID and key pair from this - // method are either used or discarded. - GetUnusedIDAndKeyPair() (ID, *btcec.PrivateKey, error) - // GetSessionByID fetches the session with the given ID. GetSessionByID(id ID) (*Session, error) diff --git a/session/kvdb_store.go b/session/kvdb_store.go index 60bc62a04..f40c139ec 100644 --- a/session/kvdb_store.go +++ b/session/kvdb_store.go @@ -182,41 +182,42 @@ func getSessionKey(session *Session) []byte { return session.LocalPublicKey.SerializeCompressed() } -// NewSession creates a new session with the given user-defined parameters. -// -// NOTE: currently this purely a constructor of the Session type and does not -// make any database calls. This will be changed in a future commit. +// NewSession creates and persists a new session with the given user-defined +// parameters. The initial state of the session will be Reserved until +// ShiftState is called with StateCreated. // // NOTE: this is part of the Store interface. -func (db *BoltStore) NewSession(id ID, localPrivKey *btcec.PrivateKey, - label string, typ Type, expiry time.Time, serverAddr string, - devServer bool, perms []bakery.Op, caveats []macaroon.Caveat, - featureConfig FeaturesConfig, privacy bool, linkedGroupID *ID, - flags PrivacyFlags) (*Session, error) { - - return buildSession( - id, localPrivKey, label, typ, db.clock.Now(), expiry, - serverAddr, devServer, perms, caveats, featureConfig, privacy, - linkedGroupID, flags, - ) -} +func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time, + serverAddr string, devServer bool, perms []bakery.Op, + caveats []macaroon.Caveat, featureConfig FeaturesConfig, privacy bool, + linkedGroupID *ID, flags PrivacyFlags) (*Session, error) { -// CreateSession adds a new session to the store. If a session with the same -// local public key already exists an error is returned. -// -// NOTE: this is part of the Store interface. -func (db *BoltStore) CreateSession(session *Session) error { - sessionKey := getSessionKey(session) - - return db.Update(func(tx *bbolt.Tx) error { + var session *Session + err := db.Update(func(tx *bbolt.Tx) error { sessionBucket, err := getBucket(tx, sessionBucketKey) if err != nil { return err } + id, localPrivKey, err := getUnusedIDAndKeyPair(sessionBucket) + if err != nil { + return err + } + + session, err = buildSession( + id, localPrivKey, label, typ, db.clock.Now(), expiry, + serverAddr, devServer, perms, caveats, featureConfig, + privacy, linkedGroupID, flags, + ) + if err != nil { + return err + } + + sessionKey := getSessionKey(session) + if len(sessionBucket.Get(sessionKey)) != 0 { - return fmt.Errorf("session with local public "+ - "key(%x) already exists", + return fmt.Errorf("session with local public key(%x) "+ + "already exists", session.LocalPublicKey.SerializeCompressed()) } @@ -275,6 +276,11 @@ func (db *BoltStore) CreateSession(session *Session) error { return putSession(sessionBucket, session) }) + if err != nil { + return nil, err + } + + return session, nil } // UpdateSessionRemotePubKey can be used to add the given remote pub key @@ -577,53 +583,35 @@ func (db *BoltStore) GetSessionByID(id ID) (*Session, error) { return session, nil } -// GetUnusedIDAndKeyPair can be used to generate a new, unused, local private +// getUnusedIDAndKeyPair can be used to generate a new, unused, local private // key and session ID pair. Care must be taken to ensure that no other thread // calls this before the returned ID and key pair from this method are either // used or discarded. -// -// NOTE: this is part of the Store interface. -func (db *BoltStore) GetUnusedIDAndKeyPair() (ID, *btcec.PrivateKey, error) { - var ( - id ID - privKey *btcec.PrivateKey - ) - err := db.Update(func(tx *bbolt.Tx) error { - sessionBucket, err := getBucket(tx, sessionBucketKey) - if err != nil { - return err - } - - idIndexBkt := sessionBucket.Bucket(idIndexKey) - if idIndexBkt == nil { - return ErrDBInitErr - } +func getUnusedIDAndKeyPair(bucket *bbolt.Bucket) (ID, *btcec.PrivateKey, + error) { - // Spin until we find a key with an ID that does not collide - // with any of our existing IDs. - for { - // Generate a new private key and ID pair. - privKey, id, err = NewSessionPrivKeyAndID() - if err != nil { - return err - } + idIndexBkt := bucket.Bucket(idIndexKey) + if idIndexBkt == nil { + return ID{}, nil, ErrDBInitErr + } - // Check that no such ID exits in our id-to-key index. - idBkt := idIndexBkt.Bucket(id[:]) - if idBkt != nil { - continue - } + // Spin until we find a key with an ID that does not collide with any of + // our existing IDs. + for { + // Generate a new private key and ID pair. + privKey, id, err := NewSessionPrivKeyAndID() + if err != nil { + return ID{}, nil, err + } - break + // Check that no such ID exits in our id-to-key index. + idBkt := idIndexBkt.Bucket(id[:]) + if idBkt != nil { + continue } - return nil - }) - if err != nil { - return id, nil, err + return id, privKey, nil } - - return id, privKey, nil } // GetGroupID will return the group ID for the given session ID. diff --git a/session/store_test.go b/session/store_test.go index db6368450..d4b0a5c54 100644 --- a/session/store_test.go +++ b/session/store_test.go @@ -23,38 +23,36 @@ func TestBasicSessionStore(t *testing.T) { _ = db.Close() }) - // Create a few sessions. We increment the time by one second between - // each session to ensure that the created at time is unique and hence - // that the ListSessions method returns the sessions in a deterministic - // order. - s1 := newSession(t, db, clock, "session 1") - clock.SetTime(testTime.Add(time.Second)) - s2 := newSession(t, db, clock, "session 2") - clock.SetTime(testTime.Add(2 * time.Second)) - s3 := newSession(t, db, clock, "session 3", withType(TypeAutopilot)) - clock.SetTime(testTime.Add(3 * time.Second)) - s4 := newSession(t, db, clock, "session 4") + // Reserve a session. This should succeed. + s1, err := reserveSession(db, "session 1") + require.NoError(t, err) - // Persist session 1. This should now succeed. - require.NoError(t, db.CreateSession(s1)) + // Show that the session starts in the reserved state. + s1, err = db.GetSessionByID(s1.ID) + require.NoError(t, err) + require.Equal(t, StateReserved, s1.State) - // Trying to persist session 1 again should fail due to a session with - // the given pub key already existing. - require.ErrorContains(t, db.CreateSession(s1), "already exists") + // Move session 1 to the created state. This should succeed. + err = db.ShiftState(s1.ID, StateCreated) + require.NoError(t, err) - // Change the local pub key of session 4 such that it has the same - // ID as session 1. - s4.ID = s1.ID - s4.GroupID = s1.GroupID + // Show that the session is now in the created state. + s1, err = db.GetSessionByID(s1.ID) + require.NoError(t, err) + require.Equal(t, StateCreated, s1.State) - // Now try to insert session 4. This should fail due to an entry for - // the ID already existing. - require.ErrorContains(t, db.CreateSession(s4), "a session with the "+ - "given ID already exists") + // Trying to move session 1 again should have no effect since it is + // already in the created state. + require.NoError(t, db.ShiftState(s1.ID, StateCreated)) - // Persist a few more sessions. - require.NoError(t, db.CreateSession(s2)) - require.NoError(t, db.CreateSession(s3)) + // Reserve and create a few more sessions. We increment the time by one + // second between each session to ensure that the created at time is + // unique and hence that the ListSessions method returns the sessions in + // a deterministic order. + clock.SetTime(testTime.Add(time.Second)) + s2 := createSession(t, db, "session 2") + clock.SetTime(testTime.Add(2 * time.Second)) + s3 := createSession(t, db, "session 3", withType(TypeAutopilot)) // Test the ListSessionsByType method. sessions, err := db.ListSessionsByType(TypeMacaroonAdmin) @@ -156,28 +154,26 @@ func TestBasicSessionStore(t *testing.T) { require.NoError(t, err) require.Empty(t, sessions) - // Add a session and put it in the StateReserved state. We'll also - // link it to session 1. - s5 := newSession( - t, db, clock, "session 5", withState(StateReserved), - withLinkedGroupID(&session1.GroupID), + // Reserve a new session and link it to session 1. + s4, err := reserveSession( + db, "session 4", withLinkedGroupID(&session1.GroupID), ) - require.NoError(t, db.CreateSession(s5)) + require.NoError(t, err) sessions, err = db.ListSessionsByState(StateReserved) require.NoError(t, err) require.Equal(t, 1, len(sessions)) - assertEqualSessions(t, s5, sessions[0]) + assertEqualSessions(t, s4, sessions[0]) // Show that the group ID/session ID index has also been populated with // this session. - groupID, err := db.GetGroupID(s5.ID) + groupID, err := db.GetGroupID(s4.ID) require.NoError(t, err) require.Equal(t, s1.ID, groupID) - sessIDs, err := db.GetSessionIDs(s5.GroupID) + sessIDs, err := db.GetSessionIDs(s4.GroupID) require.NoError(t, err) - require.ElementsMatch(t, []ID{s5.ID, s1.ID}, sessIDs) + require.ElementsMatch(t, []ID{s4.ID, s1.ID}, sessIDs) // Now delete the reserved session and show that it is no longer in the // database and no longer in the group ID/session ID index. @@ -187,17 +183,19 @@ func TestBasicSessionStore(t *testing.T) { require.NoError(t, err) require.Empty(t, sessions) - _, err = db.GetGroupID(s5.ID) + _, err = db.GetGroupID(s4.ID) require.ErrorContains(t, err, "no index entry") // Only session 1 should remain in this group. - sessIDs, err = db.GetSessionIDs(s5.GroupID) + sessIDs, err = db.GetSessionIDs(s4.GroupID) require.NoError(t, err) require.ElementsMatch(t, []ID{s1.ID}, sessIDs) } // TestLinkingSessions tests that session linking works as expected. func TestLinkingSessions(t *testing.T) { + t.Parallel() + // Set up a new DB. clock := clock.NewTestClock(testTime) db, err := NewDB(t.TempDir(), "test.db", clock) @@ -206,35 +204,39 @@ func TestLinkingSessions(t *testing.T) { _ = db.Close() }) - // Create a new session with no previous link. - s1 := newSession(t, db, clock, "session 1") - - // Create another session and link it to the first. - s2 := newSession(t, db, clock, "session 2", withLinkedGroupID(&s1.GroupID)) + groupID, err := IDFromBytes([]byte{1, 2, 3, 4}) + require.NoError(t, err) - // Try to persist the second session and assert that it fails due to the - // linked session not existing in the DB yet. - require.ErrorContains(t, db.CreateSession(s2), "unknown linked session") + // Try to reserve a session that links to another and assert that it + // fails due to the linked session not existing in the BoltStore yet. + _, err = reserveSession( + db, "session 2", withLinkedGroupID(&groupID), + ) + require.ErrorContains(t, err, "unknown linked session") - // Now persist the first session and retry persisting the second one - // and assert that this now works. - require.NoError(t, db.CreateSession(s1)) + // Create a new session with no previous link. + s1 := createSession(t, db, "session 1") - // Persisting the second session immediately should fail due to the - // first session still being active. - require.ErrorContains(t, db.CreateSession(s2), "is still active") + // Once again try to reserve a session that links to the now existing + // session. This should fail due to the first session still being + // active. + _, err = reserveSession(db, "session 2", withLinkedGroupID(&s1.GroupID)) + require.ErrorContains(t, err, "is still active") // Revoke the first session. require.NoError(t, db.ShiftState(s1.ID, StateRevoked)) // Persisting the second linked session should now work. - require.NoError(t, db.CreateSession(s2)) + _, err = reserveSession(db, "session 2", withLinkedGroupID(&s1.GroupID)) + require.NoError(t, err) } // TestIDToGroupIDIndex tests that the session-ID-to-group-ID and // group-ID-to-session-ID indexes work as expected by asserting the behaviour // of the GetGroupID and GetSessionIDs methods. func TestLinkedSessions(t *testing.T) { + t.Parallel() + // Set up a new DB. clock := clock.NewTestClock(testTime) db, err := NewDB(t.TempDir(), "test.db", clock) @@ -247,22 +249,13 @@ func TestLinkedSessions(t *testing.T) { // after are all linked to the prior one. All these sessions belong to // the same group. The group ID is equivalent to the session ID of the // first session. - s1 := newSession(t, db, clock, "session 1") - s2 := newSession( - t, db, clock, "session 2", withLinkedGroupID(&s1.GroupID), - ) - s3 := newSession( - t, db, clock, "session 3", withLinkedGroupID(&s2.GroupID), - ) - - // Persist the sessions. - require.NoError(t, db.CreateSession(s1)) + s1 := createSession(t, db, "session 1") require.NoError(t, db.ShiftState(s1.ID, StateRevoked)) - require.NoError(t, db.CreateSession(s2)) + s2 := createSession(t, db, "session 2", withLinkedGroupID(&s1.GroupID)) require.NoError(t, db.ShiftState(s2.ID, StateRevoked)) - require.NoError(t, db.CreateSession(s3)) + s3 := createSession(t, db, "session 3", withLinkedGroupID(&s2.GroupID)) // Assert that the session ID to group ID index works as expected. for _, s := range []*Session{s1, s2, s3} { @@ -279,16 +272,10 @@ func TestLinkedSessions(t *testing.T) { // To ensure that different groups don't interfere with each other, // let's add another set of linked sessions not linked to the first. - s4 := newSession(t, db, clock, "session 4") - s5 := newSession(t, db, clock, "session 5", withLinkedGroupID(&s4.GroupID)) - - require.NotEqual(t, s4.GroupID, s1.GroupID) - - // Persist the sessions. - require.NoError(t, db.CreateSession(s4)) + s4 := createSession(t, db, "session 4") require.NoError(t, db.ShiftState(s4.ID, StateRevoked)) - - require.NoError(t, db.CreateSession(s5)) + s5 := createSession(t, db, "session 5", withLinkedGroupID(&s4.GroupID)) + require.NotEqual(t, s4.GroupID, s1.GroupID) // Assert that the session ID to group ID index works as expected. for _, s := range []*Session{s4, s5} { @@ -307,6 +294,8 @@ func TestLinkedSessions(t *testing.T) { // TestCheckSessionGroupPredicate asserts that the CheckSessionGroupPredicate // method correctly checks if each session in a group passes a predicate. func TestCheckSessionGroupPredicate(t *testing.T) { + t.Parallel() + // Set up a new DB. clock := clock.NewTestClock(testTime) db, err := NewDB(t.TempDir(), "test.db", clock) @@ -319,8 +308,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) { // function is checked correctly. // Add a new session to the DB. - s1 := newSession(t, db, clock, "label 1") - require.NoError(t, db.CreateSession(s1)) + s1 := createSession(t, db, "label 1") // Check that the group passes against an appropriate predicate. ok, err := db.CheckSessionGroupPredicate( @@ -344,8 +332,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) { require.NoError(t, db.ShiftState(s1.ID, StateRevoked)) // Add a new session to the same group as the first one. - s2 := newSession(t, db, clock, "label 2", withLinkedGroupID(&s1.GroupID)) - require.NoError(t, db.CreateSession(s2)) + _ = createSession(t, db, "label 2", withLinkedGroupID(&s1.GroupID)) // Check that the group passes against an appropriate predicate. ok, err = db.CheckSessionGroupPredicate( @@ -366,8 +353,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) { require.False(t, ok) // Add a new session that is not linked to the first one. - s3 := newSession(t, db, clock, "completely different") - require.NoError(t, db.CreateSession(s3)) + s3 := createSession(t, db, "completely different") // Ensure that the first group is unaffected. ok, err = db.CheckSessionGroupPredicate( @@ -407,8 +393,7 @@ func TestStateShift(t *testing.T) { }) // Add a new session to the DB. - s1 := newSession(t, db, clock, "label 1") - require.NoError(t, db.CreateSession(s1)) + s1 := createSession(t, db, "label 1") // Check that the session is in the StateCreated state. Also check that // the "RevokedAt" time has not yet been set. @@ -443,48 +428,62 @@ func TestStateShift(t *testing.T) { require.ErrorContains(t, err, "illegal session state transition") } +type testSessionOpts struct { + groupID *ID + sessType Type +} + +func defaultTestSessOpts() *testSessionOpts { + return &testSessionOpts{ + groupID: nil, + sessType: TypeMacaroonAdmin, + } +} + // testSessionModifier is a functional option that can be used to modify the // default test session created by newSession. -type testSessionModifier func(*Session) +type testSessionModifier func(*testSessionOpts) func withLinkedGroupID(groupID *ID) testSessionModifier { - return func(s *Session) { - s.GroupID = *groupID + return func(s *testSessionOpts) { + s.groupID = groupID } } func withType(t Type) testSessionModifier { - return func(s *Session) { - s.Type = t + return func(s *testSessionOpts) { + s.sessType = t } } -func withState(state State) testSessionModifier { - return func(s *Session) { - s.State = state +func reserveSession(db Store, label string, + mods ...testSessionModifier) (*Session, error) { + + opts := defaultTestSessOpts() + for _, mod := range mods { + mod(opts) } + + return db.NewSession(label, opts.sessType, + time.Date(99999, 1, 1, 0, 0, 0, 0, time.UTC), + "foo.bar.baz:1234", true, nil, nil, nil, true, opts.groupID, + []PrivacyFlag{ClearPubkeys}, + ) } -func newSession(t *testing.T, db Store, clock clock.Clock, label string, +func createSession(t *testing.T, db Store, label string, mods ...testSessionModifier) *Session { - id, priv, err := db.GetUnusedIDAndKeyPair() + s, err := reserveSession(db, label, mods...) require.NoError(t, err) - session, err := buildSession( - id, priv, label, TypeMacaroonAdmin, - clock.Now(), - time.Date(99999, 1, 1, 0, 0, 0, 0, time.UTC), - "foo.bar.baz:1234", true, nil, nil, nil, true, nil, - []PrivacyFlag{ClearPubkeys}, - ) + err = db.ShiftState(s.ID, StateCreated) require.NoError(t, err) - for _, mod := range mods { - mod(session) - } + s, err = db.GetSessionByID(s.ID) + require.NoError(t, err) - return session + return s } func assertEqualSessions(t *testing.T, expected, actual *Session) { diff --git a/session_rpcserver.go b/session_rpcserver.go index e543b8bb2..fcb4269f9 100644 --- a/session_rpcserver.go +++ b/session_rpcserver.go @@ -46,11 +46,6 @@ type sessionRpcServer struct { cfg *sessionRpcServerConfig sessionServer *session.Server - // sessRegMu is a mutex that should be held between acquiring an unused - // session ID and key pair from the session store and persisting that - // new session. - sessRegMu sync.Mutex - quit chan struct{} wg sync.WaitGroup stopOnce sync.Once @@ -313,16 +308,8 @@ func (s *sessionRpcServer) AddSession(ctx context.Context, } } - s.sessRegMu.Lock() - defer s.sessRegMu.Unlock() - - id, localPrivKey, err := s.cfg.db.GetUnusedIDAndKeyPair() - if err != nil { - return nil, err - } - sess, err := s.cfg.db.NewSession( - id, localPrivKey, req.Label, typ, expiry, req.MailboxServerAddr, + req.Label, typ, expiry, req.MailboxServerAddr, req.DevServer, uniquePermissions, caveats, nil, false, nil, session.PrivacyFlags{}, ) @@ -330,14 +317,23 @@ func (s *sessionRpcServer) AddSession(ctx context.Context, return nil, fmt.Errorf("error creating new session: %v", err) } - if err := s.cfg.db.CreateSession(sess); err != nil { - return nil, fmt.Errorf("error storing session: %v", err) + err = s.cfg.db.ShiftState(sess.ID, session.StateCreated) + if err != nil { + return nil, fmt.Errorf("error shifting session state to "+ + "Created: %v", err) } if err := s.resumeSession(ctx, sess); err != nil { return nil, fmt.Errorf("error starting session: %v", err) } + // Re-fetch the session to get the latest state of it before marshaling + // it. + sess, err = s.cfg.db.GetSessionByID(sess.ID) + if err != nil { + return nil, fmt.Errorf("error fetching session: %v", err) + } + rpcSession, err := s.marshalRPCSession(sess) if err != nil { return nil, fmt.Errorf("error marshaling session: %v", err) @@ -1141,16 +1137,8 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, caveats = append(caveats, firewall.MetaPrivacyCaveat) } - s.sessRegMu.Lock() - defer s.sessRegMu.Unlock() - - id, localPrivKey, err := s.cfg.db.GetUnusedIDAndKeyPair() - if err != nil { - return nil, err - } - sess, err := s.cfg.db.NewSession( - id, localPrivKey, req.Label, session.TypeAutopilot, expiry, + req.Label, session.TypeAutopilot, expiry, req.MailboxServerAddr, req.DevServer, perms, caveats, clientConfig, privacy, linkedGroupID, privacyFlags, ) @@ -1233,17 +1221,33 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, "autopilot server: %v", err) } - // We only persist this session if we successfully retrieved the - // autopilot's static key. + err = s.cfg.db.UpdateSessionRemotePubKey(sess.LocalPublicKey, remoteKey) + if err != nil { + return nil, fmt.Errorf("error setting remote pubkey: %v", err) + } + + // Update our in-memory session with the remote key. sess.RemotePublicKey = remoteKey - if err := s.cfg.db.CreateSession(sess); err != nil { - return nil, fmt.Errorf("error storing session: %v", err) + + // We only activate the session if the Autopilot server registration + // was successful. + err = s.cfg.db.ShiftState(sess.ID, session.StateCreated) + if err != nil { + return nil, fmt.Errorf("error shifting session state to "+ + "Created: %v", err) } if err := s.resumeSession(ctx, sess); err != nil { return nil, fmt.Errorf("error starting session: %v", err) } + // Re-fetch the session to get the latest state of it before marshaling + // it. + sess, err = s.cfg.db.GetSessionByID(sess.ID) + if err != nil { + return nil, fmt.Errorf("error fetching session: %v", err) + } + rpcSession, err := s.marshalRPCSession(sess) if err != nil { return nil, fmt.Errorf("error marshaling session: %v", err) From 32a34d124598ffe406c16f673b8e84b66fbde6cb Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sun, 9 Feb 2025 13:47:19 +0200 Subject: [PATCH 2/2] session: remove Session Group Predicate method This was used to check that all linked sessions are no longer active before attempting to register an autopilot session. But this is no longer needed since this is done within NewSession. --- session/interface.go | 6 --- session/kvdb_store.go | 63 +---------------------------- session/store_test.go | 92 ------------------------------------------- session_rpcserver.go | 17 -------- 4 files changed, 1 insertion(+), 177 deletions(-) diff --git a/session/interface.go b/session/interface.go index ab67f9207..ad94a1f1c 100644 --- a/session/interface.go +++ b/session/interface.go @@ -216,12 +216,6 @@ type Store interface { // GetSessionByID fetches the session with the given ID. GetSessionByID(id ID) (*Session, error) - // CheckSessionGroupPredicate iterates over all the sessions in a group - // and checks if each one passes the given predicate function. True is - // returned if each session passes. - CheckSessionGroupPredicate(groupID ID, - fn func(s *Session) bool) (bool, error) - // DeleteReservedSessions deletes all sessions that are in the // StateReserved state. DeleteReservedSessions() error diff --git a/session/kvdb_store.go b/session/kvdb_store.go index f40c139ec..84f4dce06 100644 --- a/session/kvdb_store.go +++ b/session/kvdb_store.go @@ -249,9 +249,7 @@ func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time, } // Ensure that the session is no longer active. - if sess.State == StateCreated || - sess.State == StateInUse { - + if !sess.State.Terminal() { return fmt.Errorf("session (id=%x) "+ "in group %x is still active", sess.ID, sess.GroupID) @@ -679,65 +677,6 @@ func (db *BoltStore) GetSessionIDs(groupID ID) ([]ID, error) { return sessionIDs, nil } -// CheckSessionGroupPredicate iterates over all the sessions in a group and -// checks if each one passes the given predicate function. True is returned if -// each session passes. -// -// NOTE: this is part of the Store interface. -func (db *BoltStore) CheckSessionGroupPredicate(groupID ID, - fn func(s *Session) bool) (bool, error) { - - var ( - pass bool - errFailedPred = errors.New("session failed predicate") - ) - err := db.View(func(tx *bbolt.Tx) error { - sessionBkt, err := getBucket(tx, sessionBucketKey) - if err != nil { - return err - } - - sessionIDs, err := getSessionIDs(sessionBkt, groupID) - if err != nil { - return err - } - - // Iterate over all the sessions. - for _, id := range sessionIDs { - key, err := getKeyForID(sessionBkt, id) - if err != nil { - return err - } - - v := sessionBkt.Get(key) - if len(v) == 0 { - return ErrSessionNotFound - } - - session, err := DeserializeSession(bytes.NewReader(v)) - if err != nil { - return err - } - - if !fn(session) { - return errFailedPred - } - } - - pass = true - - return nil - }) - if errors.Is(err, errFailedPred) { - return pass, nil - } - if err != nil { - return pass, err - } - - return pass, nil -} - // getSessionIDs returns all the session IDs associated with the given group ID. func getSessionIDs(sessionBkt *bbolt.Bucket, groupID ID) ([]ID, error) { var sessionIDs []ID diff --git a/session/store_test.go b/session/store_test.go index d4b0a5c54..aa7afe20f 100644 --- a/session/store_test.go +++ b/session/store_test.go @@ -1,7 +1,6 @@ package session import ( - "strings" "testing" "time" @@ -291,97 +290,6 @@ func TestLinkedSessions(t *testing.T) { require.EqualValues(t, []ID{s4.ID, s5.ID}, sIDs) } -// TestCheckSessionGroupPredicate asserts that the CheckSessionGroupPredicate -// method correctly checks if each session in a group passes a predicate. -func TestCheckSessionGroupPredicate(t *testing.T) { - t.Parallel() - - // Set up a new DB. - clock := clock.NewTestClock(testTime) - db, err := NewDB(t.TempDir(), "test.db", clock) - require.NoError(t, err) - t.Cleanup(func() { - _ = db.Close() - }) - - // We will use the Label of the Session to test that the predicate - // function is checked correctly. - - // Add a new session to the DB. - s1 := createSession(t, db, "label 1") - - // Check that the group passes against an appropriate predicate. - ok, err := db.CheckSessionGroupPredicate( - s1.GroupID, func(s *Session) bool { - return strings.Contains(s.Label, "label 1") - }, - ) - require.NoError(t, err) - require.True(t, ok) - - // Check that the group fails against an appropriate predicate. - ok, err = db.CheckSessionGroupPredicate( - s1.GroupID, func(s *Session) bool { - return strings.Contains(s.Label, "label 2") - }, - ) - require.NoError(t, err) - require.False(t, ok) - - // Revoke the first session. - require.NoError(t, db.ShiftState(s1.ID, StateRevoked)) - - // Add a new session to the same group as the first one. - _ = createSession(t, db, "label 2", withLinkedGroupID(&s1.GroupID)) - - // Check that the group passes against an appropriate predicate. - ok, err = db.CheckSessionGroupPredicate( - s1.GroupID, func(s *Session) bool { - return strings.Contains(s.Label, "label") - }, - ) - require.NoError(t, err) - require.True(t, ok) - - // Check that the group fails against an appropriate predicate. - ok, err = db.CheckSessionGroupPredicate( - s1.GroupID, func(s *Session) bool { - return strings.Contains(s.Label, "label 1") - }, - ) - require.NoError(t, err) - require.False(t, ok) - - // Add a new session that is not linked to the first one. - s3 := createSession(t, db, "completely different") - - // Ensure that the first group is unaffected. - ok, err = db.CheckSessionGroupPredicate( - s1.GroupID, func(s *Session) bool { - return strings.Contains(s.Label, "label") - }, - ) - require.NoError(t, err) - require.True(t, ok) - - // And that the new session is evaluated separately. - ok, err = db.CheckSessionGroupPredicate( - s3.GroupID, func(s *Session) bool { - return strings.Contains(s.Label, "label") - }, - ) - require.NoError(t, err) - require.False(t, ok) - - ok, err = db.CheckSessionGroupPredicate( - s3.GroupID, func(s *Session) bool { - return strings.Contains(s.Label, "different") - }, - ) - require.NoError(t, err) - require.True(t, ok) -} - // TestStateShift tests that the ShiftState method works as expected. func TestStateShift(t *testing.T) { // Set up a new DB. diff --git a/session_rpcserver.go b/session_rpcserver.go index fcb4269f9..778b593cc 100644 --- a/session_rpcserver.go +++ b/session_rpcserver.go @@ -874,23 +874,6 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, "group %x", groupSess.ID, groupSess.GroupID) } - // Now we need to check that all the sessions in the group are - // no longer active. - ok, err := s.cfg.db.CheckSessionGroupPredicate( - groupID, func(s *session.Session) bool { - return s.State == session.StateRevoked || - s.State == session.StateExpired - }, - ) - if err != nil { - return nil, err - } - - if !ok { - return nil, fmt.Errorf("a linked session in group "+ - "%x is still active", groupID) - } - linkedGroupID = &groupID linkedGroupSession = groupSess