Skip to content

Commit bc53495

Browse files
committed
session: remove Session Group Predicate method
This was used to check that all linked sessions are no longer active before attempting to register an autopilot session. But this is no longer needed since this is done within NewSession.
1 parent ab75080 commit bc53495

File tree

4 files changed

+2
-175
lines changed

4 files changed

+2
-175
lines changed

session/interface.go

-6
Original file line numberDiff line numberDiff line change
@@ -218,12 +218,6 @@ type Store interface {
218218
// GetSessionByID fetches the session with the given ID.
219219
GetSessionByID(id ID) (*Session, error)
220220

221-
// CheckSessionGroupPredicate iterates over all the sessions in a group
222-
// and checks if each one passes the given predicate function. True is
223-
// returned if each session passes.
224-
CheckSessionGroupPredicate(groupID ID,
225-
fn func(s *Session) bool) (bool, error)
226-
227221
// DeleteReservedSessions deletes all sessions that are in the
228222
// StateReserved state.
229223
DeleteReservedSessions() error

session/kvdb_store.go

+2-60
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,8 @@ func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time,
250250

251251
// Ensure that the session is no longer active.
252252
if sess.State == StateCreated ||
253-
sess.State == StateInUse {
253+
sess.State == StateInUse ||
254+
sess.State == StateReserved {
254255

255256
return fmt.Errorf("session (id=%x) "+
256257
"in group %x is still active",
@@ -679,65 +680,6 @@ func (db *BoltStore) GetSessionIDs(groupID ID) ([]ID, error) {
679680
return sessionIDs, nil
680681
}
681682

682-
// CheckSessionGroupPredicate iterates over all the sessions in a group and
683-
// checks if each one passes the given predicate function. True is returned if
684-
// each session passes.
685-
//
686-
// NOTE: this is part of the Store interface.
687-
func (db *BoltStore) CheckSessionGroupPredicate(groupID ID,
688-
fn func(s *Session) bool) (bool, error) {
689-
690-
var (
691-
pass bool
692-
errFailedPred = errors.New("session failed predicate")
693-
)
694-
err := db.View(func(tx *bbolt.Tx) error {
695-
sessionBkt, err := getBucket(tx, sessionBucketKey)
696-
if err != nil {
697-
return err
698-
}
699-
700-
sessionIDs, err := getSessionIDs(sessionBkt, groupID)
701-
if err != nil {
702-
return err
703-
}
704-
705-
// Iterate over all the sessions.
706-
for _, id := range sessionIDs {
707-
key, err := getKeyForID(sessionBkt, id)
708-
if err != nil {
709-
return err
710-
}
711-
712-
v := sessionBkt.Get(key)
713-
if len(v) == 0 {
714-
return ErrSessionNotFound
715-
}
716-
717-
session, err := DeserializeSession(bytes.NewReader(v))
718-
if err != nil {
719-
return err
720-
}
721-
722-
if !fn(session) {
723-
return errFailedPred
724-
}
725-
}
726-
727-
pass = true
728-
729-
return nil
730-
})
731-
if errors.Is(err, errFailedPred) {
732-
return pass, nil
733-
}
734-
if err != nil {
735-
return pass, err
736-
}
737-
738-
return pass, nil
739-
}
740-
741683
// getSessionIDs returns all the session IDs associated with the given group ID.
742684
func getSessionIDs(sessionBkt *bbolt.Bucket, groupID ID) ([]ID, error) {
743685
var sessionIDs []ID

session/store_test.go

-92
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package session
22

33
import (
4-
"strings"
54
"testing"
65
"time"
76

@@ -291,97 +290,6 @@ func TestLinkedSessions(t *testing.T) {
291290
require.EqualValues(t, []ID{s4.ID, s5.ID}, sIDs)
292291
}
293292

294-
// TestCheckSessionGroupPredicate asserts that the CheckSessionGroupPredicate
295-
// method correctly checks if each session in a group passes a predicate.
296-
func TestCheckSessionGroupPredicate(t *testing.T) {
297-
t.Parallel()
298-
299-
// Set up a new DB.
300-
clock := clock.NewTestClock(testTime)
301-
db, err := NewDB(t.TempDir(), "test.db", clock)
302-
require.NoError(t, err)
303-
t.Cleanup(func() {
304-
_ = db.Close()
305-
})
306-
307-
// We will use the Label of the Session to test that the predicate
308-
// function is checked correctly.
309-
310-
// Add a new session to the DB.
311-
s1 := createSession(t, db, "label 1")
312-
313-
// Check that the group passes against an appropriate predicate.
314-
ok, err := db.CheckSessionGroupPredicate(
315-
s1.GroupID, func(s *Session) bool {
316-
return strings.Contains(s.Label, "label 1")
317-
},
318-
)
319-
require.NoError(t, err)
320-
require.True(t, ok)
321-
322-
// Check that the group fails against an appropriate predicate.
323-
ok, err = db.CheckSessionGroupPredicate(
324-
s1.GroupID, func(s *Session) bool {
325-
return strings.Contains(s.Label, "label 2")
326-
},
327-
)
328-
require.NoError(t, err)
329-
require.False(t, ok)
330-
331-
// Revoke the first session.
332-
require.NoError(t, db.ShiftState(s1.ID, StateRevoked))
333-
334-
// Add a new session to the same group as the first one.
335-
_ = createSession(t, db, "label 2", withLinkedGroupID(&s1.GroupID))
336-
337-
// Check that the group passes against an appropriate predicate.
338-
ok, err = db.CheckSessionGroupPredicate(
339-
s1.GroupID, func(s *Session) bool {
340-
return strings.Contains(s.Label, "label")
341-
},
342-
)
343-
require.NoError(t, err)
344-
require.True(t, ok)
345-
346-
// Check that the group fails against an appropriate predicate.
347-
ok, err = db.CheckSessionGroupPredicate(
348-
s1.GroupID, func(s *Session) bool {
349-
return strings.Contains(s.Label, "label 1")
350-
},
351-
)
352-
require.NoError(t, err)
353-
require.False(t, ok)
354-
355-
// Add a new session that is not linked to the first one.
356-
s3 := createSession(t, db, "completely different")
357-
358-
// Ensure that the first group is unaffected.
359-
ok, err = db.CheckSessionGroupPredicate(
360-
s1.GroupID, func(s *Session) bool {
361-
return strings.Contains(s.Label, "label")
362-
},
363-
)
364-
require.NoError(t, err)
365-
require.True(t, ok)
366-
367-
// And that the new session is evaluated separately.
368-
ok, err = db.CheckSessionGroupPredicate(
369-
s3.GroupID, func(s *Session) bool {
370-
return strings.Contains(s.Label, "label")
371-
},
372-
)
373-
require.NoError(t, err)
374-
require.False(t, ok)
375-
376-
ok, err = db.CheckSessionGroupPredicate(
377-
s3.GroupID, func(s *Session) bool {
378-
return strings.Contains(s.Label, "different")
379-
},
380-
)
381-
require.NoError(t, err)
382-
require.True(t, ok)
383-
}
384-
385293
// TestStateShift tests that the ShiftState method works as expected.
386294
func TestStateShift(t *testing.T) {
387295
// Set up a new DB.

session_rpcserver.go

-17
Original file line numberDiff line numberDiff line change
@@ -874,23 +874,6 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
874874
"group %x", groupSess.ID, groupSess.GroupID)
875875
}
876876

877-
// Now we need to check that all the sessions in the group are
878-
// no longer active.
879-
ok, err := s.cfg.db.CheckSessionGroupPredicate(
880-
groupID, func(s *session.Session) bool {
881-
return s.State == session.StateRevoked ||
882-
s.State == session.StateExpired
883-
},
884-
)
885-
if err != nil {
886-
return nil, err
887-
}
888-
889-
if !ok {
890-
return nil, fmt.Errorf("a linked session in group "+
891-
"%x is still active", groupID)
892-
}
893-
894877
linkedGroupID = &groupID
895878
linkedGroupSession = groupSess
896879

0 commit comments

Comments
 (0)