Skip to content

[sql-17] sessions: test preparation #988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion itest/litd_firewall_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,9 @@ func testSessionLinking(net *NetworkHarness, t *harnessTest) {
LinkedGroupId: sessResp.Session.GroupId,
},
)
require.ErrorContains(t.t, err, "is still active")
require.ErrorContains(
t.t, err, session.ErrSessionsInGroupStillActive.Error(),
)

// Revoke the previous one and repeat.
_, err = litAutopilotClient.RevokeAutopilotSession(
Expand Down
12 changes: 12 additions & 0 deletions session/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,16 @@ var (
// ErrSessionNotFound is an error returned when we attempt to retrieve
// information about a session but it is not found.
ErrSessionNotFound = errors.New("session not found")

// ErrUnknownGroup is returned when an attempt is made to insert a
// session and link it to an existing group where the group is not
// known.
ErrUnknownGroup = errors.New("unknown group")

// ErrSessionsInGroupStillActive is returned when an attempt is made to
// insert a session and link it to a group that still has other active
// sessions.
ErrSessionsInGroupStillActive = errors.New(
"group has active sessions",
)
)
29 changes: 17 additions & 12 deletions session/kvdb_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,9 @@ func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time,
if session.ID != session.GroupID {
_, err = getKeyForID(sessionBucket, session.GroupID)
if err != nil {
return fmt.Errorf("unknown linked session "+
"%x: %w", session.GroupID, err)
return fmt.Errorf("%w: unknown linked "+
"session %x: %w", ErrUnknownGroup,
session.GroupID, err)
}

// Fetch all the session IDs for this group. This will
Expand All @@ -242,18 +243,22 @@ func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time,
return err
}

// Ensure that the all the linked sessions are no longer
// active.
for _, id := range sessionIDs {
sess, err := getSessionByID(sessionBucket, id)
if err != nil {
return err
}

// Ensure that the session is no longer active.
if !sess.State.Terminal() {
return fmt.Errorf("session (id=%x) "+
"in group %x is still active",
sess.ID, sess.GroupID)
if sess.State.Terminal() {
continue
}

return fmt.Errorf("%w: session (id=%x) in "+
"group %x is still active",
ErrSessionsInGroupStillActive, sess.ID,
sess.GroupID)
}
}

Expand Down Expand Up @@ -630,14 +635,14 @@ func (db *BoltStore) GetGroupID(sessionID ID) (ID, error) {

sessionIDBkt := idIndex.Bucket(sessionID[:])
if sessionIDBkt == nil {
return fmt.Errorf("no index entry for session ID: %x",
sessionID)
return fmt.Errorf("%w: no index entry for session "+
"ID: %x", ErrUnknownGroup, sessionID)
}

groupIDBytes := sessionIDBkt.Get(groupIDKey)
if len(groupIDBytes) == 0 {
return fmt.Errorf("group ID not found for session "+
"ID %x", sessionID)
return fmt.Errorf("%w: group ID not found for "+
"session ID %x", ErrUnknownGroup, sessionID)
}

copy(groupID[:], groupIDBytes)
Expand Down Expand Up @@ -806,7 +811,7 @@ func addIDToGroupIDPair(sessionBkt *bbolt.Bucket, id, groupID ID) error {
func getSessionByID(bucket *bbolt.Bucket, id ID) (*Session, error) {
keyBytes, err := getKeyForID(bucket, id)
if err != nil {
return nil, err
return nil, fmt.Errorf("%w: %w", ErrSessionNotFound, err)
}

v := bucket.Get(keyBytes)
Expand Down
36 changes: 12 additions & 24 deletions session/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ var testTime = time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
func TestBasicSessionStore(t *testing.T) {
// Set up a new DB.
clock := clock.NewTestClock(testTime)
db, err := NewDB(t.TempDir(), "test.db", clock)
require.NoError(t, err)
t.Cleanup(func() {
_ = db.Close()
})
db := NewTestDB(t, clock)

// Try fetch a session that doesn't exist yet.
_, err := db.GetSessionByID(ID{1, 3, 4, 4})
require.ErrorIs(t, err, ErrSessionNotFound)

// Reserve a session. This should succeed.
s1, err := reserveSession(db, "session 1")
Expand Down Expand Up @@ -183,7 +183,7 @@ func TestBasicSessionStore(t *testing.T) {
require.Empty(t, sessions)

_, err = db.GetGroupID(s4.ID)
require.ErrorContains(t, err, "no index entry")
require.ErrorIs(t, err, ErrUnknownGroup)

// Only session 1 should remain in this group.
sessIDs, err = db.GetSessionIDs(s4.GroupID)
Expand All @@ -197,11 +197,7 @@ func TestLinkingSessions(t *testing.T) {

// Set up a new DB.
clock := clock.NewTestClock(testTime)
db, err := NewDB(t.TempDir(), "test.db", clock)
require.NoError(t, err)
t.Cleanup(func() {
_ = db.Close()
})
db := NewTestDB(t, clock)

groupID, err := IDFromBytes([]byte{1, 2, 3, 4})
require.NoError(t, err)
Expand All @@ -211,7 +207,7 @@ func TestLinkingSessions(t *testing.T) {
_, err = reserveSession(
db, "session 2", withLinkedGroupID(&groupID),
)
require.ErrorContains(t, err, "unknown linked session")
require.ErrorIs(t, err, ErrUnknownGroup)

// Create a new session with no previous link.
s1 := createSession(t, db, "session 1")
Expand All @@ -220,7 +216,7 @@ func TestLinkingSessions(t *testing.T) {
// session. This should fail due to the first session still being
// active.
_, err = reserveSession(db, "session 2", withLinkedGroupID(&s1.GroupID))
require.ErrorContains(t, err, "is still active")
require.ErrorIs(t, err, ErrSessionsInGroupStillActive)

// Revoke the first session.
require.NoError(t, db.ShiftState(s1.ID, StateRevoked))
Expand All @@ -238,11 +234,7 @@ func TestLinkedSessions(t *testing.T) {

// Set up a new DB.
clock := clock.NewTestClock(testTime)
db, err := NewDB(t.TempDir(), "test.db", clock)
require.NoError(t, err)
t.Cleanup(func() {
_ = db.Close()
})
db := NewTestDB(t, clock)

// Create a few sessions. The first one is a new session and the two
// after are all linked to the prior one. All these sessions belong to
Expand Down Expand Up @@ -294,18 +286,14 @@ func TestLinkedSessions(t *testing.T) {
func TestStateShift(t *testing.T) {
// Set up a new DB.
clock := clock.NewTestClock(testTime)
db, err := NewDB(t.TempDir(), "test.db", clock)
require.NoError(t, err)
t.Cleanup(func() {
_ = db.Close()
})
db := NewTestDB(t, clock)

// Add a new session to the DB.
s1 := createSession(t, db, "label 1")

// Check that the session is in the StateCreated state. Also check that
// the "RevokedAt" time has not yet been set.
s1, err = db.GetSession(s1.LocalPublicKey)
s1, err := db.GetSession(s1.LocalPublicKey)
require.NoError(t, err)
require.Equal(t, StateCreated, s1.State)
require.Equal(t, time.Time{}, s1.RevokedAt)
Expand Down
28 changes: 28 additions & 0 deletions session/test_kvdb.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package session

import (
"testing"

"github.com/lightningnetwork/lnd/clock"
"github.com/stretchr/testify/require"
)

// NewTestDB is a helper function that creates an BBolt database for testing.
func NewTestDB(t *testing.T, clock clock.Clock) *BoltStore {
return NewTestDBFromPath(t, t.TempDir(), clock)
}

// NewTestDBFromPath is a helper function that creates a new BoltStore with a
// connection to an existing BBolt database for testing.
func NewTestDBFromPath(t *testing.T, dbPath string,
clock clock.Clock) *BoltStore {

store, err := NewDB(dbPath, DBFilename, clock)
require.NoError(t, err)

t.Cleanup(func() {
require.NoError(t, store.DB.Close())
})

return store
}
Loading