diff --git a/itest/litd_firewall_test.go b/itest/litd_firewall_test.go index 98b583132..5c82bd4b3 100644 --- a/itest/litd_firewall_test.go +++ b/itest/litd_firewall_test.go @@ -866,7 +866,9 @@ func testSessionLinking(net *NetworkHarness, t *harnessTest) { LinkedGroupId: sessResp.Session.GroupId, }, ) - require.ErrorContains(t.t, err, "is still active") + require.ErrorContains( + t.t, err, session.ErrSessionsInGroupStillActive.Error(), + ) // Revoke the previous one and repeat. _, err = litAutopilotClient.RevokeAutopilotSession( diff --git a/session/errors.go b/session/errors.go index 560a6c2bc..1cb97a8f2 100644 --- a/session/errors.go +++ b/session/errors.go @@ -6,4 +6,16 @@ var ( // ErrSessionNotFound is an error returned when we attempt to retrieve // information about a session but it is not found. ErrSessionNotFound = errors.New("session not found") + + // ErrUnknownGroup is returned when an attempt is made to insert a + // session and link it to an existing group where the group is not + // known. + ErrUnknownGroup = errors.New("unknown group") + + // ErrSessionsInGroupStillActive is returned when an attempt is made to + // insert a session and link it to a group that still has other active + // sessions. + ErrSessionsInGroupStillActive = errors.New( + "group has active sessions", + ) ) diff --git a/session/kvdb_store.go b/session/kvdb_store.go index 84f4dce06..96c98d29c 100644 --- a/session/kvdb_store.go +++ b/session/kvdb_store.go @@ -229,8 +229,9 @@ func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time, if session.ID != session.GroupID { _, err = getKeyForID(sessionBucket, session.GroupID) if err != nil { - return fmt.Errorf("unknown linked session "+ - "%x: %w", session.GroupID, err) + return fmt.Errorf("%w: unknown linked "+ + "session %x: %w", ErrUnknownGroup, + session.GroupID, err) } // Fetch all the session IDs for this group. This will @@ -242,18 +243,22 @@ func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time, return err } + // Ensure that the all the linked sessions are no longer + // active. for _, id := range sessionIDs { sess, err := getSessionByID(sessionBucket, id) if err != nil { return err } - // Ensure that the session is no longer active. - if !sess.State.Terminal() { - return fmt.Errorf("session (id=%x) "+ - "in group %x is still active", - sess.ID, sess.GroupID) + if sess.State.Terminal() { + continue } + + return fmt.Errorf("%w: session (id=%x) in "+ + "group %x is still active", + ErrSessionsInGroupStillActive, sess.ID, + sess.GroupID) } } @@ -630,14 +635,14 @@ func (db *BoltStore) GetGroupID(sessionID ID) (ID, error) { sessionIDBkt := idIndex.Bucket(sessionID[:]) if sessionIDBkt == nil { - return fmt.Errorf("no index entry for session ID: %x", - sessionID) + return fmt.Errorf("%w: no index entry for session "+ + "ID: %x", ErrUnknownGroup, sessionID) } groupIDBytes := sessionIDBkt.Get(groupIDKey) if len(groupIDBytes) == 0 { - return fmt.Errorf("group ID not found for session "+ - "ID %x", sessionID) + return fmt.Errorf("%w: group ID not found for "+ + "session ID %x", ErrUnknownGroup, sessionID) } copy(groupID[:], groupIDBytes) @@ -806,7 +811,7 @@ func addIDToGroupIDPair(sessionBkt *bbolt.Bucket, id, groupID ID) error { func getSessionByID(bucket *bbolt.Bucket, id ID) (*Session, error) { keyBytes, err := getKeyForID(bucket, id) if err != nil { - return nil, err + return nil, fmt.Errorf("%w: %w", ErrSessionNotFound, err) } v := bucket.Get(keyBytes) diff --git a/session/store_test.go b/session/store_test.go index aa7afe20f..966b03962 100644 --- a/session/store_test.go +++ b/session/store_test.go @@ -16,11 +16,11 @@ var testTime = time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) func TestBasicSessionStore(t *testing.T) { // Set up a new DB. clock := clock.NewTestClock(testTime) - db, err := NewDB(t.TempDir(), "test.db", clock) - require.NoError(t, err) - t.Cleanup(func() { - _ = db.Close() - }) + db := NewTestDB(t, clock) + + // Try fetch a session that doesn't exist yet. + _, err := db.GetSessionByID(ID{1, 3, 4, 4}) + require.ErrorIs(t, err, ErrSessionNotFound) // Reserve a session. This should succeed. s1, err := reserveSession(db, "session 1") @@ -183,7 +183,7 @@ func TestBasicSessionStore(t *testing.T) { require.Empty(t, sessions) _, err = db.GetGroupID(s4.ID) - require.ErrorContains(t, err, "no index entry") + require.ErrorIs(t, err, ErrUnknownGroup) // Only session 1 should remain in this group. sessIDs, err = db.GetSessionIDs(s4.GroupID) @@ -197,11 +197,7 @@ func TestLinkingSessions(t *testing.T) { // Set up a new DB. clock := clock.NewTestClock(testTime) - db, err := NewDB(t.TempDir(), "test.db", clock) - require.NoError(t, err) - t.Cleanup(func() { - _ = db.Close() - }) + db := NewTestDB(t, clock) groupID, err := IDFromBytes([]byte{1, 2, 3, 4}) require.NoError(t, err) @@ -211,7 +207,7 @@ func TestLinkingSessions(t *testing.T) { _, err = reserveSession( db, "session 2", withLinkedGroupID(&groupID), ) - require.ErrorContains(t, err, "unknown linked session") + require.ErrorIs(t, err, ErrUnknownGroup) // Create a new session with no previous link. s1 := createSession(t, db, "session 1") @@ -220,7 +216,7 @@ func TestLinkingSessions(t *testing.T) { // session. This should fail due to the first session still being // active. _, err = reserveSession(db, "session 2", withLinkedGroupID(&s1.GroupID)) - require.ErrorContains(t, err, "is still active") + require.ErrorIs(t, err, ErrSessionsInGroupStillActive) // Revoke the first session. require.NoError(t, db.ShiftState(s1.ID, StateRevoked)) @@ -238,11 +234,7 @@ func TestLinkedSessions(t *testing.T) { // Set up a new DB. clock := clock.NewTestClock(testTime) - db, err := NewDB(t.TempDir(), "test.db", clock) - require.NoError(t, err) - t.Cleanup(func() { - _ = db.Close() - }) + db := NewTestDB(t, clock) // Create a few sessions. The first one is a new session and the two // after are all linked to the prior one. All these sessions belong to @@ -294,18 +286,14 @@ func TestLinkedSessions(t *testing.T) { func TestStateShift(t *testing.T) { // Set up a new DB. clock := clock.NewTestClock(testTime) - db, err := NewDB(t.TempDir(), "test.db", clock) - require.NoError(t, err) - t.Cleanup(func() { - _ = db.Close() - }) + db := NewTestDB(t, clock) // Add a new session to the DB. s1 := createSession(t, db, "label 1") // Check that the session is in the StateCreated state. Also check that // the "RevokedAt" time has not yet been set. - s1, err = db.GetSession(s1.LocalPublicKey) + s1, err := db.GetSession(s1.LocalPublicKey) require.NoError(t, err) require.Equal(t, StateCreated, s1.State) require.Equal(t, time.Time{}, s1.RevokedAt) diff --git a/session/test_kvdb.go b/session/test_kvdb.go new file mode 100644 index 000000000..6f270d617 --- /dev/null +++ b/session/test_kvdb.go @@ -0,0 +1,28 @@ +package session + +import ( + "testing" + + "github.com/lightningnetwork/lnd/clock" + "github.com/stretchr/testify/require" +) + +// NewTestDB is a helper function that creates an BBolt database for testing. +func NewTestDB(t *testing.T, clock clock.Clock) *BoltStore { + return NewTestDBFromPath(t, t.TempDir(), clock) +} + +// NewTestDBFromPath is a helper function that creates a new BoltStore with a +// connection to an existing BBolt database for testing. +func NewTestDBFromPath(t *testing.T, dbPath string, + clock clock.Clock) *BoltStore { + + store, err := NewDB(dbPath, DBFilename, clock) + require.NoError(t, err) + + t.Cleanup(func() { + require.NoError(t, store.DB.Close()) + }) + + return store +}