diff --git a/firewalldb/kvstores_kvdb.go b/firewalldb/kvstores_kvdb.go index a7b9d4765..51721d475 100644 --- a/firewalldb/kvstores_kvdb.go +++ b/firewalldb/kvstores_kvdb.go @@ -60,7 +60,8 @@ func (db *BoltDB) GetKVStores(rule string, groupID session.ID, db: db.DB, wrapTx: func(tx *bbolt.Tx) KVStoreTx { return &kvStoreTx{ - boltTx: tx, + boltTx: tx, + sessions: db.sessionIDIndex, kvStores: &kvStores{ ruleName: rule, groupID: groupID, @@ -109,6 +110,7 @@ type getBucketFunc func(tx *bbolt.Tx, create bool) (*bbolt.Bucket, error) type kvStoreTx struct { boltTx *bbolt.Tx getBucket getBucketFunc + sessions session.IDToGroupIndex *kvStores } @@ -116,11 +118,12 @@ type kvStoreTx struct { // Global gives the caller access to the global kv store of the rule. // // NOTE: this is part of the rules.KVStoreTx interface. -func (tx *kvStoreTx) Global() KVStore { +func (s *kvStoreTx) Global() KVStore { return &kvStoreTx{ - kvStores: tx.kvStores, - boltTx: tx.boltTx, - getBucket: getGlobalRuleBucket(true, tx.ruleName), + kvStores: s.kvStores, + boltTx: s.boltTx, + sessions: s.sessions, + getBucket: getGlobalRuleBucket(true, s.ruleName), } } @@ -129,17 +132,16 @@ func (tx *kvStoreTx) Global() KVStore { // how the kv store was initialised. // // NOTE: this is part of the KVStoreTx interface. -func (tx *kvStoreTx) Local() KVStore { - fn := getSessionRuleBucket(true, tx.ruleName, tx.groupID) - if tx.featureName != "" { - fn = getSessionFeatureRuleBucket( - true, tx.ruleName, tx.groupID, tx.featureName, - ) +func (s *kvStoreTx) Local() KVStore { + fn := s.getSessionRuleBucket(true) + if s.featureName != "" { + fn = s.getSessionFeatureRuleBucket(true) } return &kvStoreTx{ - kvStores: tx.kvStores, - boltTx: tx.boltTx, + kvStores: s.kvStores, + boltTx: s.boltTx, + sessions: s.sessions, getBucket: fn, } } @@ -148,11 +150,12 @@ func (tx *kvStoreTx) Local() KVStore { // rule. // // NOTE: this is part of the KVStoreTx interface. -func (tx *kvStoreTx) GlobalTemp() KVStore { +func (s *kvStoreTx) GlobalTemp() KVStore { return &kvStoreTx{ - kvStores: tx.kvStores, - boltTx: tx.boltTx, - getBucket: getGlobalRuleBucket(false, tx.ruleName), + kvStores: s.kvStores, + boltTx: s.boltTx, + sessions: s.sessions, + getBucket: getGlobalRuleBucket(false, s.ruleName), } } @@ -160,17 +163,16 @@ func (tx *kvStoreTx) GlobalTemp() KVStore { // rule. // // NOTE: this is part of the KVStoreTx interface. -func (tx *kvStoreTx) LocalTemp() KVStore { - fn := getSessionRuleBucket(false, tx.ruleName, tx.groupID) - if tx.featureName != "" { - fn = getSessionFeatureRuleBucket( - false, tx.ruleName, tx.groupID, tx.featureName, - ) +func (s *kvStoreTx) LocalTemp() KVStore { + fn := s.getSessionRuleBucket(false) + if s.featureName != "" { + fn = s.getSessionFeatureRuleBucket(false) } return &kvStoreTx{ - kvStores: tx.kvStores, - boltTx: tx.boltTx, + kvStores: s.kvStores, + boltTx: s.boltTx, + sessions: s.sessions, getBucket: fn, } } @@ -179,8 +181,8 @@ func (tx *kvStoreTx) LocalTemp() KVStore { // If no value is found, nil is returned. // // NOTE: this is part of the KVStore interface. -func (tx *kvStoreTx) Get(_ context.Context, key string) ([]byte, error) { - bucket, err := tx.getBucket(tx.boltTx, false) +func (s *kvStoreTx) Get(_ context.Context, key string) ([]byte, error) { + bucket, err := s.getBucket(s.boltTx, false) if err != nil { return nil, err } @@ -194,8 +196,8 @@ func (tx *kvStoreTx) Get(_ context.Context, key string) ([]byte, error) { // Set sets the given key-value pair in the underlying kv store. // // NOTE: this is part of the KVStore interface. -func (tx *kvStoreTx) Set(_ context.Context, key string, value []byte) error { - bucket, err := tx.getBucket(tx.boltTx, true) +func (s *kvStoreTx) Set(_ context.Context, key string, value []byte) error { + bucket, err := s.getBucket(s.boltTx, true) if err != nil { return err } @@ -206,8 +208,8 @@ func (tx *kvStoreTx) Set(_ context.Context, key string, value []byte) error { // Del deletes the value under the given key in the underlying kv store. // // NOTE: this is part of the .KVStore interface. -func (tx *kvStoreTx) Del(_ context.Context, key string) error { - bucket, err := tx.getBucket(tx.boltTx, false) +func (s *kvStoreTx) Del(_ context.Context, key string) error { + bucket, err := s.getBucket(s.boltTx, false) if err != nil { return err } @@ -286,11 +288,9 @@ func getGlobalRuleBucket(perm bool, ruleName string) getBucketFunc { // bucket under which a kv store for a specific rule-name and group ID is // stored. The `perm` param determines if the temporary or permanent store is // used. -func getSessionRuleBucket(perm bool, ruleName string, - groupID session.ID) getBucketFunc { - +func (s *kvStoreTx) getSessionRuleBucket(perm bool) getBucketFunc { return func(tx *bbolt.Tx, create bool) (*bbolt.Bucket, error) { - ruleBucket, err := getRuleBucket(perm, ruleName)(tx, create) + ruleBucket, err := getRuleBucket(perm, s.ruleName)(tx, create) if err != nil { return nil, err } @@ -300,6 +300,19 @@ func getSessionRuleBucket(perm bool, ruleName string, } if create { + // NOTE: for a bbolt backend, the context is in any case + // dropped behind the GetSessionIDs call. So passing in + // a new context here is not a problem. + ctx := context.Background() + + // If create is true, we expect this to be an existing + // session. So we check that now and return an error + // accordingly if the session does not exist. + _, err := s.sessions.GetSessionIDs(ctx, s.groupID) + if err != nil { + return nil, err + } + sessBucket, err := ruleBucket.CreateBucketIfNotExists( sessKVStoreBucketKey, ) @@ -307,14 +320,14 @@ func getSessionRuleBucket(perm bool, ruleName string, return nil, err } - return sessBucket.CreateBucketIfNotExists(groupID[:]) + return sessBucket.CreateBucketIfNotExists(s.groupID[:]) } sessBucket := ruleBucket.Bucket(sessKVStoreBucketKey) if sessBucket == nil { return nil, nil } - return sessBucket.Bucket(groupID[:]), nil + return sessBucket.Bucket(s.groupID[:]), nil } } @@ -322,13 +335,9 @@ func getSessionRuleBucket(perm bool, ruleName string, // bucket under which a kv store for a specific rule-name, group ID and // feature name is stored. The `perm` param determines if the temporary or // permanent store is used. -func getSessionFeatureRuleBucket(perm bool, ruleName string, - groupID session.ID, featureName string) getBucketFunc { - +func (s *kvStoreTx) getSessionFeatureRuleBucket(perm bool) getBucketFunc { return func(tx *bbolt.Tx, create bool) (*bbolt.Bucket, error) { - sessBucket, err := getSessionRuleBucket( - perm, ruleName, groupID, - )(tx, create) + sessBucket, err := s.getSessionRuleBucket(perm)(tx, create) if err != nil { return nil, err } @@ -346,7 +355,7 @@ func getSessionFeatureRuleBucket(perm bool, ruleName string, } return featureBucket.CreateBucketIfNotExists( - []byte(featureName), + []byte(s.featureName), ) } @@ -354,6 +363,6 @@ func getSessionFeatureRuleBucket(perm bool, ruleName string, if featureBucket == nil { return nil, nil } - return featureBucket.Bucket([]byte(featureName)), nil + return featureBucket.Bucket([]byte(s.featureName)), nil } } diff --git a/firewalldb/kvstores_test.go b/firewalldb/kvstores_test.go index 0742a40a7..592188c77 100644 --- a/firewalldb/kvstores_test.go +++ b/firewalldb/kvstores_test.go @@ -5,8 +5,10 @@ import ( "context" "fmt" "testing" + "time" "github.com/lightninglabs/lightning-terminal/session" + "github.com/lightningnetwork/lnd/clock" "github.com/stretchr/testify/require" ) @@ -83,15 +85,21 @@ func testTempAndPermStores(t *testing.T, featureSpecificStore bool) { featureName = "auto-fees" } - store := NewTestDB(t) + sessions := session.NewTestDB(t, clock.NewDefaultClock()) + store := NewTestDBWithSessions(t, sessions) db := NewDB(store) require.NoError(t, db.Start(ctx)) - kvstores := db.GetKVStores( - "test-rule", [4]byte{1, 1, 1, 1}, featureName, + // Create a session that we can reference. + sess, err := sessions.NewSession( + ctx, "test", session.TypeAutopilot, time.Unix(1000, 0), + "something", ) + require.NoError(t, err) + + kvstores := db.GetKVStores("test-rule", sess.GroupID, featureName) - err := kvstores.Update(ctx, func(ctx context.Context, + err = kvstores.Update(ctx, func(ctx context.Context, tx KVStoreTx) error { // Set an item in the temp store. @@ -137,7 +145,7 @@ func testTempAndPermStores(t *testing.T, featureSpecificStore bool) { require.NoError(t, db.Stop()) }) - kvstores = db.GetKVStores("test-rule", [4]byte{1, 1, 1, 1}, featureName) + kvstores = db.GetKVStores("test-rule", sess.GroupID, featureName) // The temp store should no longer have the stored value but the perm // store should . @@ -164,23 +172,31 @@ func testTempAndPermStores(t *testing.T, featureSpecificStore bool) { func TestKVStoreNameSpaces(t *testing.T) { t.Parallel() ctx := context.Background() - db := NewTestDB(t) - var ( - groupID1 = intToSessionID(1) - groupID2 = intToSessionID(2) + sessions := session.NewTestDB(t, clock.NewDefaultClock()) + db := NewTestDBWithSessions(t, sessions) + + // Create 2 sessions that we can reference. + sess1, err := sessions.NewSession( + ctx, "test", session.TypeAutopilot, time.Unix(1000, 0), "", ) + require.NoError(t, err) + + sess2, err := sessions.NewSession( + ctx, "test1", session.TypeAutopilot, time.Unix(1000, 0), "", + ) + require.NoError(t, err) // Two DBs for same group but different features. - rulesDB1 := db.GetKVStores("test-rule", groupID1, "auto-fees") - rulesDB2 := db.GetKVStores("test-rule", groupID1, "re-balance") + rulesDB1 := db.GetKVStores("test-rule", sess1.GroupID, "auto-fees") + rulesDB2 := db.GetKVStores("test-rule", sess1.GroupID, "re-balance") // The third DB is for the same rule but a different group. It is // for the same feature as db 2. - rulesDB3 := db.GetKVStores("test-rule", groupID2, "re-balance") + rulesDB3 := db.GetKVStores("test-rule", sess2.GroupID, "re-balance") // Test that the three ruleDBs share the same global space. - err := rulesDB1.Update(ctx, func(ctx context.Context, + err = rulesDB1.Update(ctx, func(ctx context.Context, tx KVStoreTx) error { return tx.Global().Set( @@ -311,9 +327,9 @@ func TestKVStoreNameSpaces(t *testing.T) { // Test that the group space is shared by the first two dbs but not // the third. To do this, we re-init the DB's but leave the feature // names out. This way, we will access the group storage. - rulesDB1 = db.GetKVStores("test-rule", groupID1, "") - rulesDB2 = db.GetKVStores("test-rule", groupID1, "") - rulesDB3 = db.GetKVStores("test-rule", groupID2, "") + rulesDB1 = db.GetKVStores("test-rule", sess1.GroupID, "") + rulesDB2 = db.GetKVStores("test-rule", sess1.GroupID, "") + rulesDB3 = db.GetKVStores("test-rule", sess2.GroupID, "") err = rulesDB1.Update(ctx, func(ctx context.Context, tx KVStoreTx) error { @@ -376,6 +392,81 @@ func TestKVStoreNameSpaces(t *testing.T) { require.True(t, bytes.Equal(v, []byte("thing 3"))) } +// TestKVStoreSessionCoupling tests if we attempt to write to a kvstore that +// is namespaced by a session that does not exist, then we should get an error. +func TestKVStoreSessionCoupling(t *testing.T) { + t.Parallel() + ctx := context.Background() + + sessions := session.NewTestDB(t, clock.NewDefaultClock()) + db := NewTestDBWithSessions(t, sessions) + + // Get a kvstore namespaced by a session ID for a session that does + // not exist. + store := db.GetKVStores("AutoFees", [4]byte{1, 1, 1, 1}, "auto-fees") + + err := store.Update(ctx, func(ctx context.Context, + tx KVStoreTx) error { + + // First, show that any call to the global namespace will not + // error since it is not namespaced by a session. + res, err := tx.Global().Get(ctx, "foo") + require.NoError(t, err) + require.Nil(t, res) + + err = tx.Global().Set(ctx, "foo", []byte("bar")) + require.NoError(t, err) + + res, err = tx.Global().Get(ctx, "foo") + require.NoError(t, err) + require.Equal(t, []byte("bar"), res) + + // Now we switch to the local store. We don't expect the Get + // call to error since it should just return a nil value for + // key that has not been set. + _, err = tx.Local().Get(ctx, "foo") + require.NoError(t, err) + + // For Set, we expect an error since the session does not exist. + err = tx.Local().Set(ctx, "foo", []byte("bar")) + require.ErrorIs(t, err, session.ErrUnknownGroup) + + // We again don't expect the error for delete since we just + // expect it to return nil if the key is not found. + err = tx.Local().Del(ctx, "foo") + require.NoError(t, err) + + return nil + }) + require.NoError(t, err) + + // Now, go and create a sessions in the session DB. + sess, err := sessions.NewSession( + ctx, "test", session.TypeAutopilot, time.Unix(1000, 0), + "something", + ) + require.NoError(t, err) + + // Get a kvstore namespaced by a session ID for a session that now + // does exist. + store = db.GetKVStores("AutoFees", sess.GroupID, "auto-fees") + + // Now, repeat the "Set" call for this session's kvstore to + // show that it no longer errors. + err = store.Update(ctx, func(ctx context.Context, tx KVStoreTx) error { + // For Set, we expect an error since the session does not exist. + err = tx.Local().Set(ctx, "foo", []byte("bar")) + require.NoError(t, err) + + res, err := tx.Local().Get(ctx, "foo") + require.NoError(t, err) + require.Equal(t, []byte("bar"), res) + + return nil + }) + require.NoError(t, err) +} + func intToSessionID(i uint32) session.ID { var id session.ID byteOrder.PutUint32(id[:], i) diff --git a/firewalldb/test_kvdb.go b/firewalldb/test_kvdb.go index 0757786eb..4e2e8a063 100644 --- a/firewalldb/test_kvdb.go +++ b/firewalldb/test_kvdb.go @@ -3,6 +3,7 @@ package firewalldb import ( "testing" + "github.com/lightninglabs/lightning-terminal/session" "github.com/stretchr/testify/require" ) @@ -14,7 +15,19 @@ func NewTestDB(t *testing.T) *BoltDB { // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. func NewTestDBFromPath(t *testing.T, dbPath string) *BoltDB { - store, err := NewBoltDB(dbPath, DBFilename, nil) + return newDBFromPathWithSessions(t, dbPath, nil) +} + +// NewTestDBWithSessions creates a new test BoltDB Store with access to an +// existing sessions DB. +func NewTestDBWithSessions(t *testing.T, sessStore session.Store) *BoltDB { + return newDBFromPathWithSessions(t, t.TempDir(), sessStore) +} + +func newDBFromPathWithSessions(t *testing.T, dbPath string, + sessStore session.Store) *BoltDB { + + store, err := NewBoltDB(dbPath, DBFilename, sessStore) require.NoError(t, err) t.Cleanup(func() { diff --git a/session/kvdb_store.go b/session/kvdb_store.go index 71099192a..d52897966 100644 --- a/session/kvdb_store.go +++ b/session/kvdb_store.go @@ -637,7 +637,7 @@ func (db *BoltStore) GetGroupID(_ context.Context, sessionID ID) (ID, error) { sessionIDBkt := idIndex.Bucket(sessionID[:]) if sessionIDBkt == nil { return fmt.Errorf("%w: no index entry for session "+ - "ID: %x", ErrUnknownGroup, sessionID) + "ID: %x", ErrSessionNotFound, sessionID) } groupIDBytes := sessionIDBkt.Get(groupIDKey) @@ -696,7 +696,7 @@ func getSessionIDs(sessionBkt *bbolt.Bucket, groupID ID) ([]ID, error) { groupIDBkt := groupIndexBkt.Bucket(groupID[:]) if groupIDBkt == nil { - return nil, fmt.Errorf("no sessions for group ID %v", + return nil, fmt.Errorf("%w: group ID %v", ErrUnknownGroup, groupID) } diff --git a/session/sql_store.go b/session/sql_store.go index 76f594f13..8bd0b51e8 100644 --- a/session/sql_store.go +++ b/session/sql_store.go @@ -593,7 +593,7 @@ func (s *SQLStore) GetGroupID(ctx context.Context, sessionID ID) (ID, error) { // Get the session using the legacy Alias. sess, err := db.GetSessionByAlias(ctx, sessionID[:]) if errors.Is(err, sql.ErrNoRows) { - return ErrUnknownGroup + return ErrSessionNotFound } else if err != nil { return err } diff --git a/session/store_test.go b/session/store_test.go index 6c7ec7176..4b1b7d3bb 100644 --- a/session/store_test.go +++ b/session/store_test.go @@ -32,6 +32,14 @@ func TestBasicSessionStore(t *testing.T) { _, err := db.GetSession(ctx, ID{1, 3, 4, 4}) require.ErrorIs(t, err, ErrSessionNotFound) + // Do the same for other "Get" type methods to assert that the correct + // errors are returned. + _, err = db.GetGroupID(ctx, ID{1, 2, 3, 4}) + require.ErrorIs(t, err, ErrSessionNotFound) + + _, err = db.GetSessionIDs(ctx, ID{1, 2, 3, 4}) + require.ErrorIs(t, err, ErrUnknownGroup) + // Reserve a session. This should succeed. s1, err := reserveSession(db, "session 1") require.NoError(t, err) @@ -182,7 +190,7 @@ func TestBasicSessionStore(t *testing.T) { require.Empty(t, sessions) _, err = db.GetGroupID(ctx, s4.ID) - require.ErrorIs(t, err, ErrUnknownGroup) + require.ErrorIs(t, err, ErrSessionNotFound) // Only session 1 should remain in this group. sessIDs, err = db.GetSessionIDs(ctx, s4.GroupID)