Skip to content

[sql-14] sessions: atomic session creation #980

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 17 additions & 36 deletions session/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ const (
type State uint8

/*
/---> StateExpired (terminal)
StateCreated ---
\---> StateRevoked (terminal)
/---> StateExpired (terminal)
StateReserved ---> StateCreated ---
\---> StateRevoked (terminal)
*/

const (
// StateCreated is the state of a session once it has been fully
// committed to the Store and is ready to be used. This is the first
// state of a session.
// committed to the BoltStore and is ready to be used. This is the
// first state after StateReserved.
StateCreated State = 0

// StateInUse is the state of a session that is currently being used.
Expand All @@ -52,10 +52,10 @@ const (
// date.
StateExpired State = 3

// StateReserved is a temporary initial state of a session. On start-up,
// any sessions in this state should be cleaned up.
//
// NOTE: this isn't used yet.
// StateReserved is a temporary initial state of a session. This is used
// to reserve a unique ID and private key pair for a session before it
// is fully created. On start-up, any sessions in this state should be
// cleaned up.
StateReserved State = 4
)

Expand All @@ -67,6 +67,9 @@ func (s State) Terminal() bool {
// legalStateShifts is a map that defines the legal State transitions that a
// Session can be put through.
var legalStateShifts = map[State]map[State]bool{
StateReserved: {
StateCreated: true,
},
StateCreated: {
StateExpired: true,
StateRevoked: true,
Expand Down Expand Up @@ -141,7 +144,7 @@ func buildSession(id ID, localPrivKey *btcec.PrivateKey, label string, typ Type,
sess := &Session{
ID: id,
Label: label,
State: StateCreated,
State: StateReserved,
Type: typ,
Expiry: expiry.UTC(),
CreatedAt: created.UTC(),
Expand Down Expand Up @@ -185,23 +188,13 @@ type IDToGroupIndex interface {
// retrieving Terminal Connect sessions.
type Store interface {
// NewSession creates a new session with the given user-defined
// parameters.
//
// NOTE: currently this purely a constructor of the Session type and
// does not make any database calls. This will be changed in a future
// commit.
NewSession(id ID, localPrivKey *btcec.PrivateKey, label string,
typ Type, expiry time.Time, serverAddr string, devServer bool,
perms []bakery.Op, caveats []macaroon.Caveat,
// parameters. The session will remain in the StateReserved state until
// ShiftState is called to update the state.
NewSession(label string, typ Type, expiry time.Time, serverAddr string,
devServer bool, perms []bakery.Op, caveats []macaroon.Caveat,
featureConfig FeaturesConfig, privacy bool, linkedGroupID *ID,
flags PrivacyFlags) (*Session, error)

// CreateSession adds a new session to the store. If a session with the
// same local public key already exists an error is returned. This
// can only be called with a Session with an ID that the Store has
// reserved.
CreateSession(*Session) error

// GetSession fetches the session with the given key.
GetSession(key *btcec.PublicKey) (*Session, error)

Expand All @@ -220,21 +213,9 @@ type Store interface {
UpdateSessionRemotePubKey(localPubKey,
remotePubKey *btcec.PublicKey) error

// GetUnusedIDAndKeyPair can be used to generate a new, unused, local
// private key and session ID pair. Care must be taken to ensure that no
// other thread calls this before the returned ID and key pair from this
// method are either used or discarded.
GetUnusedIDAndKeyPair() (ID, *btcec.PrivateKey, error)

// GetSessionByID fetches the session with the given ID.
GetSessionByID(id ID) (*Session, error)

// CheckSessionGroupPredicate iterates over all the sessions in a group
// and checks if each one passes the given predicate function. True is
// returned if each session passes.
CheckSessionGroupPredicate(groupID ID,
fn func(s *Session) bool) (bool, error)

// DeleteReservedSessions deletes all sessions that are in the
// StateReserved state.
DeleteReservedSessions() error
Expand Down
179 changes: 53 additions & 126 deletions session/kvdb_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,41 +182,42 @@ func getSessionKey(session *Session) []byte {
return session.LocalPublicKey.SerializeCompressed()
}

// NewSession creates a new session with the given user-defined parameters.
//
// NOTE: currently this purely a constructor of the Session type and does not
// make any database calls. This will be changed in a future commit.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) NewSession(id ID, localPrivKey *btcec.PrivateKey,
label string, typ Type, expiry time.Time, serverAddr string,
devServer bool, perms []bakery.Op, caveats []macaroon.Caveat,
featureConfig FeaturesConfig, privacy bool, linkedGroupID *ID,
flags PrivacyFlags) (*Session, error) {

return buildSession(
id, localPrivKey, label, typ, db.clock.Now(), expiry,
serverAddr, devServer, perms, caveats, featureConfig, privacy,
linkedGroupID, flags,
)
}

// CreateSession adds a new session to the store. If a session with the same
// local public key already exists an error is returned.
// NewSession creates and persists a new session with the given user-defined
// parameters. The initial state of the session will be Reserved until
// ShiftState is called with StateCreated.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) CreateSession(session *Session) error {
sessionKey := getSessionKey(session)
func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time,
serverAddr string, devServer bool, perms []bakery.Op,
caveats []macaroon.Caveat, featureConfig FeaturesConfig, privacy bool,
linkedGroupID *ID, flags PrivacyFlags) (*Session, error) {

return db.Update(func(tx *bbolt.Tx) error {
var session *Session
err := db.Update(func(tx *bbolt.Tx) error {
sessionBucket, err := getBucket(tx, sessionBucketKey)
if err != nil {
return err
}

id, localPrivKey, err := getUnusedIDAndKeyPair(sessionBucket)
if err != nil {
return err
}

session, err = buildSession(
id, localPrivKey, label, typ, db.clock.Now(), expiry,
serverAddr, devServer, perms, caveats, featureConfig,
privacy, linkedGroupID, flags,
)
if err != nil {
return err
}

sessionKey := getSessionKey(session)

if len(sessionBucket.Get(sessionKey)) != 0 {
return fmt.Errorf("session with local public "+
"key(%x) already exists",
return fmt.Errorf("session with local public key(%x) "+
"already exists",
session.LocalPublicKey.SerializeCompressed())
}

Expand Down Expand Up @@ -248,9 +249,7 @@ func (db *BoltStore) CreateSession(session *Session) error {
}

// Ensure that the session is no longer active.
if sess.State == StateCreated ||
sess.State == StateInUse {

if !sess.State.Terminal() {
return fmt.Errorf("session (id=%x) "+
"in group %x is still active",
sess.ID, sess.GroupID)
Expand All @@ -275,6 +274,11 @@ func (db *BoltStore) CreateSession(session *Session) error {

return putSession(sessionBucket, session)
})
if err != nil {
return nil, err
}

return session, nil
}

// UpdateSessionRemotePubKey can be used to add the given remote pub key
Expand Down Expand Up @@ -577,53 +581,35 @@ func (db *BoltStore) GetSessionByID(id ID) (*Session, error) {
return session, nil
}

// GetUnusedIDAndKeyPair can be used to generate a new, unused, local private
// getUnusedIDAndKeyPair can be used to generate a new, unused, local private
// key and session ID pair. Care must be taken to ensure that no other thread
// calls this before the returned ID and key pair from this method are either
// used or discarded.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) GetUnusedIDAndKeyPair() (ID, *btcec.PrivateKey, error) {
var (
id ID
privKey *btcec.PrivateKey
)
err := db.Update(func(tx *bbolt.Tx) error {
sessionBucket, err := getBucket(tx, sessionBucketKey)
if err != nil {
return err
}

idIndexBkt := sessionBucket.Bucket(idIndexKey)
if idIndexBkt == nil {
return ErrDBInitErr
}
func getUnusedIDAndKeyPair(bucket *bbolt.Bucket) (ID, *btcec.PrivateKey,
error) {

// Spin until we find a key with an ID that does not collide
// with any of our existing IDs.
for {
// Generate a new private key and ID pair.
privKey, id, err = NewSessionPrivKeyAndID()
if err != nil {
return err
}
idIndexBkt := bucket.Bucket(idIndexKey)
if idIndexBkt == nil {
return ID{}, nil, ErrDBInitErr
}

// Check that no such ID exits in our id-to-key index.
idBkt := idIndexBkt.Bucket(id[:])
if idBkt != nil {
continue
}
// Spin until we find a key with an ID that does not collide with any of
// our existing IDs.
for {
// Generate a new private key and ID pair.
privKey, id, err := NewSessionPrivKeyAndID()
if err != nil {
return ID{}, nil, err
}

break
// Check that no such ID exits in our id-to-key index.
idBkt := idIndexBkt.Bucket(id[:])
if idBkt != nil {
continue
}

return nil
})
if err != nil {
return id, nil, err
return id, privKey, nil
}

return id, privKey, nil
}

// GetGroupID will return the group ID for the given session ID.
Expand Down Expand Up @@ -691,65 +677,6 @@ func (db *BoltStore) GetSessionIDs(groupID ID) ([]ID, error) {
return sessionIDs, nil
}

// CheckSessionGroupPredicate iterates over all the sessions in a group and
// checks if each one passes the given predicate function. True is returned if
// each session passes.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) CheckSessionGroupPredicate(groupID ID,
fn func(s *Session) bool) (bool, error) {

var (
pass bool
errFailedPred = errors.New("session failed predicate")
)
err := db.View(func(tx *bbolt.Tx) error {
sessionBkt, err := getBucket(tx, sessionBucketKey)
if err != nil {
return err
}

sessionIDs, err := getSessionIDs(sessionBkt, groupID)
if err != nil {
return err
}

// Iterate over all the sessions.
for _, id := range sessionIDs {
key, err := getKeyForID(sessionBkt, id)
if err != nil {
return err
}

v := sessionBkt.Get(key)
if len(v) == 0 {
return ErrSessionNotFound
}

session, err := DeserializeSession(bytes.NewReader(v))
if err != nil {
return err
}

if !fn(session) {
return errFailedPred
}
}

pass = true

return nil
})
if errors.Is(err, errFailedPred) {
return pass, nil
}
if err != nil {
return pass, err
}

return pass, nil
}

// getSessionIDs returns all the session IDs associated with the given group ID.
func getSessionIDs(sessionBkt *bbolt.Bucket, groupID ID) ([]ID, error) {
var sessionIDs []ID
Expand Down
Loading
Loading