diff --git a/session/interface.go b/session/interface.go index 4898b882c..ad94a1f1c 100644 --- a/session/interface.go +++ b/session/interface.go @@ -27,15 +27,15 @@ const ( type State uint8 /* - /---> StateExpired (terminal) -StateCreated --- - \---> StateRevoked (terminal) + /---> StateExpired (terminal) +StateReserved ---> StateCreated --- + \---> StateRevoked (terminal) */ const ( // StateCreated is the state of a session once it has been fully - // committed to the Store and is ready to be used. This is the first - // state of a session. + // committed to the BoltStore and is ready to be used. This is the + // first state after StateReserved. StateCreated State = 0 // StateInUse is the state of a session that is currently being used. @@ -52,10 +52,10 @@ const ( // date. StateExpired State = 3 - // StateReserved is a temporary initial state of a session. On start-up, - // any sessions in this state should be cleaned up. - // - // NOTE: this isn't used yet. + // StateReserved is a temporary initial state of a session. This is used + // to reserve a unique ID and private key pair for a session before it + // is fully created. On start-up, any sessions in this state should be + // cleaned up. StateReserved State = 4 ) @@ -67,6 +67,9 @@ func (s State) Terminal() bool { // legalStateShifts is a map that defines the legal State transitions that a // Session can be put through. var legalStateShifts = map[State]map[State]bool{ + StateReserved: { + StateCreated: true, + }, StateCreated: { StateExpired: true, StateRevoked: true, @@ -141,7 +144,7 @@ func buildSession(id ID, localPrivKey *btcec.PrivateKey, label string, typ Type, sess := &Session{ ID: id, Label: label, - State: StateCreated, + State: StateReserved, Type: typ, Expiry: expiry.UTC(), CreatedAt: created.UTC(), @@ -185,23 +188,13 @@ type IDToGroupIndex interface { // retrieving Terminal Connect sessions. type Store interface { // NewSession creates a new session with the given user-defined - // parameters. - // - // NOTE: currently this purely a constructor of the Session type and - // does not make any database calls. This will be changed in a future - // commit. - NewSession(id ID, localPrivKey *btcec.PrivateKey, label string, - typ Type, expiry time.Time, serverAddr string, devServer bool, - perms []bakery.Op, caveats []macaroon.Caveat, + // parameters. The session will remain in the StateReserved state until + // ShiftState is called to update the state. + NewSession(label string, typ Type, expiry time.Time, serverAddr string, + devServer bool, perms []bakery.Op, caveats []macaroon.Caveat, featureConfig FeaturesConfig, privacy bool, linkedGroupID *ID, flags PrivacyFlags) (*Session, error) - // CreateSession adds a new session to the store. If a session with the - // same local public key already exists an error is returned. This - // can only be called with a Session with an ID that the Store has - // reserved. - CreateSession(*Session) error - // GetSession fetches the session with the given key. GetSession(key *btcec.PublicKey) (*Session, error) @@ -220,21 +213,9 @@ type Store interface { UpdateSessionRemotePubKey(localPubKey, remotePubKey *btcec.PublicKey) error - // GetUnusedIDAndKeyPair can be used to generate a new, unused, local - // private key and session ID pair. Care must be taken to ensure that no - // other thread calls this before the returned ID and key pair from this - // method are either used or discarded. - GetUnusedIDAndKeyPair() (ID, *btcec.PrivateKey, error) - // GetSessionByID fetches the session with the given ID. GetSessionByID(id ID) (*Session, error) - // CheckSessionGroupPredicate iterates over all the sessions in a group - // and checks if each one passes the given predicate function. True is - // returned if each session passes. - CheckSessionGroupPredicate(groupID ID, - fn func(s *Session) bool) (bool, error) - // DeleteReservedSessions deletes all sessions that are in the // StateReserved state. DeleteReservedSessions() error diff --git a/session/kvdb_store.go b/session/kvdb_store.go index 60bc62a04..84f4dce06 100644 --- a/session/kvdb_store.go +++ b/session/kvdb_store.go @@ -182,41 +182,42 @@ func getSessionKey(session *Session) []byte { return session.LocalPublicKey.SerializeCompressed() } -// NewSession creates a new session with the given user-defined parameters. -// -// NOTE: currently this purely a constructor of the Session type and does not -// make any database calls. This will be changed in a future commit. -// -// NOTE: this is part of the Store interface. -func (db *BoltStore) NewSession(id ID, localPrivKey *btcec.PrivateKey, - label string, typ Type, expiry time.Time, serverAddr string, - devServer bool, perms []bakery.Op, caveats []macaroon.Caveat, - featureConfig FeaturesConfig, privacy bool, linkedGroupID *ID, - flags PrivacyFlags) (*Session, error) { - - return buildSession( - id, localPrivKey, label, typ, db.clock.Now(), expiry, - serverAddr, devServer, perms, caveats, featureConfig, privacy, - linkedGroupID, flags, - ) -} - -// CreateSession adds a new session to the store. If a session with the same -// local public key already exists an error is returned. +// NewSession creates and persists a new session with the given user-defined +// parameters. The initial state of the session will be Reserved until +// ShiftState is called with StateCreated. // // NOTE: this is part of the Store interface. -func (db *BoltStore) CreateSession(session *Session) error { - sessionKey := getSessionKey(session) +func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time, + serverAddr string, devServer bool, perms []bakery.Op, + caveats []macaroon.Caveat, featureConfig FeaturesConfig, privacy bool, + linkedGroupID *ID, flags PrivacyFlags) (*Session, error) { - return db.Update(func(tx *bbolt.Tx) error { + var session *Session + err := db.Update(func(tx *bbolt.Tx) error { sessionBucket, err := getBucket(tx, sessionBucketKey) if err != nil { return err } + id, localPrivKey, err := getUnusedIDAndKeyPair(sessionBucket) + if err != nil { + return err + } + + session, err = buildSession( + id, localPrivKey, label, typ, db.clock.Now(), expiry, + serverAddr, devServer, perms, caveats, featureConfig, + privacy, linkedGroupID, flags, + ) + if err != nil { + return err + } + + sessionKey := getSessionKey(session) + if len(sessionBucket.Get(sessionKey)) != 0 { - return fmt.Errorf("session with local public "+ - "key(%x) already exists", + return fmt.Errorf("session with local public key(%x) "+ + "already exists", session.LocalPublicKey.SerializeCompressed()) } @@ -248,9 +249,7 @@ func (db *BoltStore) CreateSession(session *Session) error { } // Ensure that the session is no longer active. - if sess.State == StateCreated || - sess.State == StateInUse { - + if !sess.State.Terminal() { return fmt.Errorf("session (id=%x) "+ "in group %x is still active", sess.ID, sess.GroupID) @@ -275,6 +274,11 @@ func (db *BoltStore) CreateSession(session *Session) error { return putSession(sessionBucket, session) }) + if err != nil { + return nil, err + } + + return session, nil } // UpdateSessionRemotePubKey can be used to add the given remote pub key @@ -577,53 +581,35 @@ func (db *BoltStore) GetSessionByID(id ID) (*Session, error) { return session, nil } -// GetUnusedIDAndKeyPair can be used to generate a new, unused, local private +// getUnusedIDAndKeyPair can be used to generate a new, unused, local private // key and session ID pair. Care must be taken to ensure that no other thread // calls this before the returned ID and key pair from this method are either // used or discarded. -// -// NOTE: this is part of the Store interface. -func (db *BoltStore) GetUnusedIDAndKeyPair() (ID, *btcec.PrivateKey, error) { - var ( - id ID - privKey *btcec.PrivateKey - ) - err := db.Update(func(tx *bbolt.Tx) error { - sessionBucket, err := getBucket(tx, sessionBucketKey) - if err != nil { - return err - } - - idIndexBkt := sessionBucket.Bucket(idIndexKey) - if idIndexBkt == nil { - return ErrDBInitErr - } +func getUnusedIDAndKeyPair(bucket *bbolt.Bucket) (ID, *btcec.PrivateKey, + error) { - // Spin until we find a key with an ID that does not collide - // with any of our existing IDs. - for { - // Generate a new private key and ID pair. - privKey, id, err = NewSessionPrivKeyAndID() - if err != nil { - return err - } + idIndexBkt := bucket.Bucket(idIndexKey) + if idIndexBkt == nil { + return ID{}, nil, ErrDBInitErr + } - // Check that no such ID exits in our id-to-key index. - idBkt := idIndexBkt.Bucket(id[:]) - if idBkt != nil { - continue - } + // Spin until we find a key with an ID that does not collide with any of + // our existing IDs. + for { + // Generate a new private key and ID pair. + privKey, id, err := NewSessionPrivKeyAndID() + if err != nil { + return ID{}, nil, err + } - break + // Check that no such ID exits in our id-to-key index. + idBkt := idIndexBkt.Bucket(id[:]) + if idBkt != nil { + continue } - return nil - }) - if err != nil { - return id, nil, err + return id, privKey, nil } - - return id, privKey, nil } // GetGroupID will return the group ID for the given session ID. @@ -691,65 +677,6 @@ func (db *BoltStore) GetSessionIDs(groupID ID) ([]ID, error) { return sessionIDs, nil } -// CheckSessionGroupPredicate iterates over all the sessions in a group and -// checks if each one passes the given predicate function. True is returned if -// each session passes. -// -// NOTE: this is part of the Store interface. -func (db *BoltStore) CheckSessionGroupPredicate(groupID ID, - fn func(s *Session) bool) (bool, error) { - - var ( - pass bool - errFailedPred = errors.New("session failed predicate") - ) - err := db.View(func(tx *bbolt.Tx) error { - sessionBkt, err := getBucket(tx, sessionBucketKey) - if err != nil { - return err - } - - sessionIDs, err := getSessionIDs(sessionBkt, groupID) - if err != nil { - return err - } - - // Iterate over all the sessions. - for _, id := range sessionIDs { - key, err := getKeyForID(sessionBkt, id) - if err != nil { - return err - } - - v := sessionBkt.Get(key) - if len(v) == 0 { - return ErrSessionNotFound - } - - session, err := DeserializeSession(bytes.NewReader(v)) - if err != nil { - return err - } - - if !fn(session) { - return errFailedPred - } - } - - pass = true - - return nil - }) - if errors.Is(err, errFailedPred) { - return pass, nil - } - if err != nil { - return pass, err - } - - return pass, nil -} - // getSessionIDs returns all the session IDs associated with the given group ID. func getSessionIDs(sessionBkt *bbolt.Bucket, groupID ID) ([]ID, error) { var sessionIDs []ID diff --git a/session/store_test.go b/session/store_test.go index db6368450..aa7afe20f 100644 --- a/session/store_test.go +++ b/session/store_test.go @@ -1,7 +1,6 @@ package session import ( - "strings" "testing" "time" @@ -23,38 +22,36 @@ func TestBasicSessionStore(t *testing.T) { _ = db.Close() }) - // Create a few sessions. We increment the time by one second between - // each session to ensure that the created at time is unique and hence - // that the ListSessions method returns the sessions in a deterministic - // order. - s1 := newSession(t, db, clock, "session 1") - clock.SetTime(testTime.Add(time.Second)) - s2 := newSession(t, db, clock, "session 2") - clock.SetTime(testTime.Add(2 * time.Second)) - s3 := newSession(t, db, clock, "session 3", withType(TypeAutopilot)) - clock.SetTime(testTime.Add(3 * time.Second)) - s4 := newSession(t, db, clock, "session 4") + // Reserve a session. This should succeed. + s1, err := reserveSession(db, "session 1") + require.NoError(t, err) - // Persist session 1. This should now succeed. - require.NoError(t, db.CreateSession(s1)) + // Show that the session starts in the reserved state. + s1, err = db.GetSessionByID(s1.ID) + require.NoError(t, err) + require.Equal(t, StateReserved, s1.State) - // Trying to persist session 1 again should fail due to a session with - // the given pub key already existing. - require.ErrorContains(t, db.CreateSession(s1), "already exists") + // Move session 1 to the created state. This should succeed. + err = db.ShiftState(s1.ID, StateCreated) + require.NoError(t, err) - // Change the local pub key of session 4 such that it has the same - // ID as session 1. - s4.ID = s1.ID - s4.GroupID = s1.GroupID + // Show that the session is now in the created state. + s1, err = db.GetSessionByID(s1.ID) + require.NoError(t, err) + require.Equal(t, StateCreated, s1.State) - // Now try to insert session 4. This should fail due to an entry for - // the ID already existing. - require.ErrorContains(t, db.CreateSession(s4), "a session with the "+ - "given ID already exists") + // Trying to move session 1 again should have no effect since it is + // already in the created state. + require.NoError(t, db.ShiftState(s1.ID, StateCreated)) - // Persist a few more sessions. - require.NoError(t, db.CreateSession(s2)) - require.NoError(t, db.CreateSession(s3)) + // Reserve and create a few more sessions. We increment the time by one + // second between each session to ensure that the created at time is + // unique and hence that the ListSessions method returns the sessions in + // a deterministic order. + clock.SetTime(testTime.Add(time.Second)) + s2 := createSession(t, db, "session 2") + clock.SetTime(testTime.Add(2 * time.Second)) + s3 := createSession(t, db, "session 3", withType(TypeAutopilot)) // Test the ListSessionsByType method. sessions, err := db.ListSessionsByType(TypeMacaroonAdmin) @@ -156,28 +153,26 @@ func TestBasicSessionStore(t *testing.T) { require.NoError(t, err) require.Empty(t, sessions) - // Add a session and put it in the StateReserved state. We'll also - // link it to session 1. - s5 := newSession( - t, db, clock, "session 5", withState(StateReserved), - withLinkedGroupID(&session1.GroupID), + // Reserve a new session and link it to session 1. + s4, err := reserveSession( + db, "session 4", withLinkedGroupID(&session1.GroupID), ) - require.NoError(t, db.CreateSession(s5)) + require.NoError(t, err) sessions, err = db.ListSessionsByState(StateReserved) require.NoError(t, err) require.Equal(t, 1, len(sessions)) - assertEqualSessions(t, s5, sessions[0]) + assertEqualSessions(t, s4, sessions[0]) // Show that the group ID/session ID index has also been populated with // this session. - groupID, err := db.GetGroupID(s5.ID) + groupID, err := db.GetGroupID(s4.ID) require.NoError(t, err) require.Equal(t, s1.ID, groupID) - sessIDs, err := db.GetSessionIDs(s5.GroupID) + sessIDs, err := db.GetSessionIDs(s4.GroupID) require.NoError(t, err) - require.ElementsMatch(t, []ID{s5.ID, s1.ID}, sessIDs) + require.ElementsMatch(t, []ID{s4.ID, s1.ID}, sessIDs) // Now delete the reserved session and show that it is no longer in the // database and no longer in the group ID/session ID index. @@ -187,17 +182,19 @@ func TestBasicSessionStore(t *testing.T) { require.NoError(t, err) require.Empty(t, sessions) - _, err = db.GetGroupID(s5.ID) + _, err = db.GetGroupID(s4.ID) require.ErrorContains(t, err, "no index entry") // Only session 1 should remain in this group. - sessIDs, err = db.GetSessionIDs(s5.GroupID) + sessIDs, err = db.GetSessionIDs(s4.GroupID) require.NoError(t, err) require.ElementsMatch(t, []ID{s1.ID}, sessIDs) } // TestLinkingSessions tests that session linking works as expected. func TestLinkingSessions(t *testing.T) { + t.Parallel() + // Set up a new DB. clock := clock.NewTestClock(testTime) db, err := NewDB(t.TempDir(), "test.db", clock) @@ -206,35 +203,39 @@ func TestLinkingSessions(t *testing.T) { _ = db.Close() }) - // Create a new session with no previous link. - s1 := newSession(t, db, clock, "session 1") - - // Create another session and link it to the first. - s2 := newSession(t, db, clock, "session 2", withLinkedGroupID(&s1.GroupID)) + groupID, err := IDFromBytes([]byte{1, 2, 3, 4}) + require.NoError(t, err) - // Try to persist the second session and assert that it fails due to the - // linked session not existing in the DB yet. - require.ErrorContains(t, db.CreateSession(s2), "unknown linked session") + // Try to reserve a session that links to another and assert that it + // fails due to the linked session not existing in the BoltStore yet. + _, err = reserveSession( + db, "session 2", withLinkedGroupID(&groupID), + ) + require.ErrorContains(t, err, "unknown linked session") - // Now persist the first session and retry persisting the second one - // and assert that this now works. - require.NoError(t, db.CreateSession(s1)) + // Create a new session with no previous link. + s1 := createSession(t, db, "session 1") - // Persisting the second session immediately should fail due to the - // first session still being active. - require.ErrorContains(t, db.CreateSession(s2), "is still active") + // Once again try to reserve a session that links to the now existing + // session. This should fail due to the first session still being + // active. + _, err = reserveSession(db, "session 2", withLinkedGroupID(&s1.GroupID)) + require.ErrorContains(t, err, "is still active") // Revoke the first session. require.NoError(t, db.ShiftState(s1.ID, StateRevoked)) // Persisting the second linked session should now work. - require.NoError(t, db.CreateSession(s2)) + _, err = reserveSession(db, "session 2", withLinkedGroupID(&s1.GroupID)) + require.NoError(t, err) } // TestIDToGroupIDIndex tests that the session-ID-to-group-ID and // group-ID-to-session-ID indexes work as expected by asserting the behaviour // of the GetGroupID and GetSessionIDs methods. func TestLinkedSessions(t *testing.T) { + t.Parallel() + // Set up a new DB. clock := clock.NewTestClock(testTime) db, err := NewDB(t.TempDir(), "test.db", clock) @@ -247,22 +248,13 @@ func TestLinkedSessions(t *testing.T) { // after are all linked to the prior one. All these sessions belong to // the same group. The group ID is equivalent to the session ID of the // first session. - s1 := newSession(t, db, clock, "session 1") - s2 := newSession( - t, db, clock, "session 2", withLinkedGroupID(&s1.GroupID), - ) - s3 := newSession( - t, db, clock, "session 3", withLinkedGroupID(&s2.GroupID), - ) - - // Persist the sessions. - require.NoError(t, db.CreateSession(s1)) + s1 := createSession(t, db, "session 1") require.NoError(t, db.ShiftState(s1.ID, StateRevoked)) - require.NoError(t, db.CreateSession(s2)) + s2 := createSession(t, db, "session 2", withLinkedGroupID(&s1.GroupID)) require.NoError(t, db.ShiftState(s2.ID, StateRevoked)) - require.NoError(t, db.CreateSession(s3)) + s3 := createSession(t, db, "session 3", withLinkedGroupID(&s2.GroupID)) // Assert that the session ID to group ID index works as expected. for _, s := range []*Session{s1, s2, s3} { @@ -279,16 +271,10 @@ func TestLinkedSessions(t *testing.T) { // To ensure that different groups don't interfere with each other, // let's add another set of linked sessions not linked to the first. - s4 := newSession(t, db, clock, "session 4") - s5 := newSession(t, db, clock, "session 5", withLinkedGroupID(&s4.GroupID)) - - require.NotEqual(t, s4.GroupID, s1.GroupID) - - // Persist the sessions. - require.NoError(t, db.CreateSession(s4)) + s4 := createSession(t, db, "session 4") require.NoError(t, db.ShiftState(s4.ID, StateRevoked)) - - require.NoError(t, db.CreateSession(s5)) + s5 := createSession(t, db, "session 5", withLinkedGroupID(&s4.GroupID)) + require.NotEqual(t, s4.GroupID, s1.GroupID) // Assert that the session ID to group ID index works as expected. for _, s := range []*Session{s4, s5} { @@ -304,98 +290,6 @@ func TestLinkedSessions(t *testing.T) { require.EqualValues(t, []ID{s4.ID, s5.ID}, sIDs) } -// TestCheckSessionGroupPredicate asserts that the CheckSessionGroupPredicate -// method correctly checks if each session in a group passes a predicate. -func TestCheckSessionGroupPredicate(t *testing.T) { - // Set up a new DB. - clock := clock.NewTestClock(testTime) - db, err := NewDB(t.TempDir(), "test.db", clock) - require.NoError(t, err) - t.Cleanup(func() { - _ = db.Close() - }) - - // We will use the Label of the Session to test that the predicate - // function is checked correctly. - - // Add a new session to the DB. - s1 := newSession(t, db, clock, "label 1") - require.NoError(t, db.CreateSession(s1)) - - // Check that the group passes against an appropriate predicate. - ok, err := db.CheckSessionGroupPredicate( - s1.GroupID, func(s *Session) bool { - return strings.Contains(s.Label, "label 1") - }, - ) - require.NoError(t, err) - require.True(t, ok) - - // Check that the group fails against an appropriate predicate. - ok, err = db.CheckSessionGroupPredicate( - s1.GroupID, func(s *Session) bool { - return strings.Contains(s.Label, "label 2") - }, - ) - require.NoError(t, err) - require.False(t, ok) - - // Revoke the first session. - require.NoError(t, db.ShiftState(s1.ID, StateRevoked)) - - // Add a new session to the same group as the first one. - s2 := newSession(t, db, clock, "label 2", withLinkedGroupID(&s1.GroupID)) - require.NoError(t, db.CreateSession(s2)) - - // Check that the group passes against an appropriate predicate. - ok, err = db.CheckSessionGroupPredicate( - s1.GroupID, func(s *Session) bool { - return strings.Contains(s.Label, "label") - }, - ) - require.NoError(t, err) - require.True(t, ok) - - // Check that the group fails against an appropriate predicate. - ok, err = db.CheckSessionGroupPredicate( - s1.GroupID, func(s *Session) bool { - return strings.Contains(s.Label, "label 1") - }, - ) - require.NoError(t, err) - require.False(t, ok) - - // Add a new session that is not linked to the first one. - s3 := newSession(t, db, clock, "completely different") - require.NoError(t, db.CreateSession(s3)) - - // Ensure that the first group is unaffected. - ok, err = db.CheckSessionGroupPredicate( - s1.GroupID, func(s *Session) bool { - return strings.Contains(s.Label, "label") - }, - ) - require.NoError(t, err) - require.True(t, ok) - - // And that the new session is evaluated separately. - ok, err = db.CheckSessionGroupPredicate( - s3.GroupID, func(s *Session) bool { - return strings.Contains(s.Label, "label") - }, - ) - require.NoError(t, err) - require.False(t, ok) - - ok, err = db.CheckSessionGroupPredicate( - s3.GroupID, func(s *Session) bool { - return strings.Contains(s.Label, "different") - }, - ) - require.NoError(t, err) - require.True(t, ok) -} - // TestStateShift tests that the ShiftState method works as expected. func TestStateShift(t *testing.T) { // Set up a new DB. @@ -407,8 +301,7 @@ func TestStateShift(t *testing.T) { }) // Add a new session to the DB. - s1 := newSession(t, db, clock, "label 1") - require.NoError(t, db.CreateSession(s1)) + s1 := createSession(t, db, "label 1") // Check that the session is in the StateCreated state. Also check that // the "RevokedAt" time has not yet been set. @@ -443,48 +336,62 @@ func TestStateShift(t *testing.T) { require.ErrorContains(t, err, "illegal session state transition") } +type testSessionOpts struct { + groupID *ID + sessType Type +} + +func defaultTestSessOpts() *testSessionOpts { + return &testSessionOpts{ + groupID: nil, + sessType: TypeMacaroonAdmin, + } +} + // testSessionModifier is a functional option that can be used to modify the // default test session created by newSession. -type testSessionModifier func(*Session) +type testSessionModifier func(*testSessionOpts) func withLinkedGroupID(groupID *ID) testSessionModifier { - return func(s *Session) { - s.GroupID = *groupID + return func(s *testSessionOpts) { + s.groupID = groupID } } func withType(t Type) testSessionModifier { - return func(s *Session) { - s.Type = t + return func(s *testSessionOpts) { + s.sessType = t } } -func withState(state State) testSessionModifier { - return func(s *Session) { - s.State = state +func reserveSession(db Store, label string, + mods ...testSessionModifier) (*Session, error) { + + opts := defaultTestSessOpts() + for _, mod := range mods { + mod(opts) } + + return db.NewSession(label, opts.sessType, + time.Date(99999, 1, 1, 0, 0, 0, 0, time.UTC), + "foo.bar.baz:1234", true, nil, nil, nil, true, opts.groupID, + []PrivacyFlag{ClearPubkeys}, + ) } -func newSession(t *testing.T, db Store, clock clock.Clock, label string, +func createSession(t *testing.T, db Store, label string, mods ...testSessionModifier) *Session { - id, priv, err := db.GetUnusedIDAndKeyPair() + s, err := reserveSession(db, label, mods...) require.NoError(t, err) - session, err := buildSession( - id, priv, label, TypeMacaroonAdmin, - clock.Now(), - time.Date(99999, 1, 1, 0, 0, 0, 0, time.UTC), - "foo.bar.baz:1234", true, nil, nil, nil, true, nil, - []PrivacyFlag{ClearPubkeys}, - ) + err = db.ShiftState(s.ID, StateCreated) require.NoError(t, err) - for _, mod := range mods { - mod(session) - } + s, err = db.GetSessionByID(s.ID) + require.NoError(t, err) - return session + return s } func assertEqualSessions(t *testing.T, expected, actual *Session) { diff --git a/session_rpcserver.go b/session_rpcserver.go index e543b8bb2..778b593cc 100644 --- a/session_rpcserver.go +++ b/session_rpcserver.go @@ -46,11 +46,6 @@ type sessionRpcServer struct { cfg *sessionRpcServerConfig sessionServer *session.Server - // sessRegMu is a mutex that should be held between acquiring an unused - // session ID and key pair from the session store and persisting that - // new session. - sessRegMu sync.Mutex - quit chan struct{} wg sync.WaitGroup stopOnce sync.Once @@ -313,16 +308,8 @@ func (s *sessionRpcServer) AddSession(ctx context.Context, } } - s.sessRegMu.Lock() - defer s.sessRegMu.Unlock() - - id, localPrivKey, err := s.cfg.db.GetUnusedIDAndKeyPair() - if err != nil { - return nil, err - } - sess, err := s.cfg.db.NewSession( - id, localPrivKey, req.Label, typ, expiry, req.MailboxServerAddr, + req.Label, typ, expiry, req.MailboxServerAddr, req.DevServer, uniquePermissions, caveats, nil, false, nil, session.PrivacyFlags{}, ) @@ -330,14 +317,23 @@ func (s *sessionRpcServer) AddSession(ctx context.Context, return nil, fmt.Errorf("error creating new session: %v", err) } - if err := s.cfg.db.CreateSession(sess); err != nil { - return nil, fmt.Errorf("error storing session: %v", err) + err = s.cfg.db.ShiftState(sess.ID, session.StateCreated) + if err != nil { + return nil, fmt.Errorf("error shifting session state to "+ + "Created: %v", err) } if err := s.resumeSession(ctx, sess); err != nil { return nil, fmt.Errorf("error starting session: %v", err) } + // Re-fetch the session to get the latest state of it before marshaling + // it. + sess, err = s.cfg.db.GetSessionByID(sess.ID) + if err != nil { + return nil, fmt.Errorf("error fetching session: %v", err) + } + rpcSession, err := s.marshalRPCSession(sess) if err != nil { return nil, fmt.Errorf("error marshaling session: %v", err) @@ -878,23 +874,6 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, "group %x", groupSess.ID, groupSess.GroupID) } - // Now we need to check that all the sessions in the group are - // no longer active. - ok, err := s.cfg.db.CheckSessionGroupPredicate( - groupID, func(s *session.Session) bool { - return s.State == session.StateRevoked || - s.State == session.StateExpired - }, - ) - if err != nil { - return nil, err - } - - if !ok { - return nil, fmt.Errorf("a linked session in group "+ - "%x is still active", groupID) - } - linkedGroupID = &groupID linkedGroupSession = groupSess @@ -1141,16 +1120,8 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, caveats = append(caveats, firewall.MetaPrivacyCaveat) } - s.sessRegMu.Lock() - defer s.sessRegMu.Unlock() - - id, localPrivKey, err := s.cfg.db.GetUnusedIDAndKeyPair() - if err != nil { - return nil, err - } - sess, err := s.cfg.db.NewSession( - id, localPrivKey, req.Label, session.TypeAutopilot, expiry, + req.Label, session.TypeAutopilot, expiry, req.MailboxServerAddr, req.DevServer, perms, caveats, clientConfig, privacy, linkedGroupID, privacyFlags, ) @@ -1233,17 +1204,33 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, "autopilot server: %v", err) } - // We only persist this session if we successfully retrieved the - // autopilot's static key. + err = s.cfg.db.UpdateSessionRemotePubKey(sess.LocalPublicKey, remoteKey) + if err != nil { + return nil, fmt.Errorf("error setting remote pubkey: %v", err) + } + + // Update our in-memory session with the remote key. sess.RemotePublicKey = remoteKey - if err := s.cfg.db.CreateSession(sess); err != nil { - return nil, fmt.Errorf("error storing session: %v", err) + + // We only activate the session if the Autopilot server registration + // was successful. + err = s.cfg.db.ShiftState(sess.ID, session.StateCreated) + if err != nil { + return nil, fmt.Errorf("error shifting session state to "+ + "Created: %v", err) } if err := s.resumeSession(ctx, sess); err != nil { return nil, fmt.Errorf("error starting session: %v", err) } + // Re-fetch the session to get the latest state of it before marshaling + // it. + sess, err = s.cfg.db.GetSessionByID(sess.ID) + if err != nil { + return nil, fmt.Errorf("error fetching session: %v", err) + } + rpcSession, err := s.marshalRPCSession(sess) if err != nil { return nil, fmt.Errorf("error marshaling session: %v", err)