Skip to content

Commit 72d32b0

Browse files
committed
session: remove Session Group Predicate method
This was used to check that all linked sessions are no longer active before attempting to register an autopilot session. But this is no longer needed since this is done within NewSession.
1 parent 1980990 commit 72d32b0

File tree

4 files changed

+2
-175
lines changed

4 files changed

+2
-175
lines changed

session/interface.go

-6
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,6 @@ type Store interface {
203203
// GetSessionByID fetches the session with the given ID.
204204
GetSessionByID(id ID) (*Session, error)
205205

206-
// CheckSessionGroupPredicate iterates over all the sessions in a group
207-
// and checks if each one passes the given predicate function. True is
208-
// returned if each session passes.
209-
CheckSessionGroupPredicate(groupID ID,
210-
fn func(s *Session) bool) (bool, error)
211-
212206
// DeleteReservedSessions deletes all sessions that are in the
213207
// StateReserved state.
214208
DeleteReservedSessions() error

session/kvdb_store.go

+2-60
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,8 @@ func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time,
250250

251251
// Ensure that the session is no longer active.
252252
if sess.State == StateCreated ||
253-
sess.State == StateInUse {
253+
sess.State == StateInUse ||
254+
sess.State == StateReserved {
254255

255256
return fmt.Errorf("session (id=%x) "+
256257
"in group %x is still active",
@@ -702,65 +703,6 @@ func (db *BoltStore) GetSessionIDs(groupID ID) ([]ID, error) {
702703
return sessionIDs, nil
703704
}
704705

705-
// CheckSessionGroupPredicate iterates over all the sessions in a group and
706-
// checks if each one passes the given predicate function. True is returned if
707-
// each session passes.
708-
//
709-
// NOTE: this is part of the Store interface.
710-
func (db *BoltStore) CheckSessionGroupPredicate(groupID ID,
711-
fn func(s *Session) bool) (bool, error) {
712-
713-
var (
714-
pass bool
715-
errFailedPred = errors.New("session failed predicate")
716-
)
717-
err := db.View(func(tx *bbolt.Tx) error {
718-
sessionBkt, err := getBucket(tx, sessionBucketKey)
719-
if err != nil {
720-
return err
721-
}
722-
723-
sessionIDs, err := getSessionIDs(sessionBkt, groupID)
724-
if err != nil {
725-
return err
726-
}
727-
728-
// Iterate over all the sessions.
729-
for _, id := range sessionIDs {
730-
key, err := getKeyForID(sessionBkt, id)
731-
if err != nil {
732-
return err
733-
}
734-
735-
v := sessionBkt.Get(key)
736-
if len(v) == 0 {
737-
return ErrSessionNotFound
738-
}
739-
740-
session, err := DeserializeSession(bytes.NewReader(v))
741-
if err != nil {
742-
return err
743-
}
744-
745-
if !fn(session) {
746-
return errFailedPred
747-
}
748-
}
749-
750-
pass = true
751-
752-
return nil
753-
})
754-
if errors.Is(err, errFailedPred) {
755-
return pass, nil
756-
}
757-
if err != nil {
758-
return pass, err
759-
}
760-
761-
return pass, nil
762-
}
763-
764706
// getSessionIDs returns all the session IDs associated with the given group ID.
765707
func getSessionIDs(sessionBkt *bbolt.Bucket, groupID ID) ([]ID, error) {
766708
var sessionIDs []ID

session/store_test.go

-92
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package session
22

33
import (
4-
"strings"
54
"testing"
65
"time"
76

@@ -288,97 +287,6 @@ func TestLinkedSessions(t *testing.T) {
288287
require.EqualValues(t, []ID{s4.ID, s5.ID}, sIDs)
289288
}
290289

291-
// TestCheckSessionGroupPredicate asserts that the CheckSessionGroupPredicate
292-
// method correctly checks if each session in a group passes a predicate.
293-
func TestCheckSessionGroupPredicate(t *testing.T) {
294-
t.Parallel()
295-
296-
// Set up a new DB.
297-
clock := clock.NewTestClock(testTime)
298-
db, err := NewDB(t.TempDir(), "test.db", clock)
299-
require.NoError(t, err)
300-
t.Cleanup(func() {
301-
_ = db.Close()
302-
})
303-
304-
// We will use the Label of the Session to test that the predicate
305-
// function is checked correctly.
306-
307-
// Add a new session to the DB.
308-
s1 := createSession(t, db, "label 1")
309-
310-
// Check that the group passes against an appropriate predicate.
311-
ok, err := db.CheckSessionGroupPredicate(
312-
s1.GroupID, func(s *Session) bool {
313-
return strings.Contains(s.Label, "label 1")
314-
},
315-
)
316-
require.NoError(t, err)
317-
require.True(t, ok)
318-
319-
// Check that the group fails against an appropriate predicate.
320-
ok, err = db.CheckSessionGroupPredicate(
321-
s1.GroupID, func(s *Session) bool {
322-
return strings.Contains(s.Label, "label 2")
323-
},
324-
)
325-
require.NoError(t, err)
326-
require.False(t, ok)
327-
328-
// Revoke the first session.
329-
require.NoError(t, db.RevokeSession(s1.LocalPublicKey))
330-
331-
// Add a new session to the same group as the first one.
332-
_ = createSession(t, db, "label 2", withLinkedGroupID(&s1.GroupID))
333-
334-
// Check that the group passes against an appropriate predicate.
335-
ok, err = db.CheckSessionGroupPredicate(
336-
s1.GroupID, func(s *Session) bool {
337-
return strings.Contains(s.Label, "label")
338-
},
339-
)
340-
require.NoError(t, err)
341-
require.True(t, ok)
342-
343-
// Check that the group fails against an appropriate predicate.
344-
ok, err = db.CheckSessionGroupPredicate(
345-
s1.GroupID, func(s *Session) bool {
346-
return strings.Contains(s.Label, "label 1")
347-
},
348-
)
349-
require.NoError(t, err)
350-
require.False(t, ok)
351-
352-
// Add a new session that is not linked to the first one.
353-
s3 := createSession(t, db, "completely different")
354-
355-
// Ensure that the first group is unaffected.
356-
ok, err = db.CheckSessionGroupPredicate(
357-
s1.GroupID, func(s *Session) bool {
358-
return strings.Contains(s.Label, "label")
359-
},
360-
)
361-
require.NoError(t, err)
362-
require.True(t, ok)
363-
364-
// And that the new session is evaluated separately.
365-
ok, err = db.CheckSessionGroupPredicate(
366-
s3.GroupID, func(s *Session) bool {
367-
return strings.Contains(s.Label, "label")
368-
},
369-
)
370-
require.NoError(t, err)
371-
require.False(t, ok)
372-
373-
ok, err = db.CheckSessionGroupPredicate(
374-
s3.GroupID, func(s *Session) bool {
375-
return strings.Contains(s.Label, "different")
376-
},
377-
)
378-
require.NoError(t, err)
379-
require.True(t, ok)
380-
}
381-
382290
type testSessionOpts struct {
383291
groupID *ID
384292
sessType Type

session_rpcserver.go

-17
Original file line numberDiff line numberDiff line change
@@ -857,23 +857,6 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
857857
"group %x", groupSess.ID, groupSess.GroupID)
858858
}
859859

860-
// Now we need to check that all the sessions in the group are
861-
// no longer active.
862-
ok, err := s.cfg.db.CheckSessionGroupPredicate(
863-
groupID, func(s *session.Session) bool {
864-
return s.State == session.StateRevoked ||
865-
s.State == session.StateExpired
866-
},
867-
)
868-
if err != nil {
869-
return nil, err
870-
}
871-
872-
if !ok {
873-
return nil, fmt.Errorf("a linked session in group "+
874-
"%x is still active", groupID)
875-
}
876-
877860
linkedGroupID = &groupID
878861
linkedGroupSession = groupSess
879862

0 commit comments

Comments
 (0)