Skip to content

Commit aa8080d

Browse files
authored
Merge pull request #980 from ellemouton/sql14Sessions6
[sql-14] sessions: atomic session creation
2 parents 88a2bf0 + 32a34d1 commit aa8080d

File tree

4 files changed

+201
-399
lines changed

4 files changed

+201
-399
lines changed

session/interface.go

+17-36
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ const (
2727
type State uint8
2828

2929
/*
30-
/---> StateExpired (terminal)
31-
StateCreated ---
32-
\---> StateRevoked (terminal)
30+
/---> StateExpired (terminal)
31+
StateReserved ---> StateCreated ---
32+
\---> StateRevoked (terminal)
3333
*/
3434

3535
const (
3636
// StateCreated is the state of a session once it has been fully
37-
// committed to the Store and is ready to be used. This is the first
38-
// state of a session.
37+
// committed to the BoltStore and is ready to be used. This is the
38+
// first state after StateReserved.
3939
StateCreated State = 0
4040

4141
// StateInUse is the state of a session that is currently being used.
@@ -52,10 +52,10 @@ const (
5252
// date.
5353
StateExpired State = 3
5454

55-
// StateReserved is a temporary initial state of a session. On start-up,
56-
// any sessions in this state should be cleaned up.
57-
//
58-
// NOTE: this isn't used yet.
55+
// StateReserved is a temporary initial state of a session. This is used
56+
// to reserve a unique ID and private key pair for a session before it
57+
// is fully created. On start-up, any sessions in this state should be
58+
// cleaned up.
5959
StateReserved State = 4
6060
)
6161

@@ -67,6 +67,9 @@ func (s State) Terminal() bool {
6767
// legalStateShifts is a map that defines the legal State transitions that a
6868
// Session can be put through.
6969
var legalStateShifts = map[State]map[State]bool{
70+
StateReserved: {
71+
StateCreated: true,
72+
},
7073
StateCreated: {
7174
StateExpired: true,
7275
StateRevoked: true,
@@ -141,7 +144,7 @@ func buildSession(id ID, localPrivKey *btcec.PrivateKey, label string, typ Type,
141144
sess := &Session{
142145
ID: id,
143146
Label: label,
144-
State: StateCreated,
147+
State: StateReserved,
145148
Type: typ,
146149
Expiry: expiry.UTC(),
147150
CreatedAt: created.UTC(),
@@ -185,23 +188,13 @@ type IDToGroupIndex interface {
185188
// retrieving Terminal Connect sessions.
186189
type Store interface {
187190
// NewSession creates a new session with the given user-defined
188-
// parameters.
189-
//
190-
// NOTE: currently this purely a constructor of the Session type and
191-
// does not make any database calls. This will be changed in a future
192-
// commit.
193-
NewSession(id ID, localPrivKey *btcec.PrivateKey, label string,
194-
typ Type, expiry time.Time, serverAddr string, devServer bool,
195-
perms []bakery.Op, caveats []macaroon.Caveat,
191+
// parameters. The session will remain in the StateReserved state until
192+
// ShiftState is called to update the state.
193+
NewSession(label string, typ Type, expiry time.Time, serverAddr string,
194+
devServer bool, perms []bakery.Op, caveats []macaroon.Caveat,
196195
featureConfig FeaturesConfig, privacy bool, linkedGroupID *ID,
197196
flags PrivacyFlags) (*Session, error)
198197

199-
// CreateSession adds a new session to the store. If a session with the
200-
// same local public key already exists an error is returned. This
201-
// can only be called with a Session with an ID that the Store has
202-
// reserved.
203-
CreateSession(*Session) error
204-
205198
// GetSession fetches the session with the given key.
206199
GetSession(key *btcec.PublicKey) (*Session, error)
207200

@@ -220,21 +213,9 @@ type Store interface {
220213
UpdateSessionRemotePubKey(localPubKey,
221214
remotePubKey *btcec.PublicKey) error
222215

223-
// GetUnusedIDAndKeyPair can be used to generate a new, unused, local
224-
// private key and session ID pair. Care must be taken to ensure that no
225-
// other thread calls this before the returned ID and key pair from this
226-
// method are either used or discarded.
227-
GetUnusedIDAndKeyPair() (ID, *btcec.PrivateKey, error)
228-
229216
// GetSessionByID fetches the session with the given ID.
230217
GetSessionByID(id ID) (*Session, error)
231218

232-
// CheckSessionGroupPredicate iterates over all the sessions in a group
233-
// and checks if each one passes the given predicate function. True is
234-
// returned if each session passes.
235-
CheckSessionGroupPredicate(groupID ID,
236-
fn func(s *Session) bool) (bool, error)
237-
238219
// DeleteReservedSessions deletes all sessions that are in the
239220
// StateReserved state.
240221
DeleteReservedSessions() error

session/kvdb_store.go

+53-126
Original file line numberDiff line numberDiff line change
@@ -182,41 +182,42 @@ func getSessionKey(session *Session) []byte {
182182
return session.LocalPublicKey.SerializeCompressed()
183183
}
184184

185-
// NewSession creates a new session with the given user-defined parameters.
186-
//
187-
// NOTE: currently this purely a constructor of the Session type and does not
188-
// make any database calls. This will be changed in a future commit.
189-
//
190-
// NOTE: this is part of the Store interface.
191-
func (db *BoltStore) NewSession(id ID, localPrivKey *btcec.PrivateKey,
192-
label string, typ Type, expiry time.Time, serverAddr string,
193-
devServer bool, perms []bakery.Op, caveats []macaroon.Caveat,
194-
featureConfig FeaturesConfig, privacy bool, linkedGroupID *ID,
195-
flags PrivacyFlags) (*Session, error) {
196-
197-
return buildSession(
198-
id, localPrivKey, label, typ, db.clock.Now(), expiry,
199-
serverAddr, devServer, perms, caveats, featureConfig, privacy,
200-
linkedGroupID, flags,
201-
)
202-
}
203-
204-
// CreateSession adds a new session to the store. If a session with the same
205-
// local public key already exists an error is returned.
185+
// NewSession creates and persists a new session with the given user-defined
186+
// parameters. The initial state of the session will be Reserved until
187+
// ShiftState is called with StateCreated.
206188
//
207189
// NOTE: this is part of the Store interface.
208-
func (db *BoltStore) CreateSession(session *Session) error {
209-
sessionKey := getSessionKey(session)
190+
func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time,
191+
serverAddr string, devServer bool, perms []bakery.Op,
192+
caveats []macaroon.Caveat, featureConfig FeaturesConfig, privacy bool,
193+
linkedGroupID *ID, flags PrivacyFlags) (*Session, error) {
210194

211-
return db.Update(func(tx *bbolt.Tx) error {
195+
var session *Session
196+
err := db.Update(func(tx *bbolt.Tx) error {
212197
sessionBucket, err := getBucket(tx, sessionBucketKey)
213198
if err != nil {
214199
return err
215200
}
216201

202+
id, localPrivKey, err := getUnusedIDAndKeyPair(sessionBucket)
203+
if err != nil {
204+
return err
205+
}
206+
207+
session, err = buildSession(
208+
id, localPrivKey, label, typ, db.clock.Now(), expiry,
209+
serverAddr, devServer, perms, caveats, featureConfig,
210+
privacy, linkedGroupID, flags,
211+
)
212+
if err != nil {
213+
return err
214+
}
215+
216+
sessionKey := getSessionKey(session)
217+
217218
if len(sessionBucket.Get(sessionKey)) != 0 {
218-
return fmt.Errorf("session with local public "+
219-
"key(%x) already exists",
219+
return fmt.Errorf("session with local public key(%x) "+
220+
"already exists",
220221
session.LocalPublicKey.SerializeCompressed())
221222
}
222223

@@ -248,9 +249,7 @@ func (db *BoltStore) CreateSession(session *Session) error {
248249
}
249250

250251
// Ensure that the session is no longer active.
251-
if sess.State == StateCreated ||
252-
sess.State == StateInUse {
253-
252+
if !sess.State.Terminal() {
254253
return fmt.Errorf("session (id=%x) "+
255254
"in group %x is still active",
256255
sess.ID, sess.GroupID)
@@ -275,6 +274,11 @@ func (db *BoltStore) CreateSession(session *Session) error {
275274

276275
return putSession(sessionBucket, session)
277276
})
277+
if err != nil {
278+
return nil, err
279+
}
280+
281+
return session, nil
278282
}
279283

280284
// UpdateSessionRemotePubKey can be used to add the given remote pub key
@@ -577,53 +581,35 @@ func (db *BoltStore) GetSessionByID(id ID) (*Session, error) {
577581
return session, nil
578582
}
579583

580-
// GetUnusedIDAndKeyPair can be used to generate a new, unused, local private
584+
// getUnusedIDAndKeyPair can be used to generate a new, unused, local private
581585
// key and session ID pair. Care must be taken to ensure that no other thread
582586
// calls this before the returned ID and key pair from this method are either
583587
// used or discarded.
584-
//
585-
// NOTE: this is part of the Store interface.
586-
func (db *BoltStore) GetUnusedIDAndKeyPair() (ID, *btcec.PrivateKey, error) {
587-
var (
588-
id ID
589-
privKey *btcec.PrivateKey
590-
)
591-
err := db.Update(func(tx *bbolt.Tx) error {
592-
sessionBucket, err := getBucket(tx, sessionBucketKey)
593-
if err != nil {
594-
return err
595-
}
596-
597-
idIndexBkt := sessionBucket.Bucket(idIndexKey)
598-
if idIndexBkt == nil {
599-
return ErrDBInitErr
600-
}
588+
func getUnusedIDAndKeyPair(bucket *bbolt.Bucket) (ID, *btcec.PrivateKey,
589+
error) {
601590

602-
// Spin until we find a key with an ID that does not collide
603-
// with any of our existing IDs.
604-
for {
605-
// Generate a new private key and ID pair.
606-
privKey, id, err = NewSessionPrivKeyAndID()
607-
if err != nil {
608-
return err
609-
}
591+
idIndexBkt := bucket.Bucket(idIndexKey)
592+
if idIndexBkt == nil {
593+
return ID{}, nil, ErrDBInitErr
594+
}
610595

611-
// Check that no such ID exits in our id-to-key index.
612-
idBkt := idIndexBkt.Bucket(id[:])
613-
if idBkt != nil {
614-
continue
615-
}
596+
// Spin until we find a key with an ID that does not collide with any of
597+
// our existing IDs.
598+
for {
599+
// Generate a new private key and ID pair.
600+
privKey, id, err := NewSessionPrivKeyAndID()
601+
if err != nil {
602+
return ID{}, nil, err
603+
}
616604

617-
break
605+
// Check that no such ID exits in our id-to-key index.
606+
idBkt := idIndexBkt.Bucket(id[:])
607+
if idBkt != nil {
608+
continue
618609
}
619610

620-
return nil
621-
})
622-
if err != nil {
623-
return id, nil, err
611+
return id, privKey, nil
624612
}
625-
626-
return id, privKey, nil
627613
}
628614

629615
// GetGroupID will return the group ID for the given session ID.
@@ -691,65 +677,6 @@ func (db *BoltStore) GetSessionIDs(groupID ID) ([]ID, error) {
691677
return sessionIDs, nil
692678
}
693679

694-
// CheckSessionGroupPredicate iterates over all the sessions in a group and
695-
// checks if each one passes the given predicate function. True is returned if
696-
// each session passes.
697-
//
698-
// NOTE: this is part of the Store interface.
699-
func (db *BoltStore) CheckSessionGroupPredicate(groupID ID,
700-
fn func(s *Session) bool) (bool, error) {
701-
702-
var (
703-
pass bool
704-
errFailedPred = errors.New("session failed predicate")
705-
)
706-
err := db.View(func(tx *bbolt.Tx) error {
707-
sessionBkt, err := getBucket(tx, sessionBucketKey)
708-
if err != nil {
709-
return err
710-
}
711-
712-
sessionIDs, err := getSessionIDs(sessionBkt, groupID)
713-
if err != nil {
714-
return err
715-
}
716-
717-
// Iterate over all the sessions.
718-
for _, id := range sessionIDs {
719-
key, err := getKeyForID(sessionBkt, id)
720-
if err != nil {
721-
return err
722-
}
723-
724-
v := sessionBkt.Get(key)
725-
if len(v) == 0 {
726-
return ErrSessionNotFound
727-
}
728-
729-
session, err := DeserializeSession(bytes.NewReader(v))
730-
if err != nil {
731-
return err
732-
}
733-
734-
if !fn(session) {
735-
return errFailedPred
736-
}
737-
}
738-
739-
pass = true
740-
741-
return nil
742-
})
743-
if errors.Is(err, errFailedPred) {
744-
return pass, nil
745-
}
746-
if err != nil {
747-
return pass, err
748-
}
749-
750-
return pass, nil
751-
}
752-
753680
// getSessionIDs returns all the session IDs associated with the given group ID.
754681
func getSessionIDs(sessionBkt *bbolt.Bucket, groupID ID) ([]ID, error) {
755682
var sessionIDs []ID

0 commit comments

Comments
 (0)