Skip to content

Commit 374c93f

Browse files
committed
session: remove Session Group Predicate method
This was used to check that all linked sessions are no longer active before attempting to register an autopilot session. But this is no longer needed since this is done within NewSession.
1 parent f9ece32 commit 374c93f

File tree

4 files changed

+2
-175
lines changed

4 files changed

+2
-175
lines changed

session/interface.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -222,12 +222,6 @@ type Store interface {
222222
// GetSessionByID fetches the session with the given ID.
223223
GetSessionByID(id ID) (*Session, error)
224224

225-
// CheckSessionGroupPredicate iterates over all the sessions in a group
226-
// and checks if each one passes the given predicate function. True is
227-
// returned if each session passes.
228-
CheckSessionGroupPredicate(groupID ID,
229-
fn func(s *Session) bool) (bool, error)
230-
231225
// DeleteReservedSessions deletes all sessions that are in the
232226
// StateReserved state.
233227
DeleteReservedSessions() error

session/kvdb_store.go

Lines changed: 2 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,8 @@ func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time,
250250

251251
// Ensure that the session is no longer active.
252252
if sess.State == StateCreated ||
253-
sess.State == StateInUse {
253+
sess.State == StateInUse ||
254+
sess.State == StateReserved {
254255

255256
return fmt.Errorf("session (id=%x) "+
256257
"in group %x is still active",
@@ -721,65 +722,6 @@ func (db *BoltStore) GetSessionIDs(groupID ID) ([]ID, error) {
721722
return sessionIDs, nil
722723
}
723724

724-
// CheckSessionGroupPredicate iterates over all the sessions in a group and
725-
// checks if each one passes the given predicate function. True is returned if
726-
// each session passes.
727-
//
728-
// NOTE: this is part of the Store interface.
729-
func (db *BoltStore) CheckSessionGroupPredicate(groupID ID,
730-
fn func(s *Session) bool) (bool, error) {
731-
732-
var (
733-
pass bool
734-
errFailedPred = errors.New("session failed predicate")
735-
)
736-
err := db.View(func(tx *bbolt.Tx) error {
737-
sessionBkt, err := getBucket(tx, sessionBucketKey)
738-
if err != nil {
739-
return err
740-
}
741-
742-
sessionIDs, err := getSessionIDs(sessionBkt, groupID)
743-
if err != nil {
744-
return err
745-
}
746-
747-
// Iterate over all the sessions.
748-
for _, id := range sessionIDs {
749-
key, err := getKeyForID(sessionBkt, id)
750-
if err != nil {
751-
return err
752-
}
753-
754-
v := sessionBkt.Get(key)
755-
if len(v) == 0 {
756-
return ErrSessionNotFound
757-
}
758-
759-
session, err := DeserializeSession(bytes.NewReader(v))
760-
if err != nil {
761-
return err
762-
}
763-
764-
if !fn(session) {
765-
return errFailedPred
766-
}
767-
}
768-
769-
pass = true
770-
771-
return nil
772-
})
773-
if errors.Is(err, errFailedPred) {
774-
return pass, nil
775-
}
776-
if err != nil {
777-
return pass, err
778-
}
779-
780-
return pass, nil
781-
}
782-
783725
// getSessionIDs returns all the session IDs associated with the given group ID.
784726
func getSessionIDs(sessionBkt *bbolt.Bucket, groupID ID) ([]ID, error) {
785727
var sessionIDs []ID

session/store_test.go

Lines changed: 0 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package session
22

33
import (
4-
"strings"
54
"testing"
65
"time"
76

@@ -292,97 +291,6 @@ func TestLinkedSessions(t *testing.T) {
292291
require.EqualValues(t, []ID{s4.ID, s5.ID}, sIDs)
293292
}
294293

295-
// TestCheckSessionGroupPredicate asserts that the CheckSessionGroupPredicate
296-
// method correctly checks if each session in a group passes a predicate.
297-
func TestCheckSessionGroupPredicate(t *testing.T) {
298-
t.Parallel()
299-
300-
// Set up a new DB.
301-
clock := clock.NewTestClock(testTime)
302-
db, err := NewDB(t.TempDir(), "test.db", clock)
303-
require.NoError(t, err)
304-
t.Cleanup(func() {
305-
_ = db.Close()
306-
})
307-
308-
// We will use the Label of the Session to test that the predicate
309-
// function is checked correctly.
310-
311-
// Add a new session to the DB.
312-
s1 := createSession(t, db, "label 1")
313-
314-
// Check that the group passes against an appropriate predicate.
315-
ok, err := db.CheckSessionGroupPredicate(
316-
s1.GroupID, func(s *Session) bool {
317-
return strings.Contains(s.Label, "label 1")
318-
},
319-
)
320-
require.NoError(t, err)
321-
require.True(t, ok)
322-
323-
// Check that the group fails against an appropriate predicate.
324-
ok, err = db.CheckSessionGroupPredicate(
325-
s1.GroupID, func(s *Session) bool {
326-
return strings.Contains(s.Label, "label 2")
327-
},
328-
)
329-
require.NoError(t, err)
330-
require.False(t, ok)
331-
332-
// Revoke the first session.
333-
require.NoError(t, db.ShiftState(s1.LocalPublicKey, StateRevoked))
334-
335-
// Add a new session to the same group as the first one.
336-
_ = createSession(t, db, "label 2", withLinkedGroupID(&s1.GroupID))
337-
338-
// Check that the group passes against an appropriate predicate.
339-
ok, err = db.CheckSessionGroupPredicate(
340-
s1.GroupID, func(s *Session) bool {
341-
return strings.Contains(s.Label, "label")
342-
},
343-
)
344-
require.NoError(t, err)
345-
require.True(t, ok)
346-
347-
// Check that the group fails against an appropriate predicate.
348-
ok, err = db.CheckSessionGroupPredicate(
349-
s1.GroupID, func(s *Session) bool {
350-
return strings.Contains(s.Label, "label 1")
351-
},
352-
)
353-
require.NoError(t, err)
354-
require.False(t, ok)
355-
356-
// Add a new session that is not linked to the first one.
357-
s3 := createSession(t, db, "completely different")
358-
359-
// Ensure that the first group is unaffected.
360-
ok, err = db.CheckSessionGroupPredicate(
361-
s1.GroupID, func(s *Session) bool {
362-
return strings.Contains(s.Label, "label")
363-
},
364-
)
365-
require.NoError(t, err)
366-
require.True(t, ok)
367-
368-
// And that the new session is evaluated separately.
369-
ok, err = db.CheckSessionGroupPredicate(
370-
s3.GroupID, func(s *Session) bool {
371-
return strings.Contains(s.Label, "label")
372-
},
373-
)
374-
require.NoError(t, err)
375-
require.False(t, ok)
376-
377-
ok, err = db.CheckSessionGroupPredicate(
378-
s3.GroupID, func(s *Session) bool {
379-
return strings.Contains(s.Label, "different")
380-
},
381-
)
382-
require.NoError(t, err)
383-
require.True(t, ok)
384-
}
385-
386294
// TestStateShift tests that the ShiftState method works as expected.
387295
func TestStateShift(t *testing.T) {
388296
// Set up a new DB.

session_rpcserver.go

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -859,23 +859,6 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
859859
"group %x", groupSess.ID, groupSess.GroupID)
860860
}
861861

862-
// Now we need to check that all the sessions in the group are
863-
// no longer active.
864-
ok, err := s.cfg.db.CheckSessionGroupPredicate(
865-
groupID, func(s *session.Session) bool {
866-
return s.State == session.StateRevoked ||
867-
s.State == session.StateExpired
868-
},
869-
)
870-
if err != nil {
871-
return nil, err
872-
}
873-
874-
if !ok {
875-
return nil, fmt.Errorf("a linked session in group "+
876-
"%x is still active", groupID)
877-
}
878-
879862
linkedGroupID = &groupID
880863
linkedGroupSession = groupSess
881864

0 commit comments

Comments
 (0)