diff --git a/config_dev.go b/config_dev.go index 5244b3fdc..2dd937b9b 100644 --- a/config_dev.go +++ b/config_dev.go @@ -108,9 +108,11 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { acctStore := accounts.NewSQLStore(sqlStore.BaseDB, clock) sessStore := session.NewSQLStore(sqlStore.BaseDB, clock) + firewallStore := firewalldb.NewSQLDB(sqlStore.BaseDB) stores.accounts = acctStore stores.sessions = sessStore + stores.firewall = firewalldb.NewDB(firewallStore) stores.closeFns["sqlite"] = sqlStore.BaseDB.Close case DatabaseBackendPostgres: @@ -121,9 +123,11 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { acctStore := accounts.NewSQLStore(sqlStore.BaseDB, clock) sessStore := session.NewSQLStore(sqlStore.BaseDB, clock) + firewallStore := firewalldb.NewSQLDB(sqlStore.BaseDB) stores.accounts = acctStore stores.sessions = sessStore + stores.firewall = firewalldb.NewDB(firewallStore) stores.closeFns["postgres"] = sqlStore.BaseDB.Close default: @@ -157,7 +161,10 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { err) } - stores.firewall = firewalldb.NewDB(firewallBoltDB) + if stores.firewall == nil { + stores.firewall = firewalldb.NewDB(firewallBoltDB) + } + stores.firewallBolt = firewallBoltDB stores.closeFns["bbolt-firewalldb"] = firewallBoltDB.Close diff --git a/firewalldb/kvstores_sql.go b/firewalldb/kvstores_sql.go new file mode 100644 index 000000000..e7e1e7da5 --- /dev/null +++ b/firewalldb/kvstores_sql.go @@ -0,0 +1,451 @@ +package firewalldb + +import ( + "bytes" + "context" + "database/sql" + "errors" + "fmt" + + "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/session" + "github.com/lightningnetwork/lnd/fn" +) + +// SQLSessionQueries is a subset of the sqlc.Queries interface that can be used +// to interact with the session table. +type SQLSessionQueries interface { + GetSessionIDByAlias(ctx context.Context, legacyID []byte) (int64, error) +} + +// SQLKVStoreQueries is a subset of the sqlc.Queries interface that can be +// used to interact with the kvstore tables. +// +//nolint:lll +type SQLKVStoreQueries interface { + SQLSessionQueries + + DeleteFeatureKVStoreRecord(ctx context.Context, arg sqlc.DeleteFeatureKVStoreRecordParams) error + DeleteGlobalKVStoreRecord(ctx context.Context, arg sqlc.DeleteGlobalKVStoreRecordParams) error + DeleteSessionKVStoreRecord(ctx context.Context, arg sqlc.DeleteSessionKVStoreRecordParams) error + GetFeatureKVStoreRecord(ctx context.Context, arg sqlc.GetFeatureKVStoreRecordParams) ([]byte, error) + GetGlobalKVStoreRecord(ctx context.Context, arg sqlc.GetGlobalKVStoreRecordParams) ([]byte, error) + GetSessionKVStoreRecord(ctx context.Context, arg sqlc.GetSessionKVStoreRecordParams) ([]byte, error) + UpdateFeatureKVStoreRecord(ctx context.Context, arg sqlc.UpdateFeatureKVStoreRecordParams) error + UpdateGlobalKVStoreRecord(ctx context.Context, arg sqlc.UpdateGlobalKVStoreRecordParams) error + UpdateSessionKVStoreRecord(ctx context.Context, arg sqlc.UpdateSessionKVStoreRecordParams) error + InsertKVStoreRecord(ctx context.Context, arg sqlc.InsertKVStoreRecordParams) error + DeleteAllTempKVStores(ctx context.Context) error + GetOrInsertFeatureID(ctx context.Context, name string) (int64, error) + GetOrInsertRuleID(ctx context.Context, name string) (int64, error) + GetFeatureID(ctx context.Context, name string) (int64, error) + GetRuleID(ctx context.Context, name string) (int64, error) +} + +// DeleteTempKVStores deletes all temporary kv stores. +// +// NOTE: part of the RulesDB interface. +func (s *SQLDB) DeleteTempKVStores(ctx context.Context) error { + var writeTxOpts db.QueriesTxOptions + + return s.db.ExecTx(ctx, &writeTxOpts, func(tx SQLQueries) error { + return tx.DeleteAllTempKVStores(ctx) + }) +} + +// GetKVStores constructs a new rules.KVStores in a namespace defined by the +// rule name, group ID and feature name. +// +// NOTE: part of the RulesDB interface. +func (s *SQLDB) GetKVStores(rule string, groupAlias session.ID, + feature string) KVStores { + + return &sqlExecutor[KVStoreTx]{ + db: s.db, + wrapTx: func(queries SQLQueries) KVStoreTx { + return &sqlKVStoresTx{ + queries: queries, + groupAlias: groupAlias, + rule: rule, + feature: feature, + } + }, + } +} + +// sqlKVStoresTx is a SQL implementation of the KVStoreTx interface. +type sqlKVStoresTx struct { + queries SQLKVStoreQueries + groupAlias session.ID + rule string + feature string +} + +// Global returns a persisted global, rule-name indexed, kv store. A rule with a +// given name will have access to this store independent of group ID or feature. +// +// NOTE: part of the KVStoreTx interface. +func (s *sqlKVStoresTx) Global() KVStore { + return &sqlKVStore{ + sqlKVStoresTx: s, + params: &sqlKVStoreParams{ + perm: true, + ruleName: s.rule, + }, + } +} + +// Local returns a persisted local kv store for the rule. Depending on how the +// implementation is initialised, this will either be under the group ID +// namespace or the group ID _and_ feature name namespace. +// +// NOTE: part of the KVStoreTx interface. +func (s *sqlKVStoresTx) Local() KVStore { + var featureName fn.Option[string] + if s.feature != "" { + featureName = fn.Some(s.feature) + } + + return &sqlKVStore{ + sqlKVStoresTx: s, + params: &sqlKVStoreParams{ + perm: true, + ruleName: s.rule, + groupID: fn.Some(s.groupAlias), + featureName: featureName, + }, + } +} + +// GlobalTemp is similar to the Global store except that its contents is cleared +// upon restart of the database. The reason persisting the temporary store +// changes instead of just keeping an in-memory store is that we can then +// guarantee atomicity if changes are made to both the permanent and temporary +// stores. +// +// NOTE: part of the KVStoreTx interface. +func (s *sqlKVStoresTx) GlobalTemp() KVStore { + return &sqlKVStore{ + sqlKVStoresTx: s, + params: &sqlKVStoreParams{ + perm: false, + ruleName: s.rule, + }, + } +} + +// LocalTemp is similar to the Local store except that its contents is cleared +// upon restart of the database. The reason persisting the temporary store +// changes instead of just keeping an in-memory store is that we can then +// guarantee atomicity if changes are made to both the permanent and temporary +// stores. +// +// NOTE: part of the KVStoreTx interface. +func (s *sqlKVStoresTx) LocalTemp() KVStore { + var featureName fn.Option[string] + if s.feature != "" { + featureName = fn.Some(s.feature) + } + + return &sqlKVStore{ + sqlKVStoresTx: s, + params: &sqlKVStoreParams{ + perm: false, + ruleName: s.rule, + groupID: fn.Some(s.groupAlias), + featureName: featureName, + }, + } +} + +// A compile-time assertion to ensure that sqlKVStoresTx implements the +// KVStoreTx interface. +var _ KVStoreTx = (*sqlKVStoresTx)(nil) + +// sqlKVStoreParams holds the various parameters that determine the namespace +// that a query is accessing. +type sqlKVStoreParams struct { + perm bool + ruleName string + groupID fn.Option[session.ID] + featureName fn.Option[string] +} + +// sqlKVStore is a SQL store backed KVStore. +type sqlKVStore struct { + *sqlKVStoresTx + + params *sqlKVStoreParams +} + +// A compile-time assertion to ensure that sqlKVStore implements the KVStore +// interface. +var _ KVStore = (*sqlKVStore)(nil) + +// Get fetches the value under the given key from the underlying kv store. If no +// value is found, nil is returned. +// +// NOTE: part of the KVStore interface. +func (s *sqlKVStore) Get(ctx context.Context, key string) ([]byte, error) { + value, err := s.get(ctx, key) + if errors.Is(err, sql.ErrNoRows) || + errors.Is(err, session.ErrUnknownGroup) { + + return nil, nil + } else if err != nil { + return nil, err + } + + return value, nil +} + +// Set sets the given key-value pair in the underlying kv store. +// +// NOTE: part of the KVStore interface. +func (s *sqlKVStore) Set(ctx context.Context, key string, value []byte) error { + ruleID, sessionID, featureID, err := s.genNamespaceFields(ctx, false) + if err != nil { + return err + } + + // We first need to figure out if we are inserting a new record or + // updating an existing one. So first do a GET with the same set of + // params. + oldValue, err := s.get(ctx, key) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return err + } + + // No such entry. Add new record. + if errors.Is(err, sql.ErrNoRows) { + return s.queries.InsertKVStoreRecord( + ctx, sqlc.InsertKVStoreRecordParams{ + EntryKey: key, + Value: value, + Perm: s.params.perm, + RuleID: ruleID, + SessionID: sessionID, + FeatureID: featureID, + }, + ) + } + + // If an entry exists but the value has not changed, there is nothing + // left to do. + if bytes.Equal(oldValue, value) { + return nil + } + + // Otherwise, the key exists but the value needs to be updated. + switch { + case sessionID.Valid && featureID.Valid: + return s.queries.UpdateFeatureKVStoreRecord( + ctx, sqlc.UpdateFeatureKVStoreRecordParams{ + Key: key, + Value: value, + Perm: s.params.perm, + SessionID: sessionID, + RuleID: ruleID, + FeatureID: featureID, + }, + ) + + case sessionID.Valid: + return s.queries.UpdateSessionKVStoreRecord( + ctx, sqlc.UpdateSessionKVStoreRecordParams{ + Key: key, + Value: value, + Perm: s.params.perm, + SessionID: sessionID, + RuleID: ruleID, + }, + ) + + case featureID.Valid: + return fmt.Errorf("a global feature kv store is " + + "not currently supported") + default: + return s.queries.UpdateGlobalKVStoreRecord( + ctx, sqlc.UpdateGlobalKVStoreRecordParams{ + Key: key, + Value: value, + Perm: s.params.perm, + RuleID: ruleID, + }, + ) + } +} + +// Del deletes the value under the given key in the underlying kv store. +// +// NOTE: part of the KVStore interface. +func (s *sqlKVStore) Del(ctx context.Context, key string) error { + // Note: we pass in true here for "read-only" since because this is a + // Delete, if the record does not exist, we don't need to create one. + // But no need to error out if it doesn't exist. + ruleID, sessionID, featureID, err := s.genNamespaceFields(ctx, true) + if errors.Is(err, sql.ErrNoRows) || + errors.Is(err, session.ErrUnknownGroup) { + + return nil + } else if err != nil { + return err + } + + switch { + case sessionID.Valid && featureID.Valid: + return s.queries.DeleteFeatureKVStoreRecord( + ctx, sqlc.DeleteFeatureKVStoreRecordParams{ + Key: key, + Perm: s.params.perm, + SessionID: sessionID, + RuleID: ruleID, + FeatureID: featureID, + }, + ) + + case sessionID.Valid: + return s.queries.DeleteSessionKVStoreRecord( + ctx, sqlc.DeleteSessionKVStoreRecordParams{ + Key: key, + Perm: s.params.perm, + SessionID: sessionID, + RuleID: ruleID, + }, + ) + + case featureID.Valid: + return fmt.Errorf("a global feature kv store is " + + "not currently supported") + default: + return s.queries.DeleteGlobalKVStoreRecord( + ctx, sqlc.DeleteGlobalKVStoreRecordParams{ + Key: key, + Perm: s.params.perm, + RuleID: ruleID, + }, + ) + } +} + +// get fetches the value under the given key from the underlying kv store given +// the namespace fields. +func (s *sqlKVStore) get(ctx context.Context, key string) ([]byte, error) { + ruleID, sessionID, featureID, err := s.genNamespaceFields(ctx, true) + if err != nil { + return nil, err + } + + switch { + case sessionID.Valid && featureID.Valid: + return s.queries.GetFeatureKVStoreRecord( + ctx, sqlc.GetFeatureKVStoreRecordParams{ + Key: key, + Perm: s.params.perm, + SessionID: sessionID, + RuleID: ruleID, + FeatureID: featureID, + }, + ) + + case sessionID.Valid: + return s.queries.GetSessionKVStoreRecord( + ctx, sqlc.GetSessionKVStoreRecordParams{ + Key: key, + Perm: s.params.perm, + SessionID: sessionID, + RuleID: ruleID, + }, + ) + + case featureID.Valid: + return nil, fmt.Errorf("a global feature kv store is " + + "not currently supported") + default: + return s.queries.GetGlobalKVStoreRecord( + ctx, sqlc.GetGlobalKVStoreRecordParams{ + Key: key, + Perm: s.params.perm, + RuleID: ruleID, + }, + ) + } +} + +// genNamespaceFields generates the various SQL query parameters that are +// required to access the kvstore namespace determined by the sqlKVStore params. +func (s *sqlKVStore) genNamespaceFields(ctx context.Context, + readOnly bool) (int64, sql.NullInt64, sql.NullInt64, error) { + + var ( + sessionID sql.NullInt64 + featureID sql.NullInt64 + ruleID int64 + err error + ) + + // If a group ID is specified, then we first check that this group ID + // is a known session alias. + s.params.groupID.WhenSome(func(id session.ID) { + var groupID int64 + groupID, err = s.queries.GetSessionIDByAlias(ctx, id[:]) + if errors.Is(err, sql.ErrNoRows) { + err = session.ErrUnknownGroup + + return + } else if err != nil { + return + } + + sessionID = sql.NullInt64{ + Int64: groupID, + Valid: true, + } + }) + if err != nil { + return ruleID, sessionID, featureID, err + } + + // We only insert a new rule name into the DB if this is a write call. + if readOnly { + ruleID, err = s.queries.GetRuleID(ctx, s.params.ruleName) + if err != nil { + return 0, sessionID, featureID, + fmt.Errorf("unable to get rule ID: %w", err) + } + } else { + ruleID, err = s.queries.GetOrInsertRuleID( + ctx, s.params.ruleName, + ) + if err != nil { + return 0, sessionID, featureID, + fmt.Errorf("unable to get or insert rule "+ + "ID: %w", err) + } + } + + s.params.featureName.WhenSome(func(feature string) { + // We only insert a new feature name into the DB if this is a + // write call. + var id int64 + if readOnly { + id, err = s.queries.GetFeatureID(ctx, feature) + if err != nil { + return + } + } else { + id, err = s.queries.GetOrInsertFeatureID(ctx, feature) + if err != nil { + return + } + } + + featureID = sql.NullInt64{ + Int64: id, + Valid: true, + } + }) + + return ruleID, sessionID, featureID, err +} diff --git a/firewalldb/sql_store.go b/firewalldb/sql_store.go new file mode 100644 index 000000000..01f1d5b9c --- /dev/null +++ b/firewalldb/sql_store.go @@ -0,0 +1,91 @@ +package firewalldb + +import ( + "context" + "database/sql" + + "github.com/lightninglabs/lightning-terminal/db" +) + +// SQLQueries is a subset of the sqlc.Queries interface that can be used to +// interact with various firewalldb tables. +type SQLQueries interface { + SQLKVStoreQueries +} + +// BatchedSQLQueries is a version of the SQLQueries that's capable of batched +// database operations. +type BatchedSQLQueries interface { + SQLQueries + + db.BatchedTx[SQLQueries] +} + +// SQLDB represents a storage backend. +type SQLDB struct { + // db is all the higher level queries that the SQLStore has access to + // in order to implement all its CRUD logic. + db BatchedSQLQueries + + // BaseDB represents the underlying database connection. + *db.BaseDB +} + +// A compile-time assertion to ensure that SQLDB implements the RulesDB +// interface. +var _ RulesDB = (*SQLDB)(nil) + +// NewSQLDB creates a new SQLStore instance given an open SQLQueries +// storage backend. +func NewSQLDB(sqlDB *db.BaseDB) *SQLDB { + executor := db.NewTransactionExecutor( + sqlDB, func(tx *sql.Tx) SQLQueries { + return sqlDB.WithTx(tx) + }, + ) + + return &SQLDB{ + db: executor, + BaseDB: sqlDB, + } +} + +// sqlExecutor is a concrete implementation of the DBExecutor interface that +// uses a SQL database as its backing store. +type sqlExecutor[T any] struct { + db BatchedSQLQueries + wrapTx func(queries SQLQueries) T +} + +// Update opens a database read/write transaction and executes the function f +// with the transaction passed as a parameter. After f exits, if f did not +// error, the transaction is committed. Otherwise, if f did error, the +// transaction is rolled back. If the rollback fails, the original error +// returned by f is still returned. If the commit fails, the commit error is +// returned. +// +// NOTE: this is part of the DBExecutor interface. +func (e *sqlExecutor[T]) Update(ctx context.Context, + fn func(ctx context.Context, tx T) error) error { + + var txOpts db.QueriesTxOptions + return e.db.ExecTx(ctx, &txOpts, func(queries SQLQueries) error { + return fn(ctx, e.wrapTx(queries)) + }) +} + +// View opens a database read transaction and executes the function f with the +// transaction passed as a parameter. After f exits, the transaction is rolled +// back. If f errors, its error is returned, not a rollback error (if any +// occur). +// +// NOTE: this is part of the DBExecutor interface. +func (e *sqlExecutor[T]) View(ctx context.Context, + fn func(ctx context.Context, tx T) error) error { + + txOpts := db.NewQueryReadTx() + + return e.db.ExecTx(ctx, &txOpts, func(queries SQLQueries) error { + return fn(ctx, e.wrapTx(queries)) + }) +} diff --git a/firewalldb/test_kvdb.go b/firewalldb/test_kvdb.go index 4e2e8a063..91ea130b1 100644 --- a/firewalldb/test_kvdb.go +++ b/firewalldb/test_kvdb.go @@ -1,3 +1,5 @@ +//go:build !test_db_postgres && !test_db_sqlite + package firewalldb import ( diff --git a/firewalldb/test_postgres.go b/firewalldb/test_postgres.go new file mode 100644 index 000000000..aeb012351 --- /dev/null +++ b/firewalldb/test_postgres.go @@ -0,0 +1,20 @@ +//go:build test_db_postgres && !test_db_sqlite + +package firewalldb + +import ( + "testing" + + "github.com/lightninglabs/lightning-terminal/db" +) + +// NewTestDB is a helper function that creates an BBolt database for testing. +func NewTestDB(t *testing.T) *SQLDB { + return NewSQLDB(db.NewTestPostgresDB(t).BaseDB) +} + +// NewTestDBFromPath is a helper function that creates a new BoltStore with a +// connection to an existing BBolt database for testing. +func NewTestDBFromPath(t *testing.T, _ string) *SQLDB { + return NewSQLDB(db.NewTestPostgresDB(t).BaseDB) +} diff --git a/firewalldb/test_sql.go b/firewalldb/test_sql.go new file mode 100644 index 000000000..d256480f9 --- /dev/null +++ b/firewalldb/test_sql.go @@ -0,0 +1,19 @@ +//go:build test_db_postgres || test_db_sqlite + +package firewalldb + +import ( + "testing" + + "github.com/lightninglabs/lightning-terminal/session" + "github.com/stretchr/testify/require" +) + +// NewTestDBWithSessions creates a new test SQLDB Store with access to an +// existing sessions DB. +func NewTestDBWithSessions(t *testing.T, sessionStore session.Store) *SQLDB { + sessions, ok := sessionStore.(*session.SQLStore) + require.True(t, ok) + + return NewSQLDB(sessions.BaseDB) +} diff --git a/firewalldb/test_sqlite.go b/firewalldb/test_sqlite.go new file mode 100644 index 000000000..2497584d9 --- /dev/null +++ b/firewalldb/test_sqlite.go @@ -0,0 +1,20 @@ +//go:build test_db_sqlite && !test_db_postgres + +package firewalldb + +import ( + "testing" + + "github.com/lightninglabs/lightning-terminal/db" +) + +// NewTestDB is a helper function that creates an BBolt database for testing. +func NewTestDB(t *testing.T) *SQLDB { + return NewSQLDB(db.NewTestSqliteDB(t).BaseDB) +} + +// NewTestDBFromPath is a helper function that creates a new BoltStore with a +// connection to an existing BBolt database for testing. +func NewTestDBFromPath(t *testing.T, dbPath string) *SQLDB { + return NewSQLDB(db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB) +} diff --git a/session/sql_store.go b/session/sql_store.go index 8bd0b51e8..b1d366fe7 100644 --- a/session/sql_store.go +++ b/session/sql_store.go @@ -65,8 +65,8 @@ type SQLStore struct { // in order to implement all its CRUD logic. db BatchedSQLQueries - // DB represents the underlying database connection. - *sql.DB + // BaseDB represents the underlying database connection. + *db.BaseDB clock clock.Clock } @@ -81,9 +81,9 @@ func NewSQLStore(sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { ) return &SQLStore{ - db: executor, - DB: sqlDB.DB, - clock: clock, + db: executor, + BaseDB: sqlDB, + clock: clock, } }