From 7ce36d7e7d58e8ece6021718722705737db1f164 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 11 Mar 2025 14:57:52 -0500 Subject: [PATCH 1/7] multi: thread contexts through privacy map interfaces Update the PrivacyMapDB interface methods to take contexts (both the methods themselves and the call-back params) and then ensure all implementations are updated and all call-sites pass contexts through correctly. --- firewall/privacy_mapper.go | 83 ++++++++++++++++++++---------- firewall/privacy_mapper_test.go | 16 +++--- firewall/rule_enforcer.go | 2 +- firewalldb/privacy_mapper.go | 17 +++--- firewalldb/privacy_mapper_test.go | 19 +++++-- rules/chan_policy_bounds.go | 9 ++-- rules/channel_constraints.go | 9 ++-- rules/channel_restrictions.go | 12 +++-- rules/channel_restrictions_test.go | 5 +- rules/history_limit.go | 9 ++-- rules/interfaces.go | 4 +- rules/onchain_budget.go | 9 ++-- rules/peer_restrictions.go | 12 +++-- rules/peer_restrictions_test.go | 5 +- rules/rate_limit.go | 9 ++-- session_rpcserver.go | 30 +++++++---- 16 files changed, 160 insertions(+), 90 deletions(-) diff --git a/firewall/privacy_mapper.go b/firewall/privacy_mapper.go index 26e053500..fd077c916 100644 --- a/firewall/privacy_mapper.go +++ b/firewall/privacy_mapper.go @@ -325,14 +325,16 @@ func handleGetInfoResponse(db firewalldb.PrivacyMapDB, flags session.PrivacyFlags) func(ctx context.Context, r *lnrpc.GetInfoResponse) (proto.Message, error) { - return func(_ context.Context, r *lnrpc.GetInfoResponse) ( + return func(ctx context.Context, r *lnrpc.GetInfoResponse) ( proto.Message, error) { // We hide the pubkey unless it is disabled. pseudoPubKey := r.IdentityPubkey if !flags.Contains(session.ClearPubkeys) { - err := db.Update( - func(tx firewalldb.PrivacyMapTx) error { + err := db.Update(ctx, + func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + var err error pseudoPubKey, err = firewalldb.HideString( tx, r.IdentityPubkey, @@ -377,14 +379,16 @@ func handleFwdHistoryResponse(db firewalldb.PrivacyMapDB, randIntn func(int) (int, error)) func(ctx context.Context, r *lnrpc.ForwardingHistoryResponse) (proto.Message, error) { - return func(_ context.Context, r *lnrpc.ForwardingHistoryResponse) ( + return func(ctx context.Context, r *lnrpc.ForwardingHistoryResponse) ( proto.Message, error) { fwdEvents := make( []*lnrpc.ForwardingEvent, len(r.ForwardingEvents), ) - err := db.Update(func(tx firewalldb.PrivacyMapTx) error { + err := db.Update(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + for i, fe := range r.ForwardingEvents { var err error @@ -487,7 +491,9 @@ func handleFeeReportResponse(db firewalldb.PrivacyMapDB, chanFees := make([]*lnrpc.ChannelFeeReport, len(r.ChannelFees)) - err := db.Update(func(tx firewalldb.PrivacyMapTx) error { + err := db.Update(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + var err error for i, c := range r.ChannelFees { @@ -550,7 +556,9 @@ func handleListChannelsRequest(db firewalldb.PrivacyMapDB, return r, nil } - err := db.View(func(tx firewalldb.PrivacyMapTx) error { + err := db.View(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + peer, err := firewalldb.RevealBytes(tx, r.Peer) if err != nil { return err @@ -572,7 +580,7 @@ func handleListChannelsResponse(db firewalldb.PrivacyMapDB, randIntn func(int) (int, error)) func(ctx context.Context, r *lnrpc.ListChannelsResponse) (proto.Message, error) { - return func(_ context.Context, r *lnrpc.ListChannelsResponse) ( + return func(ctx context.Context, r *lnrpc.ListChannelsResponse) ( proto.Message, error) { hidePubkeys := !flags.Contains(session.ClearPubkeys) @@ -580,7 +588,9 @@ func handleListChannelsResponse(db firewalldb.PrivacyMapDB, channels := make([]*lnrpc.Channel, len(r.Channels)) - err := db.Update(func(tx firewalldb.PrivacyMapTx) error { + err := db.Update(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + for i, c := range r.Channels { var err error @@ -745,7 +755,7 @@ func handleUpdatePolicyRequest(db firewalldb.PrivacyMapDB, flags session.PrivacyFlags) func(ctx context.Context, r *lnrpc.PolicyUpdateRequest) (proto.Message, error) { - return func(_ context.Context, r *lnrpc.PolicyUpdateRequest) ( + return func(ctx context.Context, r *lnrpc.PolicyUpdateRequest) ( proto.Message, error) { chanPoint := r.GetChanPoint() @@ -764,7 +774,9 @@ func handleUpdatePolicyRequest(db firewalldb.PrivacyMapDB, newTxid := txid.String() newIndex := chanPoint.GetOutputIndex() if !flags.Contains(session.ClearChanIDs) { - err = db.View(func(tx firewalldb.PrivacyMapTx) error { + err = db.View(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + var err error newTxid, newIndex, err = firewalldb.RevealChanPoint( tx, newTxid, newIndex, @@ -793,7 +805,7 @@ func handleUpdatePolicyResponse(db firewalldb.PrivacyMapDB, flags session.PrivacyFlags) func(ctx context.Context, r *lnrpc.PolicyUpdateResponse) (proto.Message, error) { - return func(_ context.Context, r *lnrpc.PolicyUpdateResponse) ( + return func(ctx context.Context, r *lnrpc.PolicyUpdateResponse) ( proto.Message, error) { if flags.Contains(session.ClearChanIDs) { @@ -804,7 +816,9 @@ func handleUpdatePolicyResponse(db firewalldb.PrivacyMapDB, []*lnrpc.FailedUpdate, len(r.FailedUpdates), ) - err := db.Update(func(tx firewalldb.PrivacyMapTx) error { + err := db.Update(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + for i, u := range r.FailedUpdates { failedUpdates[i] = &lnrpc.FailedUpdate{ Reason: u.Reason, @@ -926,7 +940,7 @@ func handleClosedChannelsResponse(db firewalldb.PrivacyMapDB, randIntn func(int) (int, error)) func(ctx context.Context, r *lnrpc.ClosedChannelsResponse) (proto.Message, error) { - return func(_ context.Context, r *lnrpc.ClosedChannelsResponse) ( + return func(ctx context.Context, r *lnrpc.ClosedChannelsResponse) ( proto.Message, error) { closedChannels := make( @@ -934,7 +948,9 @@ func handleClosedChannelsResponse(db firewalldb.PrivacyMapDB, len(r.Channels), ) - err := db.Update(func(tx firewalldb.PrivacyMapTx) error { + err := db.Update(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + for i, c := range r.Channels { var err error @@ -1117,7 +1133,7 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB, randIntn func(int) (int, error)) func(ctx context.Context, r *lnrpc.PendingChannelsResponse) (proto.Message, error) { - return func(_ context.Context, r *lnrpc.PendingChannelsResponse) ( + return func(ctx context.Context, r *lnrpc.PendingChannelsResponse) ( proto.Message, error) { pendingOpens := make( @@ -1140,7 +1156,9 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB, len(r.WaitingCloseChannels), ) - err := db.Update(func(tx firewalldb.PrivacyMapTx) error { + err := db.Update(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + for i, c := range r.PendingOpenChannels { var err error @@ -1343,12 +1361,14 @@ func handleBatchOpenChannelRequest(db firewalldb.PrivacyMapDB, flags session.PrivacyFlags) func(ctx context.Context, r *lnrpc.BatchOpenChannelRequest) (proto.Message, error) { - return func(_ context.Context, r *lnrpc.BatchOpenChannelRequest) ( + return func(ctx context.Context, r *lnrpc.BatchOpenChannelRequest) ( proto.Message, error) { var reqs = make([]*lnrpc.BatchOpenChannel, len(r.Channels)) - err := db.View(func(tx firewalldb.PrivacyMapTx) error { + err := db.View(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + for i, c := range r.Channels { var err error @@ -1414,12 +1434,14 @@ func handleBatchOpenChannelResponse(db firewalldb.PrivacyMapDB, flags session.PrivacyFlags) func(ctx context.Context, r *lnrpc.BatchOpenChannelResponse) (proto.Message, error) { - return func(_ context.Context, r *lnrpc.BatchOpenChannelResponse) ( + return func(ctx context.Context, r *lnrpc.BatchOpenChannelResponse) ( proto.Message, error) { resps := make([]*lnrpc.PendingUpdate, len(r.PendingChannels)) - err := db.Update(func(tx firewalldb.PrivacyMapTx) error { + err := db.Update(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + for i, p := range r.PendingChannels { var ( txIdBytes = p.Txid @@ -1471,14 +1493,15 @@ func handleChannelOpenRequest(db firewalldb.PrivacyMapDB, flags session.PrivacyFlags) func(ctx context.Context, r *lnrpc.OpenChannelRequest) (proto.Message, error) { - return func(_ context.Context, r *lnrpc.OpenChannelRequest) ( + return func(ctx context.Context, r *lnrpc.OpenChannelRequest) ( proto.Message, error) { var nodePubkey []byte - err := db.View(func(tx firewalldb.PrivacyMapTx) error { - var err error + err := db.View(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + var err error // We use the byte slice representation of the // pubkey and fall back to the hex string if present. nodePubkey = r.NodePubkey @@ -1548,7 +1571,7 @@ func handleChannelOpenResponse(db firewalldb.PrivacyMapDB, flags session.PrivacyFlags) func(ctx context.Context, r *lnrpc.ChannelPoint) (proto.Message, error) { - return func(_ context.Context, r *lnrpc.ChannelPoint) ( + return func(ctx context.Context, r *lnrpc.ChannelPoint) ( proto.Message, error) { var ( @@ -1556,7 +1579,9 @@ func handleChannelOpenResponse(db firewalldb.PrivacyMapDB, index uint32 ) - err := db.Update(func(tx firewalldb.PrivacyMapTx) error { + err := db.Update(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + var err error txid = r.GetFundingTxidStr() @@ -1622,12 +1647,14 @@ func handleConnectPeerRequest(db firewalldb.PrivacyMapDB, flags session.PrivacyFlags) func(ctx context.Context, r *lnrpc.ConnectPeerRequest) (proto.Message, error) { - return func(_ context.Context, r *lnrpc.ConnectPeerRequest) ( + return func(ctx context.Context, r *lnrpc.ConnectPeerRequest) ( proto.Message, error) { var addr *lnrpc.LightningAddress - err := db.View(func(tx firewalldb.PrivacyMapTx) error { + err := db.View(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + var err error // Note, this only works if the pubkey alias was diff --git a/firewall/privacy_mapper_test.go b/firewall/privacy_mapper_test.go index 1b67068e0..24582f8ce 100644 --- a/firewall/privacy_mapper_test.go +++ b/firewall/privacy_mapper_test.go @@ -1073,7 +1073,9 @@ func newMockDB(t *testing.T, preloadRealToPseudo map[string]string, db := mockDB{privDB: make(map[string]*mockPrivacyMapDB)} sessDB := db.NewSessionDB(sessID) - _ = sessDB.Update(func(tx firewalldb.PrivacyMapTx) error { + _ = sessDB.Update(context.Background(), func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + for r, p := range preloadRealToPseudo { require.NoError(t, tx.NewPair(r, p)) } @@ -1107,16 +1109,16 @@ type mockPrivacyMapDB struct { p2r map[string]string } -func (m *mockPrivacyMapDB) Update( - f func(tx firewalldb.PrivacyMapTx) error) error { +func (m *mockPrivacyMapDB) Update(ctx context.Context, + f func(ctx context.Context, tx firewalldb.PrivacyMapTx) error) error { - return f(m) + return f(ctx, m) } -func (m *mockPrivacyMapDB) View( - f func(tx firewalldb.PrivacyMapTx) error) error { +func (m *mockPrivacyMapDB) View(ctx context.Context, + f func(ctx context.Context, tx firewalldb.PrivacyMapTx) error) error { - return f(m) + return f(ctx, m) } func (m *mockPrivacyMapDB) NewPair(real, pseudo string) error { diff --git a/firewall/rule_enforcer.go b/firewall/rule_enforcer.go index c99671cdf..7914965ed 100644 --- a/firewall/rule_enforcer.go +++ b/firewall/rule_enforcer.go @@ -395,7 +395,7 @@ func (r *RuleEnforcer) initRule(ctx context.Context, reqID uint64, name string, privMap := r.newPrivMap(session.GroupID) ruleValues, err = ruleValues.PseudoToReal( - privMap, session.PrivacyFlags, + ctx, privMap, session.PrivacyFlags, ) if err != nil { return nil, fmt.Errorf("could not prepare rule "+ diff --git a/firewalldb/privacy_mapper.go b/firewalldb/privacy_mapper.go index e2f10f281..e4fee472a 100644 --- a/firewalldb/privacy_mapper.go +++ b/firewalldb/privacy_mapper.go @@ -1,6 +1,7 @@ package firewalldb import ( + "context" "crypto/rand" "encoding/binary" "encoding/hex" @@ -57,13 +58,13 @@ type PrivacyMapDB interface { // error, the transaction is rolled back. If the rollback fails, the // original error returned by f is still returned. If the commit fails, // the commit error is returned. - Update(f func(tx PrivacyMapTx) error) error + Update(context.Context, func(context.Context, PrivacyMapTx) error) error // View opens a database read transaction and executes the function f // with the transaction passed as a parameter. After f exits, the // transaction is rolled back. If f errors, its error is returned, not a // rollback error (if any occur). - View(f func(tx PrivacyMapTx) error) error + View(context.Context, func(context.Context, PrivacyMapTx) error) error } // PrivacyMapTx represents a db that can be used to create, store and fetch @@ -112,7 +113,9 @@ func (p *privacyMapDB) beginTx(writable bool) (*privacyMapTx, error) { // returned. // // NOTE: this is part of the PrivacyMapDB interface. -func (p *privacyMapDB) Update(f func(tx PrivacyMapTx) error) error { +func (p *privacyMapDB) Update(ctx context.Context, f func(ctx context.Context, + tx PrivacyMapTx) error) error { + tx, err := p.beginTx(true) if err != nil { return err @@ -125,7 +128,7 @@ func (p *privacyMapDB) Update(f func(tx PrivacyMapTx) error) error { } }() - err = f(tx) + err = f(ctx, tx) if err != nil { // Want to return the original error, not a rollback error if // any occur. @@ -142,7 +145,9 @@ func (p *privacyMapDB) Update(f func(tx PrivacyMapTx) error) error { // occur). // // NOTE: this is part of the PrivacyMapDB interface. -func (p *privacyMapDB) View(f func(tx PrivacyMapTx) error) error { +func (p *privacyMapDB) View(ctx context.Context, f func(ctx context.Context, + tx PrivacyMapTx) error) error { + tx, err := p.beginTx(false) if err != nil { return err @@ -155,7 +160,7 @@ func (p *privacyMapDB) View(f func(tx PrivacyMapTx) error) error { } }() - err = f(tx) + err = f(ctx, tx) rollbackErr := tx.boltTx.Rollback() if err != nil { return err diff --git a/firewalldb/privacy_mapper_test.go b/firewalldb/privacy_mapper_test.go index 5ba9d50fe..7a48881a3 100644 --- a/firewalldb/privacy_mapper_test.go +++ b/firewalldb/privacy_mapper_test.go @@ -1,6 +1,7 @@ package firewalldb import ( + "context" "fmt" "testing" @@ -9,6 +10,9 @@ import ( // TestPrivacyMapStorage tests the privacy mapper CRUD logic. func TestPrivacyMapStorage(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() db, err := NewDB(tmpDir, "test.db", nil) require.NoError(t, err) @@ -18,7 +22,7 @@ func TestPrivacyMapStorage(t *testing.T) { pdb1 := db.PrivacyDB([4]byte{1, 1, 1, 1}) - _ = pdb1.Update(func(tx PrivacyMapTx) error { + _ = pdb1.Update(ctx, func(ctx context.Context, tx PrivacyMapTx) error { _, err = tx.RealToPseudo("real") require.ErrorIs(t, err, ErrNoSuchKeyFound) @@ -48,7 +52,7 @@ func TestPrivacyMapStorage(t *testing.T) { pdb2 := db.PrivacyDB([4]byte{2, 2, 2, 2}) - _ = pdb2.Update(func(tx PrivacyMapTx) error { + _ = pdb2.Update(ctx, func(ctx context.Context, tx PrivacyMapTx) error { _, err = tx.RealToPseudo("real") require.ErrorIs(t, err, ErrNoSuchKeyFound) @@ -78,7 +82,7 @@ func TestPrivacyMapStorage(t *testing.T) { pdb3 := db.PrivacyDB([4]byte{3, 3, 3, 3}) - _ = pdb3.Update(func(tx PrivacyMapTx) error { + _ = pdb3.Update(ctx, func(ctx context.Context, tx PrivacyMapTx) error { // Check that calling FetchAllPairs returns an empty map if // nothing exists in the DB yet. m, err := tx.FetchAllPairs() @@ -180,6 +184,9 @@ func TestPrivacyMapStorage(t *testing.T) { // provide atomic access to the db. If anything fails in the middle of an // `Update` function, then all the changes prior should be rolled back. func TestPrivacyMapTxs(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() db, err := NewDB(tmpDir, "test.db", nil) require.NoError(t, err) @@ -191,7 +198,9 @@ func TestPrivacyMapTxs(t *testing.T) { // Test that if an action fails midway through the transaction, then // it is rolled back. - err = pdb1.Update(func(tx PrivacyMapTx) error { + err = pdb1.Update(ctx, func(ctx context.Context, + tx PrivacyMapTx) error { + err := tx.NewPair("real", "pseudo") if err != nil { return err @@ -208,7 +217,7 @@ func TestPrivacyMapTxs(t *testing.T) { }) require.Error(t, err) - err = pdb1.View(func(tx PrivacyMapTx) error { + err = pdb1.View(ctx, func(ctx context.Context, tx PrivacyMapTx) error { _, err := tx.RealToPseudo("real") return err }) diff --git a/rules/chan_policy_bounds.go b/rules/chan_policy_bounds.go index 9ba90ded6..55b79598e 100644 --- a/rules/chan_policy_bounds.go +++ b/rules/chan_policy_bounds.go @@ -396,8 +396,8 @@ func (f *ChanPolicyBounds) RuleName() string { // find the real values. This is a no-op for the ChanPolicyBounds rule. // // NOTE: this is part of the Values interface. -func (f *ChanPolicyBounds) PseudoToReal(_ firewalldb.PrivacyMapDB, - _ session.PrivacyFlags) (Values, error) { +func (f *ChanPolicyBounds) PseudoToReal(_ context.Context, + _ firewalldb.PrivacyMapDB, _ session.PrivacyFlags) (Values, error) { return f, nil } @@ -407,8 +407,9 @@ func (f *ChanPolicyBounds) PseudoToReal(_ firewalldb.PrivacyMapDB, // that should be persisted. This is a no-op for the ChanPolicyBounds rule. // // NOTE: this is part of the Values interface. -func (f *ChanPolicyBounds) RealToPseudo(_ firewalldb.PrivacyMapReader, - _ session.PrivacyFlags) (Values, map[string]string, error) { +func (f *ChanPolicyBounds) RealToPseudo(_ context.Context, + _ firewalldb.PrivacyMapReader, _ session.PrivacyFlags) (Values, + map[string]string, error) { return f, nil, nil } diff --git a/rules/channel_constraints.go b/rules/channel_constraints.go index e50e30df3..8e8524b20 100644 --- a/rules/channel_constraints.go +++ b/rules/channel_constraints.go @@ -333,8 +333,8 @@ func (v *ChannelConstraint) RuleName() string { // find the real values. This is a no-op for the ChannelConstraint rule. // // NOTE: this is part of the Values interface. -func (v *ChannelConstraint) PseudoToReal(_ firewalldb.PrivacyMapDB, - _ session.PrivacyFlags) (Values, error) { +func (v *ChannelConstraint) PseudoToReal(_ context.Context, + _ firewalldb.PrivacyMapDB, _ session.PrivacyFlags) (Values, error) { return v, nil } @@ -344,8 +344,9 @@ func (v *ChannelConstraint) PseudoToReal(_ firewalldb.PrivacyMapDB, // that should be persisted. This is a no-op for the ChannelConstraint rule. // // NOTE: this is part of the Values interface. -func (v *ChannelConstraint) RealToPseudo(_ firewalldb.PrivacyMapReader, - _ session.PrivacyFlags) (Values, map[string]string, error) { +func (v *ChannelConstraint) RealToPseudo(_ context.Context, + _ firewalldb.PrivacyMapReader, _ session.PrivacyFlags) (Values, + map[string]string, error) { return v, nil, nil } diff --git a/rules/channel_restrictions.go b/rules/channel_restrictions.go index 745ed85be..7481ce3e9 100644 --- a/rules/channel_restrictions.go +++ b/rules/channel_restrictions.go @@ -336,8 +336,9 @@ func (c *ChannelRestrict) ToProto() *litrpc.RuleValue { // It constructs a new ChannelRestrict instance with these real channel IDs. // // NOTE: this is part of the Values interface. -func (c *ChannelRestrict) PseudoToReal(db firewalldb.PrivacyMapDB, - flags session.PrivacyFlags) (Values, error) { +func (c *ChannelRestrict) PseudoToReal(ctx context.Context, + db firewalldb.PrivacyMapDB, flags session.PrivacyFlags) (Values, + error) { restrictList := make([]uint64, len(c.DenyList)) @@ -348,7 +349,9 @@ func (c *ChannelRestrict) PseudoToReal(db firewalldb.PrivacyMapDB, return &ChannelRestrict{DenyList: restrictList}, nil } - err := db.View(func(tx firewalldb.PrivacyMapTx) error { + err := db.View(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + for i, chanID := range c.DenyList { real, err := firewalldb.RevealUint64(tx, chanID) if err != nil { @@ -372,7 +375,8 @@ func (c *ChannelRestrict) PseudoToReal(db firewalldb.PrivacyMapDB, // not find in the given PrivacyMapReader. // // NOTE: this is part of the Values interface. -func (c *ChannelRestrict) RealToPseudo(db firewalldb.PrivacyMapReader, +func (c *ChannelRestrict) RealToPseudo(_ context.Context, + db firewalldb.PrivacyMapReader, flags session.PrivacyFlags) (Values, map[string]string, error) { pseudoIDs := make([]uint64, len(c.DenyList)) diff --git a/rules/channel_restrictions_test.go b/rules/channel_restrictions_test.go index a12c80916..b2d6200c3 100644 --- a/rules/channel_restrictions_test.go +++ b/rules/channel_restrictions_test.go @@ -167,6 +167,9 @@ func (m *mockLndClient) ListChannels(_ context.Context, _, _ bool) ( // method correctly determines which real strings to generate pseudo pairs for // based on the privacy map db passed to it. func TestChannelRestrictRealToPseudo(t *testing.T) { + t.Parallel() + + ctx := context.Background() chanID1 := firewalldb.Uint64ToStr(1) chanID2 := firewalldb.Uint64ToStr(2) chanID3 := firewalldb.Uint64ToStr(3) @@ -249,7 +252,7 @@ func TestChannelRestrictRealToPseudo(t *testing.T) { // form along with any new privacy map pairs that should // be added to the DB. v, newPairs, err := cr.RealToPseudo( - privMapPairDB, test.privacyFlags, + ctx, privMapPairDB, test.privacyFlags, ) require.NoError(t, err) require.Len(t, newPairs, len(test.expectNewPairs)) diff --git a/rules/history_limit.go b/rules/history_limit.go index dccebef44..be2894f42 100644 --- a/rules/history_limit.go +++ b/rules/history_limit.go @@ -256,8 +256,8 @@ func (h *HistoryLimit) GetStartDate() time.Time { // find the real values. This is a no-op for the HistoryLimit rule. // // NOTE: this is part of the Values interface. -func (h *HistoryLimit) PseudoToReal(_ firewalldb.PrivacyMapDB, - _ session.PrivacyFlags) (Values, error) { +func (h *HistoryLimit) PseudoToReal(_ context.Context, + _ firewalldb.PrivacyMapDB, _ session.PrivacyFlags) (Values, error) { return h, nil } @@ -267,8 +267,9 @@ func (h *HistoryLimit) PseudoToReal(_ firewalldb.PrivacyMapDB, // that should be persisted. This is a no-op for the HistoryLimit rule. // // NOTE: this is part of the Values interface. -func (h *HistoryLimit) RealToPseudo(_ firewalldb.PrivacyMapReader, - _ session.PrivacyFlags) (Values, map[string]string, error) { +func (h *HistoryLimit) RealToPseudo(_ context.Context, + _ firewalldb.PrivacyMapReader, _ session.PrivacyFlags) (Values, + map[string]string, error) { return h, nil, nil } diff --git a/rules/interfaces.go b/rules/interfaces.go index a1683c4c5..e657a5c03 100644 --- a/rules/interfaces.go +++ b/rules/interfaces.go @@ -64,13 +64,13 @@ type Values interface { // keys, channel IDs, channel points etc. It returns a map of any new // real to pseudo strings that should be persisted that it did not find // in the given PrivacyMapReader. - RealToPseudo(db firewalldb.PrivacyMapReader, + RealToPseudo(ctx context.Context, db firewalldb.PrivacyMapReader, flags session.PrivacyFlags) (Values, map[string]string, error) // PseudoToReal attempts to convert any appropriate pseudo fields in // the rule Values to their corresponding real values. It uses the // passed PrivacyMapDB to find the real values. - PseudoToReal(db firewalldb.PrivacyMapDB, + PseudoToReal(ctx context.Context, db firewalldb.PrivacyMapDB, flags session.PrivacyFlags) (Values, error) } diff --git a/rules/onchain_budget.go b/rules/onchain_budget.go index 248b2f699..783e3a664 100644 --- a/rules/onchain_budget.go +++ b/rules/onchain_budget.go @@ -363,8 +363,8 @@ func (o *OnChainBudget) ToProto() *litrpc.RuleValue { // find the real values. This is a no-op for the OnChainBudget rule. // // NOTE: this is part of the Values interface. -func (o *OnChainBudget) PseudoToReal(_ firewalldb.PrivacyMapDB, - _ session.PrivacyFlags) (Values, error) { +func (o *OnChainBudget) PseudoToReal(_ context.Context, + _ firewalldb.PrivacyMapDB, _ session.PrivacyFlags) (Values, error) { return o, nil } @@ -374,8 +374,9 @@ func (o *OnChainBudget) PseudoToReal(_ firewalldb.PrivacyMapDB, // that should be persisted. This is a no-op for the OnChainBudget rule. // // NOTE: this is part of the Values interface. -func (o *OnChainBudget) RealToPseudo(db firewalldb.PrivacyMapReader, - flags session.PrivacyFlags) (Values, map[string]string, error) { +func (o *OnChainBudget) RealToPseudo(_ context.Context, + _ firewalldb.PrivacyMapReader, _ session.PrivacyFlags) (Values, + map[string]string, error) { return o, nil, nil } diff --git a/rules/peer_restrictions.go b/rules/peer_restrictions.go index fbaefe94c..cb5e40f10 100644 --- a/rules/peer_restrictions.go +++ b/rules/peer_restrictions.go @@ -381,8 +381,9 @@ func (c *PeerRestrict) ToProto() *litrpc.RuleValue { // It constructs a new PeerRestrict instance with these real peer IDs. // // NOTE: this is part of the Values interface. -func (c *PeerRestrict) PseudoToReal(db firewalldb.PrivacyMapDB, - flags session.PrivacyFlags) (Values, error) { +func (c *PeerRestrict) PseudoToReal(ctx context.Context, + db firewalldb.PrivacyMapDB, flags session.PrivacyFlags) (Values, + error) { restrictList := make([]string, len(c.DenyList)) @@ -393,7 +394,9 @@ func (c *PeerRestrict) PseudoToReal(db firewalldb.PrivacyMapDB, return &PeerRestrict{DenyList: restrictList}, nil } - err := db.View(func(tx firewalldb.PrivacyMapTx) error { + err := db.View(ctx, func(_ context.Context, + tx firewalldb.PrivacyMapTx) error { + for i, peerPubKey := range c.DenyList { real, err := firewalldb.RevealString(tx, peerPubKey) if err != nil { @@ -418,7 +421,8 @@ func (c *PeerRestrict) PseudoToReal(db firewalldb.PrivacyMapDB, // find in the given PrivacyMapReader. // // NOTE: this is part of the Values interface. -func (c *PeerRestrict) RealToPseudo(db firewalldb.PrivacyMapReader, +func (c *PeerRestrict) RealToPseudo(_ context.Context, + db firewalldb.PrivacyMapReader, flags session.PrivacyFlags) (Values, map[string]string, error) { pseudoIDs := make([]string, len(c.DenyList)) diff --git a/rules/peer_restrictions_test.go b/rules/peer_restrictions_test.go index faa3c18d3..abfa30540 100644 --- a/rules/peer_restrictions_test.go +++ b/rules/peer_restrictions_test.go @@ -204,6 +204,9 @@ func TestPeerRestrictCheckRequest(t *testing.T) { // method correctly determines which real strings to generate pseudo pairs for // based on the privacy map db passed to it. func TestPeerRestrictRealToPseudo(t *testing.T) { + t.Parallel() + ctx := context.Background() + tests := []struct { name string privacyFlags session.PrivacyFlags @@ -276,7 +279,7 @@ func TestPeerRestrictRealToPseudo(t *testing.T) { // form along with any new privacy map pairs that should // be added to the DB. v, newPairs, err := pr.RealToPseudo( - privMapPairDB, test.privacyFlags, + ctx, privMapPairDB, test.privacyFlags, ) require.NoError(t, err) require.Len(t, newPairs, len(test.expectNewPairs)) diff --git a/rules/rate_limit.go b/rules/rate_limit.go index 4bff4bbe0..f324721a0 100644 --- a/rules/rate_limit.go +++ b/rules/rate_limit.go @@ -267,8 +267,8 @@ func (r *RateLimit) ToProto() *litrpc.RuleValue { // find the real values. This is a no-op for the RateLimit rule. // // NOTE: this is part of the Values interface. -func (r *RateLimit) PseudoToReal(_ firewalldb.PrivacyMapDB, - _ session.PrivacyFlags) (Values, error) { +func (r *RateLimit) PseudoToReal(_ context.Context, + _ firewalldb.PrivacyMapDB, _ session.PrivacyFlags) (Values, error) { return r, nil } @@ -278,8 +278,9 @@ func (r *RateLimit) PseudoToReal(_ firewalldb.PrivacyMapDB, // that should be persisted. This is a no-op for the RateLimit rule. // // NOTE: this is part of the Values interface. -func (r *RateLimit) RealToPseudo(_ firewalldb.PrivacyMapReader, - flags session.PrivacyFlags) (Values, map[string]string, error) { +func (r *RateLimit) RealToPseudo(_ context.Context, + _ firewalldb.PrivacyMapReader, flags session.PrivacyFlags) (Values, + map[string]string, error) { return r, nil, nil } diff --git a/session_rpcserver.go b/session_rpcserver.go index 092ca2e7e..43055f657 100644 --- a/session_rpcserver.go +++ b/session_rpcserver.go @@ -355,7 +355,7 @@ func (s *sessionRpcServer) AddSession(ctx context.Context, return nil, fmt.Errorf("error fetching session: %v", err) } - rpcSession, err := s.marshalRPCSession(sess) + rpcSession, err := s.marshalRPCSession(ctx, sess) if err != nil { return nil, fmt.Errorf("error marshaling session: %v", err) } @@ -557,7 +557,7 @@ func (s *sessionRpcServer) ListSessions(ctx context.Context, Sessions: make([]*litrpc.Session, len(sessions)), } for idx, sess := range sessions { - response.Sessions[idx], err = s.marshalRPCSession(sess) + response.Sessions[idx], err = s.marshalRPCSession(ctx, sess) if err != nil { return nil, fmt.Errorf("error marshaling session: %v", err) @@ -629,7 +629,9 @@ func (s *sessionRpcServer) PrivacyMapConversion(ctx context.Context, var res string privMap := s.cfg.privMap(groupID) - err = privMap.View(func(tx firewalldb.PrivacyMapTx) error { + err = privMap.View(ctx, func(_ context.Context, + tx firewalldb.PrivacyMapTx) error { + var err error if req.RealToPseudo { res, err = tx.RealToPseudo(req.Input) @@ -899,7 +901,9 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, linkedGroupSession = groupSess privDB := s.cfg.privMap(groupID) - err = privDB.View(func(tx firewalldb.PrivacyMapTx) error { + err = privDB.View(ctx, func(_ context.Context, + tx firewalldb.PrivacyMapTx) error { + knownPrivMapPairs, err = tx.FetchAllPairs() return err @@ -1002,7 +1006,8 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, if privacy { var privMapPairs map[string]string v, privMapPairs, err = v.RealToPseudo( - knownPrivMapPairs, privacyFlags, + ctx, knownPrivMapPairs, + privacyFlags, ) if err != nil { return nil, err @@ -1221,7 +1226,9 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, // Register all the privacy map pairs for this session ID. privDB := s.cfg.privMap(sess.GroupID) - err = privDB.Update(func(tx firewalldb.PrivacyMapTx) error { + err = privDB.Update(ctx, func(_ context.Context, + tx firewalldb.PrivacyMapTx) error { + for r, p := range newPrivMapPairs { err := tx.NewPair(r, p) if err != nil { @@ -1272,7 +1279,7 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, return nil, fmt.Errorf("error fetching session: %v", err) } - rpcSession, err := s.marshalRPCSession(sess) + rpcSession, err := s.marshalRPCSession(ctx, sess) if err != nil { return nil, fmt.Errorf("error marshaling session: %v", err) } @@ -1297,7 +1304,7 @@ func (s *sessionRpcServer) ListAutopilotSessions(ctx context.Context, Sessions: make([]*litrpc.Session, len(sessions)), } for idx, sess := range sessions { - response.Sessions[idx], err = s.marshalRPCSession(sess) + response.Sessions[idx], err = s.marshalRPCSession(ctx, sess) if err != nil { return nil, fmt.Errorf("error marshaling session: %v", err) @@ -1426,8 +1433,8 @@ func marshalPerms(perms map[string][]bakery.Op) []*litrpc.Permissions { } // marshalRPCSession converts a session into its RPC counterpart. -func (s *sessionRpcServer) marshalRPCSession(sess *session.Session) ( - *litrpc.Session, error) { +func (s *sessionRpcServer) marshalRPCSession(ctx context.Context, + sess *session.Session) (*litrpc.Session, error) { rpcState, err := marshalRPCState(sess.State) if err != nil { @@ -1484,7 +1491,8 @@ func (s *sessionRpcServer) marshalRPCSession(sess *session.Session) ( sess.GroupID, ) val, err = val.PseudoToReal( - db, sess.PrivacyFlags, + ctx, db, + sess.PrivacyFlags, ) if err != nil { return nil, err From 197ee3b5ba721908647ffeee9241aea6288b5701 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 13 Mar 2025 16:22:39 -0500 Subject: [PATCH 2/7] firewalldb: thread context to PrivMap NewPair Update the NewPair method of the PrivacyMapTx interface to take a context. --- firewall/privacy_mapper.go | 60 ++++++++++++++++--------------- firewall/privacy_mapper_test.go | 6 ++-- firewalldb/privacy_mapper.go | 34 +++++++++++------- firewalldb/privacy_mapper_test.go | 18 +++++----- session_rpcserver.go | 4 +-- 5 files changed, 67 insertions(+), 55 deletions(-) diff --git a/firewall/privacy_mapper.go b/firewall/privacy_mapper.go index fd077c916..91848cb08 100644 --- a/firewall/privacy_mapper.go +++ b/firewall/privacy_mapper.go @@ -336,8 +336,8 @@ func handleGetInfoResponse(db firewalldb.PrivacyMapDB, tx firewalldb.PrivacyMapTx) error { var err error - pseudoPubKey, err = firewalldb.HideString( - tx, r.IdentityPubkey, + pseudoPubKey, err = firewalldb.HideString( //nolint:lll + ctx, tx, r.IdentityPubkey, ) return err @@ -397,14 +397,14 @@ func handleFwdHistoryResponse(db firewalldb.PrivacyMapDB, if !flags.Contains(session.ClearChanIDs) { // Deterministically hide channel ids. chanIn, err = firewalldb.HideUint64( - tx, chanIn, + ctx, tx, chanIn, ) if err != nil { return err } chanOut, err = firewalldb.HideUint64( - tx, chanOut, + ctx, tx, chanOut, ) if err != nil { return err @@ -500,7 +500,7 @@ func handleFeeReportResponse(db firewalldb.PrivacyMapDB, chanID := c.ChanId if !flags.Contains(session.ClearChanIDs) { chanID, err = firewalldb.HideUint64( - tx, chanID, + ctx, tx, chanID, ) if err != nil { return err @@ -510,7 +510,7 @@ func handleFeeReportResponse(db firewalldb.PrivacyMapDB, chanPoint := c.ChannelPoint if !flags.Contains(session.ClearChanIDs) { chanPoint, err = firewalldb.HideChanPointStr( - tx, chanPoint, + ctx, tx, chanPoint, ) if err != nil { return err @@ -599,7 +599,7 @@ func handleListChannelsResponse(db firewalldb.PrivacyMapDB, remotePub := c.RemotePubkey if hidePubkeys { remotePub, err = firewalldb.HideString( - tx, c.RemotePubkey, + ctx, tx, c.RemotePubkey, ) if err != nil { return err @@ -610,14 +610,14 @@ func handleListChannelsResponse(db firewalldb.PrivacyMapDB, chanID := c.ChanId if hideChanIds { chanPoint, err = firewalldb.HideChanPointStr( - tx, c.ChannelPoint, + ctx, tx, c.ChannelPoint, ) if err != nil { return err } chanID, err = firewalldb.HideUint64( - tx, c.ChanId, + ctx, tx, c.ChanId, ) if err != nil { return err @@ -830,7 +830,7 @@ func handleUpdatePolicyResponse(db firewalldb.PrivacyMapDB, } txid, index, err := firewalldb.HideChanPoint( - tx, u.Outpoint.TxidStr, + ctx, tx, u.Outpoint.TxidStr, u.Outpoint.OutputIndex, ) if err != nil { @@ -957,7 +957,7 @@ func handleClosedChannelsResponse(db firewalldb.PrivacyMapDB, remotePub := c.RemotePubkey if !flags.Contains(session.ClearPubkeys) { remotePub, err = firewalldb.HideString( - tx, remotePub, + ctx, tx, remotePub, ) if err != nil { return err @@ -985,7 +985,7 @@ func handleClosedChannelsResponse(db firewalldb.PrivacyMapDB, channelPoint := c.ChannelPoint if !flags.Contains(session.ClearChanIDs) { channelPoint, err = firewalldb.HideChanPointStr( - tx, c.ChannelPoint, + ctx, tx, c.ChannelPoint, ) if err != nil { return err @@ -995,7 +995,7 @@ func handleClosedChannelsResponse(db firewalldb.PrivacyMapDB, chanID := c.ChanId if !flags.Contains(session.ClearChanIDs) { chanID, err = firewalldb.HideUint64( - tx, c.ChanId, + ctx, tx, c.ChanId, ) if err != nil { return err @@ -1005,7 +1005,7 @@ func handleClosedChannelsResponse(db firewalldb.PrivacyMapDB, closingTxid := c.ClosingTxHash if !flags.Contains(session.ClearClosingTxIds) { closingTxid, err = firewalldb.HideString( - tx, c.ClosingTxHash, + ctx, tx, c.ClosingTxHash, ) if err != nil { return err @@ -1052,7 +1052,8 @@ func handleClosedChannelsResponse(db firewalldb.PrivacyMapDB, // obfuscatePendingChannel is a helper to obfuscate the fields of a pending // channel. -func obfuscatePendingChannel(c *lnrpc.PendingChannelsResponse_PendingChannel, +func obfuscatePendingChannel(ctx context.Context, + c *lnrpc.PendingChannelsResponse_PendingChannel, tx firewalldb.PrivacyMapTx, randIntn func(int) (int, error), flags session.PrivacyFlags) ( *lnrpc.PendingChannelsResponse_PendingChannel, error) { @@ -1062,7 +1063,7 @@ func obfuscatePendingChannel(c *lnrpc.PendingChannelsResponse_PendingChannel, remotePub := c.RemoteNodePub if !flags.Contains(session.ClearPubkeys) { remotePub, err = firewalldb.HideString( - tx, remotePub, + ctx, tx, remotePub, ) if err != nil { return nil, err @@ -1099,7 +1100,7 @@ func obfuscatePendingChannel(c *lnrpc.PendingChannelsResponse_PendingChannel, chanPoint := c.ChannelPoint if !flags.Contains(session.ClearChanIDs) { chanPoint, err = firewalldb.HideChanPointStr( - tx, c.ChannelPoint, + ctx, tx, c.ChannelPoint, ) if err != nil { return nil, err @@ -1163,7 +1164,7 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB, var err error pendingChannel, err := obfuscatePendingChannel( - c.Channel, tx, randIntn, flags, + ctx, c.Channel, tx, randIntn, flags, ) if err != nil { return err @@ -1187,7 +1188,7 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB, var err error pendingChannel, err := obfuscatePendingChannel( - c.Channel, tx, randIntn, flags, + ctx, c.Channel, tx, randIntn, flags, ) if err != nil { return err @@ -1195,8 +1196,8 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB, closingTxid := c.ClosingTxid if !flags.Contains(session.ClearClosingTxIds) { - closingTxid, err = firewalldb.HideString( - tx, c.ClosingTxid, + closingTxid, err = firewalldb.HideString( //nolint:lll + ctx, tx, c.ClosingTxid, ) if err != nil { return err @@ -1216,7 +1217,7 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB, var err error pendingChannel, err := obfuscatePendingChannel( - c.Channel, tx, randIntn, flags, + ctx, c.Channel, tx, randIntn, flags, ) if err != nil { return err @@ -1225,7 +1226,7 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB, closingTxid := c.ClosingTxid if !flags.Contains(session.ClearClosingTxIds) { closingTxid, err = firewalldb.HideString( - tx, c.ClosingTxid, + ctx, tx, c.ClosingTxid, ) if err != nil { return err @@ -1277,7 +1278,7 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB, var err error pendingChannel, err := obfuscatePendingChannel( - c.Channel, tx, randIntn, flags, + ctx, c.Channel, tx, randIntn, flags, ) if err != nil { return err @@ -1297,7 +1298,7 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB, closingTxid := c.ClosingTxid if !flags.Contains(session.ClearClosingTxIds) { closingTxid, err = firewalldb.HideString( - tx, closingTxid, + ctx, tx, closingTxid, ) if err != nil { return err @@ -1314,7 +1315,7 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB, ) { closingTxHex, err = firewalldb.HideString( - tx, closingTxHex, + ctx, tx, closingTxHex, ) if err != nil { return err @@ -1454,8 +1455,9 @@ func handleBatchOpenChannelResponse(db firewalldb.PrivacyMapDB, return err } - txID, outIdx, err := firewalldb.HideChanPoint( - tx, txId.String(), p.OutputIndex, + txID, outIdx, err := firewalldb.HideChanPoint( //nolint:lll + ctx, tx, txId.String(), + p.OutputIndex, ) if err != nil { return err @@ -1600,7 +1602,7 @@ func handleChannelOpenResponse(db firewalldb.PrivacyMapDB, if !flags.Contains(session.ClearChanIDs) { txid, index, err = firewalldb.HideChanPoint( - tx, txid, index, + ctx, tx, txid, index, ) if err != nil { return err diff --git a/firewall/privacy_mapper_test.go b/firewall/privacy_mapper_test.go index 24582f8ce..7c84054b6 100644 --- a/firewall/privacy_mapper_test.go +++ b/firewall/privacy_mapper_test.go @@ -1077,7 +1077,7 @@ func newMockDB(t *testing.T, preloadRealToPseudo map[string]string, tx firewalldb.PrivacyMapTx) error { for r, p := range preloadRealToPseudo { - require.NoError(t, tx.NewPair(r, p)) + require.NoError(t, tx.NewPair(ctx, r, p)) } return nil }) @@ -1121,7 +1121,9 @@ func (m *mockPrivacyMapDB) View(ctx context.Context, return f(ctx, m) } -func (m *mockPrivacyMapDB) NewPair(real, pseudo string) error { +func (m *mockPrivacyMapDB) NewPair(_ context.Context, real, + pseudo string) error { + m.r2p[real] = pseudo m.p2r[pseudo] = real return nil diff --git a/firewalldb/privacy_mapper.go b/firewalldb/privacy_mapper.go index e4fee472a..fb9524b46 100644 --- a/firewalldb/privacy_mapper.go +++ b/firewalldb/privacy_mapper.go @@ -71,7 +71,7 @@ type PrivacyMapDB interface { // real-pseudo pairs. type PrivacyMapTx interface { // NewPair persists a new real-pseudo pair. - NewPair(real, pseudo string) error + NewPair(ctx context.Context, real, pseudo string) error // PseudoToReal returns the real value associated with the given pseudo // value. If no such pair is found, then ErrNoSuchKeyFound is returned. @@ -181,7 +181,7 @@ type privacyMapTx struct { // NewPair inserts a new real-pseudo pair into the db. // // NOTE: this is part of the PrivacyMapTx interface. -func (p *privacyMapTx) NewPair(real, pseudo string) error { +func (p *privacyMapTx) NewPair(_ context.Context, real, pseudo string) error { privacyBucket, err := getBucket(p.boltTx, privacyBucketKey) if err != nil { return err @@ -314,7 +314,9 @@ func (p *privacyMapTx) FetchAllPairs() (*PrivacyMapPairs, error) { return NewPrivacyMapPairs(pairs), nil } -func HideString(tx PrivacyMapTx, real string) (string, error) { +func HideString(ctx context.Context, tx PrivacyMapTx, real string) (string, + error) { + pseudo, err := tx.RealToPseudo(real) if err != nil && err != ErrNoSuchKeyFound { return "", err @@ -328,7 +330,7 @@ func HideString(tx PrivacyMapTx, real string) (string, error) { return "", err } - if err = tx.NewPair(real, pseudo); err != nil { + if err = tx.NewPair(ctx, real, pseudo); err != nil { return "", err } @@ -360,7 +362,9 @@ func RevealString(tx PrivacyMapTx, pseudo string) (string, error) { return tx.PseudoToReal(pseudo) } -func HideUint64(tx PrivacyMapTx, real uint64) (uint64, error) { +func HideUint64(ctx context.Context, tx PrivacyMapTx, real uint64) (uint64, + error) { + str := Uint64ToStr(real) pseudo, err := tx.RealToPseudo(str) if err != nil && err != ErrNoSuchKeyFound { @@ -371,7 +375,7 @@ func HideUint64(tx PrivacyMapTx, real uint64) (uint64, error) { } pseudoUint64, pseudoUint64Str := NewPseudoUint64() - if err := tx.NewPair(str, pseudoUint64Str); err != nil { + if err := tx.NewPair(ctx, str, pseudoUint64Str); err != nil { return 0, err } @@ -391,8 +395,8 @@ func RevealUint64(tx PrivacyMapTx, pseudo uint64) (uint64, error) { return StrToUint64(real) } -func HideChanPoint(tx PrivacyMapTx, txid string, index uint32) (string, - uint32, error) { +func HideChanPoint(ctx context.Context, tx PrivacyMapTx, txid string, + index uint32) (string, uint32, error) { cp := fmt.Sprintf("%s:%d", txid, index) pseudo, err := tx.RealToPseudo(cp) @@ -408,7 +412,7 @@ func HideChanPoint(tx PrivacyMapTx, txid string, index uint32) (string, return "", 0, err } - if err := tx.NewPair(cp, newCp); err != nil { + if err := tx.NewPair(ctx, cp, newCp); err != nil { return "", 0, err } @@ -444,13 +448,15 @@ func NewPseudoUint32() uint32 { return binary.BigEndian.Uint32(b) } -func HideChanPointStr(tx PrivacyMapTx, cp string) (string, error) { +func HideChanPointStr(ctx context.Context, tx PrivacyMapTx, cp string) (string, + error) { + txid, index, err := DecodeChannelPoint(cp) if err != nil { return "", err } - newTxid, newIndex, err := HideChanPoint(tx, txid, index) + newTxid, newIndex, err := HideChanPoint(ctx, tx, txid, index) if err != nil { return "", err } @@ -458,10 +464,12 @@ func HideChanPointStr(tx PrivacyMapTx, cp string) (string, error) { return fmt.Sprintf("%s:%d", newTxid, newIndex), nil } -func HideBytes(tx PrivacyMapTx, realBytes []byte) ([]byte, error) { +func HideBytes(ctx context.Context, tx PrivacyMapTx, realBytes []byte) ([]byte, + error) { + real := hex.EncodeToString(realBytes) - pseudo, err := HideString(tx, real) + pseudo, err := HideString(ctx, tx, real) if err != nil { return nil, err } diff --git a/firewalldb/privacy_mapper_test.go b/firewalldb/privacy_mapper_test.go index 7a48881a3..03a8584ea 100644 --- a/firewalldb/privacy_mapper_test.go +++ b/firewalldb/privacy_mapper_test.go @@ -29,7 +29,7 @@ func TestPrivacyMapStorage(t *testing.T) { _, err = tx.PseudoToReal("pseudo") require.ErrorIs(t, err, ErrNoSuchKeyFound) - err = tx.NewPair("real", "pseudo") + err = tx.NewPair(ctx, "real", "pseudo") require.NoError(t, err) pseudo, err := tx.RealToPseudo("real") @@ -59,7 +59,7 @@ func TestPrivacyMapStorage(t *testing.T) { _, err = tx.PseudoToReal("pseudo") require.ErrorIs(t, err, ErrNoSuchKeyFound) - err = tx.NewPair("real 2", "pseudo 2") + err = tx.NewPair(ctx, "real 2", "pseudo 2") require.NoError(t, err) pseudo, err := tx.RealToPseudo("real 2") @@ -90,29 +90,29 @@ func TestPrivacyMapStorage(t *testing.T) { require.Empty(t, m.pairs) // Add a new pair. - err = tx.NewPair("real 1", "pseudo 1") + err = tx.NewPair(ctx, "real 1", "pseudo 1") require.NoError(t, err) // Try to add a new pair that has the same real value as the // first pair. This should fail. - err = tx.NewPair("real 1", "pseudo 2") + err = tx.NewPair(ctx, "real 1", "pseudo 2") require.ErrorContains(t, err, "an entry already exists for "+ "real value") // Try to add a new pair that has the same pseudo value as the // first pair. This should fail. - err = tx.NewPair("real 2", "pseudo 1") + err = tx.NewPair(ctx, "real 2", "pseudo 1") require.ErrorContains(t, err, "an entry already exists for "+ "pseudo value") // Add a few more pairs. - err = tx.NewPair("real 2", "pseudo 2") + err = tx.NewPair(ctx, "real 2", "pseudo 2") require.NoError(t, err) - err = tx.NewPair("real 3", "pseudo 3") + err = tx.NewPair(ctx, "real 3", "pseudo 3") require.NoError(t, err) - err = tx.NewPair("real 4", "pseudo 4") + err = tx.NewPair(ctx, "real 4", "pseudo 4") require.NoError(t, err) // Check that FetchAllPairs correctly returns all the pairs. @@ -201,7 +201,7 @@ func TestPrivacyMapTxs(t *testing.T) { err = pdb1.Update(ctx, func(ctx context.Context, tx PrivacyMapTx) error { - err := tx.NewPair("real", "pseudo") + err := tx.NewPair(ctx, "real", "pseudo") if err != nil { return err } diff --git a/session_rpcserver.go b/session_rpcserver.go index 43055f657..139a14b89 100644 --- a/session_rpcserver.go +++ b/session_rpcserver.go @@ -1226,11 +1226,11 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, // Register all the privacy map pairs for this session ID. privDB := s.cfg.privMap(sess.GroupID) - err = privDB.Update(ctx, func(_ context.Context, + err = privDB.Update(ctx, func(ctx context.Context, tx firewalldb.PrivacyMapTx) error { for r, p := range newPrivMapPairs { - err := tx.NewPair(r, p) + err := tx.NewPair(ctx, r, p) if err != nil { return err } From 7e8e4a9920c0afa165314744f230d88aa4faf141 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 13 Mar 2025 16:28:39 -0500 Subject: [PATCH 3/7] firewalldb: thread context to PseudoToReal Update the PseudoToReal method of the PrivacyMapTx interface to take a context. --- firewall/privacy_mapper.go | 14 +++++++------- firewall/privacy_mapper_test.go | 4 +++- firewalldb/privacy_mapper.go | 30 +++++++++++++++++++----------- firewalldb/privacy_mapper_test.go | 8 ++++---- rules/channel_restrictions.go | 2 +- rules/peer_restrictions.go | 6 ++++-- session_rpcserver.go | 4 ++-- 7 files changed, 40 insertions(+), 28 deletions(-) diff --git a/firewall/privacy_mapper.go b/firewall/privacy_mapper.go index 91848cb08..cbd8a8da4 100644 --- a/firewall/privacy_mapper.go +++ b/firewall/privacy_mapper.go @@ -559,7 +559,7 @@ func handleListChannelsRequest(db firewalldb.PrivacyMapDB, err := db.View(ctx, func(ctx context.Context, tx firewalldb.PrivacyMapTx) error { - peer, err := firewalldb.RevealBytes(tx, r.Peer) + peer, err := firewalldb.RevealBytes(ctx, tx, r.Peer) if err != nil { return err } @@ -778,8 +778,8 @@ func handleUpdatePolicyRequest(db firewalldb.PrivacyMapDB, tx firewalldb.PrivacyMapTx) error { var err error - newTxid, newIndex, err = firewalldb.RevealChanPoint( - tx, newTxid, newIndex, + newTxid, newIndex, err = firewalldb.RevealChanPoint( //nolint:lll + ctx, tx, newTxid, newIndex, ) return err }) @@ -1380,7 +1380,7 @@ func handleBatchOpenChannelRequest(db firewalldb.PrivacyMapDB, nodePubkey := c.NodePubkey if !flags.Contains(session.ClearPubkeys) { nodePubkey, err = firewalldb.RevealBytes( - tx, c.NodePubkey, + ctx, tx, c.NodePubkey, ) if err != nil { return err @@ -1518,7 +1518,7 @@ func handleChannelOpenRequest(db firewalldb.PrivacyMapDB, if !flags.Contains(session.ClearPubkeys) { nodePubkey, err = firewalldb.RevealBytes( - tx, nodePubkey, + ctx, tx, nodePubkey, ) if err != nil { return err @@ -1665,7 +1665,7 @@ func handleConnectPeerRequest(db firewalldb.PrivacyMapDB, pubkey := r.Addr.Pubkey if !flags.Contains(session.ClearPubkeys) { pubkey, err = firewalldb.RevealString( - tx, r.Addr.Pubkey, + ctx, tx, r.Addr.Pubkey, ) if err != nil { return err @@ -1675,7 +1675,7 @@ func handleConnectPeerRequest(db firewalldb.PrivacyMapDB, host := r.Addr.Host if !flags.Contains(session.ClearNetworkAddresses) { host, err = firewalldb.RevealString( - tx, r.Addr.Host, + ctx, tx, r.Addr.Host, ) if err != nil { return err diff --git a/firewall/privacy_mapper_test.go b/firewall/privacy_mapper_test.go index 7c84054b6..49b8837bd 100644 --- a/firewall/privacy_mapper_test.go +++ b/firewall/privacy_mapper_test.go @@ -1129,7 +1129,9 @@ func (m *mockPrivacyMapDB) NewPair(_ context.Context, real, return nil } -func (m *mockPrivacyMapDB) PseudoToReal(pseudo string) (string, error) { +func (m *mockPrivacyMapDB) PseudoToReal(_ context.Context, pseudo string) ( + string, error) { + r, ok := m.p2r[pseudo] if !ok { return "", firewalldb.ErrNoSuchKeyFound diff --git a/firewalldb/privacy_mapper.go b/firewalldb/privacy_mapper.go index fb9524b46..f4155516c 100644 --- a/firewalldb/privacy_mapper.go +++ b/firewalldb/privacy_mapper.go @@ -75,7 +75,7 @@ type PrivacyMapTx interface { // PseudoToReal returns the real value associated with the given pseudo // value. If no such pair is found, then ErrNoSuchKeyFound is returned. - PseudoToReal(pseudo string) (string, error) + PseudoToReal(ctx context.Context, pseudo string) (string, error) // RealToPseudo returns the pseudo value associated with the given real // value. If no such pair is found, then ErrNoSuchKeyFound is returned. @@ -228,7 +228,9 @@ func (p *privacyMapTx) NewPair(_ context.Context, real, pseudo string) error { // it does then the real value is returned, else an error is returned. // // NOTE: this is part of the PrivacyMapTx interface. -func (p *privacyMapTx) PseudoToReal(pseudo string) (string, error) { +func (p *privacyMapTx) PseudoToReal(_ context.Context, pseudo string) (string, + error) { + privacyBucket, err := getBucket(p.boltTx, privacyBucketKey) if err != nil { return "", err @@ -354,12 +356,14 @@ func NewPseudoStr(n int) (string, error) { return string(b), nil } -func RevealString(tx PrivacyMapTx, pseudo string) (string, error) { +func RevealString(ctx context.Context, tx PrivacyMapTx, pseudo string) (string, + error) { + if pseudo == "" { return pseudo, nil } - return tx.PseudoToReal(pseudo) + return tx.PseudoToReal(ctx, pseudo) } func HideUint64(ctx context.Context, tx PrivacyMapTx, real uint64) (uint64, @@ -382,12 +386,14 @@ func HideUint64(ctx context.Context, tx PrivacyMapTx, real uint64) (uint64, return pseudoUint64, nil } -func RevealUint64(tx PrivacyMapTx, pseudo uint64) (uint64, error) { +func RevealUint64(ctx context.Context, tx PrivacyMapTx, pseudo uint64) (uint64, + error) { + if pseudo == 0 { return 0, nil } - real, err := tx.PseudoToReal(Uint64ToStr(pseudo)) + real, err := tx.PseudoToReal(ctx, Uint64ToStr(pseudo)) if err != nil { return 0, err } @@ -429,11 +435,11 @@ func NewPseudoChanPoint() (string, error) { return fmt.Sprintf("%s:%d", pseudoTXID, pseudoIndex), nil } -func RevealChanPoint(tx PrivacyMapTx, txid string, index uint32) (string, - uint32, error) { +func RevealChanPoint(ctx context.Context, tx PrivacyMapTx, txid string, + index uint32) (string, uint32, error) { fakePoint := fmt.Sprintf("%s:%d", txid, index) - real, err := tx.PseudoToReal(fakePoint) + real, err := tx.PseudoToReal(ctx, fakePoint) if err != nil { return "", 0, err } @@ -477,13 +483,15 @@ func HideBytes(ctx context.Context, tx PrivacyMapTx, realBytes []byte) ([]byte, return hex.DecodeString(pseudo) } -func RevealBytes(tx PrivacyMapTx, pseudoBytes []byte) ([]byte, error) { +func RevealBytes(ctx context.Context, tx PrivacyMapTx, + pseudoBytes []byte) ([]byte, error) { + if pseudoBytes == nil { return nil, nil } pseudo := hex.EncodeToString(pseudoBytes) - pseudo, err := RevealString(tx, pseudo) + pseudo, err := RevealString(ctx, tx, pseudo) if err != nil { return nil, err } diff --git a/firewalldb/privacy_mapper_test.go b/firewalldb/privacy_mapper_test.go index 03a8584ea..bb4462e91 100644 --- a/firewalldb/privacy_mapper_test.go +++ b/firewalldb/privacy_mapper_test.go @@ -26,7 +26,7 @@ func TestPrivacyMapStorage(t *testing.T) { _, err = tx.RealToPseudo("real") require.ErrorIs(t, err, ErrNoSuchKeyFound) - _, err = tx.PseudoToReal("pseudo") + _, err = tx.PseudoToReal(ctx, "pseudo") require.ErrorIs(t, err, ErrNoSuchKeyFound) err = tx.NewPair(ctx, "real", "pseudo") @@ -36,7 +36,7 @@ func TestPrivacyMapStorage(t *testing.T) { require.NoError(t, err) require.Equal(t, "pseudo", pseudo) - real, err := tx.PseudoToReal("pseudo") + real, err := tx.PseudoToReal(ctx, "pseudo") require.NoError(t, err) require.Equal(t, "real", real) @@ -56,7 +56,7 @@ func TestPrivacyMapStorage(t *testing.T) { _, err = tx.RealToPseudo("real") require.ErrorIs(t, err, ErrNoSuchKeyFound) - _, err = tx.PseudoToReal("pseudo") + _, err = tx.PseudoToReal(ctx, "pseudo") require.ErrorIs(t, err, ErrNoSuchKeyFound) err = tx.NewPair(ctx, "real 2", "pseudo 2") @@ -66,7 +66,7 @@ func TestPrivacyMapStorage(t *testing.T) { require.NoError(t, err) require.Equal(t, "pseudo 2", pseudo) - real, err := tx.PseudoToReal("pseudo 2") + real, err := tx.PseudoToReal(ctx, "pseudo 2") require.NoError(t, err) require.Equal(t, "real 2", real) diff --git a/rules/channel_restrictions.go b/rules/channel_restrictions.go index 7481ce3e9..c6cb134d8 100644 --- a/rules/channel_restrictions.go +++ b/rules/channel_restrictions.go @@ -353,7 +353,7 @@ func (c *ChannelRestrict) PseudoToReal(ctx context.Context, tx firewalldb.PrivacyMapTx) error { for i, chanID := range c.DenyList { - real, err := firewalldb.RevealUint64(tx, chanID) + real, err := firewalldb.RevealUint64(ctx, tx, chanID) if err != nil { return err } diff --git a/rules/peer_restrictions.go b/rules/peer_restrictions.go index cb5e40f10..009ee8ab8 100644 --- a/rules/peer_restrictions.go +++ b/rules/peer_restrictions.go @@ -394,11 +394,13 @@ func (c *PeerRestrict) PseudoToReal(ctx context.Context, return &PeerRestrict{DenyList: restrictList}, nil } - err := db.View(ctx, func(_ context.Context, + err := db.View(ctx, func(ctx context.Context, tx firewalldb.PrivacyMapTx) error { for i, peerPubKey := range c.DenyList { - real, err := firewalldb.RevealString(tx, peerPubKey) + real, err := firewalldb.RevealString( + ctx, tx, peerPubKey, + ) if err != nil { return err } diff --git a/session_rpcserver.go b/session_rpcserver.go index 139a14b89..a70eeb095 100644 --- a/session_rpcserver.go +++ b/session_rpcserver.go @@ -629,7 +629,7 @@ func (s *sessionRpcServer) PrivacyMapConversion(ctx context.Context, var res string privMap := s.cfg.privMap(groupID) - err = privMap.View(ctx, func(_ context.Context, + err = privMap.View(ctx, func(ctx context.Context, tx firewalldb.PrivacyMapTx) error { var err error @@ -638,7 +638,7 @@ func (s *sessionRpcServer) PrivacyMapConversion(ctx context.Context, return err } - res, err = tx.PseudoToReal(req.Input) + res, err = tx.PseudoToReal(ctx, req.Input) return err }) if err != nil { From 5b31f16446a97ec93f3c4a4c8b3e1e8f96ea7fdd Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 13 Mar 2025 16:30:35 -0500 Subject: [PATCH 4/7] firewalldb: thread context to RealToPseudo Update the RealToPseudo method of the PrivacyMapTx interface to take a context. --- firewall/privacy_mapper_test.go | 4 +++- firewalldb/privacy_mapper.go | 12 +++++++----- firewalldb/privacy_mapper_test.go | 12 ++++++------ session_rpcserver.go | 2 +- 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/firewall/privacy_mapper_test.go b/firewall/privacy_mapper_test.go index 49b8837bd..685c7cb3b 100644 --- a/firewall/privacy_mapper_test.go +++ b/firewall/privacy_mapper_test.go @@ -1140,7 +1140,9 @@ func (m *mockPrivacyMapDB) PseudoToReal(_ context.Context, pseudo string) ( return r, nil } -func (m *mockPrivacyMapDB) RealToPseudo(real string) (string, error) { +func (m *mockPrivacyMapDB) RealToPseudo(_ context.Context, real string) (string, + error) { + p, ok := m.r2p[real] if !ok { return "", firewalldb.ErrNoSuchKeyFound diff --git a/firewalldb/privacy_mapper.go b/firewalldb/privacy_mapper.go index f4155516c..6ae8b9faa 100644 --- a/firewalldb/privacy_mapper.go +++ b/firewalldb/privacy_mapper.go @@ -79,7 +79,7 @@ type PrivacyMapTx interface { // RealToPseudo returns the pseudo value associated with the given real // value. If no such pair is found, then ErrNoSuchKeyFound is returned. - RealToPseudo(real string) (string, error) + RealToPseudo(ctx context.Context, real string) (string, error) // FetchAllPairs loads and returns the real-to-pseudo pairs in the form // of a PrivacyMapPairs struct. @@ -258,7 +258,9 @@ func (p *privacyMapTx) PseudoToReal(_ context.Context, pseudo string) (string, // it does then the pseudo value is returned, else an error is returned. // // NOTE: this is part of the PrivacyMapTx interface. -func (p *privacyMapTx) RealToPseudo(real string) (string, error) { +func (p *privacyMapTx) RealToPseudo(_ context.Context, real string) (string, + error) { + privacyBucket, err := getBucket(p.boltTx, privacyBucketKey) if err != nil { return "", err @@ -319,7 +321,7 @@ func (p *privacyMapTx) FetchAllPairs() (*PrivacyMapPairs, error) { func HideString(ctx context.Context, tx PrivacyMapTx, real string) (string, error) { - pseudo, err := tx.RealToPseudo(real) + pseudo, err := tx.RealToPseudo(ctx, real) if err != nil && err != ErrNoSuchKeyFound { return "", err } @@ -370,7 +372,7 @@ func HideUint64(ctx context.Context, tx PrivacyMapTx, real uint64) (uint64, error) { str := Uint64ToStr(real) - pseudo, err := tx.RealToPseudo(str) + pseudo, err := tx.RealToPseudo(ctx, str) if err != nil && err != ErrNoSuchKeyFound { return 0, err } @@ -405,7 +407,7 @@ func HideChanPoint(ctx context.Context, tx PrivacyMapTx, txid string, index uint32) (string, uint32, error) { cp := fmt.Sprintf("%s:%d", txid, index) - pseudo, err := tx.RealToPseudo(cp) + pseudo, err := tx.RealToPseudo(ctx, cp) if err != nil && err != ErrNoSuchKeyFound { return "", 0, err } diff --git a/firewalldb/privacy_mapper_test.go b/firewalldb/privacy_mapper_test.go index bb4462e91..95c08aeb0 100644 --- a/firewalldb/privacy_mapper_test.go +++ b/firewalldb/privacy_mapper_test.go @@ -23,7 +23,7 @@ func TestPrivacyMapStorage(t *testing.T) { pdb1 := db.PrivacyDB([4]byte{1, 1, 1, 1}) _ = pdb1.Update(ctx, func(ctx context.Context, tx PrivacyMapTx) error { - _, err = tx.RealToPseudo("real") + _, err = tx.RealToPseudo(ctx, "real") require.ErrorIs(t, err, ErrNoSuchKeyFound) _, err = tx.PseudoToReal(ctx, "pseudo") @@ -32,7 +32,7 @@ func TestPrivacyMapStorage(t *testing.T) { err = tx.NewPair(ctx, "real", "pseudo") require.NoError(t, err) - pseudo, err := tx.RealToPseudo("real") + pseudo, err := tx.RealToPseudo(ctx, "real") require.NoError(t, err) require.Equal(t, "pseudo", pseudo) @@ -53,7 +53,7 @@ func TestPrivacyMapStorage(t *testing.T) { pdb2 := db.PrivacyDB([4]byte{2, 2, 2, 2}) _ = pdb2.Update(ctx, func(ctx context.Context, tx PrivacyMapTx) error { - _, err = tx.RealToPseudo("real") + _, err = tx.RealToPseudo(ctx, "real") require.ErrorIs(t, err, ErrNoSuchKeyFound) _, err = tx.PseudoToReal(ctx, "pseudo") @@ -62,7 +62,7 @@ func TestPrivacyMapStorage(t *testing.T) { err = tx.NewPair(ctx, "real 2", "pseudo 2") require.NoError(t, err) - pseudo, err := tx.RealToPseudo("real 2") + pseudo, err := tx.RealToPseudo(ctx, "real 2") require.NoError(t, err) require.Equal(t, "pseudo 2", pseudo) @@ -206,7 +206,7 @@ func TestPrivacyMapTxs(t *testing.T) { return err } - p, err := tx.RealToPseudo("real") + p, err := tx.RealToPseudo(ctx, "real") if err != nil { return err } @@ -218,7 +218,7 @@ func TestPrivacyMapTxs(t *testing.T) { require.Error(t, err) err = pdb1.View(ctx, func(ctx context.Context, tx PrivacyMapTx) error { - _, err := tx.RealToPseudo("real") + _, err := tx.RealToPseudo(ctx, "real") return err }) require.ErrorIs(t, err, ErrNoSuchKeyFound) diff --git a/session_rpcserver.go b/session_rpcserver.go index a70eeb095..a3a3b339c 100644 --- a/session_rpcserver.go +++ b/session_rpcserver.go @@ -634,7 +634,7 @@ func (s *sessionRpcServer) PrivacyMapConversion(ctx context.Context, var err error if req.RealToPseudo { - res, err = tx.RealToPseudo(req.Input) + res, err = tx.RealToPseudo(ctx, req.Input) return err } From ef936114892f2cd9daeeedf03af070cb757a1c32 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 13 Mar 2025 16:32:23 -0500 Subject: [PATCH 5/7] firewalldb: thread contexts to FetchAllPairs Update the FetchAllPairs method of the PrivacyMapTx interface to take a context. --- firewall/privacy_mapper_test.go | 4 ++-- firewalldb/privacy_mapper.go | 6 ++++-- firewalldb/privacy_mapper_test.go | 8 ++++---- session_rpcserver.go | 4 ++-- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/firewall/privacy_mapper_test.go b/firewall/privacy_mapper_test.go index 685c7cb3b..1998d1280 100644 --- a/firewall/privacy_mapper_test.go +++ b/firewall/privacy_mapper_test.go @@ -1151,8 +1151,8 @@ func (m *mockPrivacyMapDB) RealToPseudo(_ context.Context, real string) (string, return p, nil } -func (m *mockPrivacyMapDB) FetchAllPairs() (*firewalldb.PrivacyMapPairs, - error) { +func (m *mockPrivacyMapDB) FetchAllPairs(_ context.Context) ( + *firewalldb.PrivacyMapPairs, error) { return firewalldb.NewPrivacyMapPairs(m.r2p), nil } diff --git a/firewalldb/privacy_mapper.go b/firewalldb/privacy_mapper.go index 6ae8b9faa..eadb49339 100644 --- a/firewalldb/privacy_mapper.go +++ b/firewalldb/privacy_mapper.go @@ -83,7 +83,7 @@ type PrivacyMapTx interface { // FetchAllPairs loads and returns the real-to-pseudo pairs in the form // of a PrivacyMapPairs struct. - FetchAllPairs() (*PrivacyMapPairs, error) + FetchAllPairs(ctx context.Context) (*PrivacyMapPairs, error) } // privacyMapDB is an implementation of PrivacyMapDB. @@ -287,7 +287,9 @@ func (p *privacyMapTx) RealToPseudo(_ context.Context, real string) (string, // FetchAllPairs loads and returns the real-to-pseudo pairs. // // NOTE: this is part of the PrivacyMapTx interface. -func (p *privacyMapTx) FetchAllPairs() (*PrivacyMapPairs, error) { +func (p *privacyMapTx) FetchAllPairs(_ context.Context) (*PrivacyMapPairs, + error) { + privacyBucket, err := getBucket(p.boltTx, privacyBucketKey) if err != nil { return nil, err diff --git a/firewalldb/privacy_mapper_test.go b/firewalldb/privacy_mapper_test.go index 95c08aeb0..7be4d3b64 100644 --- a/firewalldb/privacy_mapper_test.go +++ b/firewalldb/privacy_mapper_test.go @@ -40,7 +40,7 @@ func TestPrivacyMapStorage(t *testing.T) { require.NoError(t, err) require.Equal(t, "real", real) - pairs, err := tx.FetchAllPairs() + pairs, err := tx.FetchAllPairs(ctx) require.NoError(t, err) require.EqualValues(t, pairs.pairs, map[string]string{ @@ -70,7 +70,7 @@ func TestPrivacyMapStorage(t *testing.T) { require.NoError(t, err) require.Equal(t, "real 2", real) - pairs, err := tx.FetchAllPairs() + pairs, err := tx.FetchAllPairs(ctx) require.NoError(t, err) require.EqualValues(t, pairs.pairs, map[string]string{ @@ -85,7 +85,7 @@ func TestPrivacyMapStorage(t *testing.T) { _ = pdb3.Update(ctx, func(ctx context.Context, tx PrivacyMapTx) error { // Check that calling FetchAllPairs returns an empty map if // nothing exists in the DB yet. - m, err := tx.FetchAllPairs() + m, err := tx.FetchAllPairs(ctx) require.NoError(t, err) require.Empty(t, m.pairs) @@ -116,7 +116,7 @@ func TestPrivacyMapStorage(t *testing.T) { require.NoError(t, err) // Check that FetchAllPairs correctly returns all the pairs. - pairs, err := tx.FetchAllPairs() + pairs, err := tx.FetchAllPairs(ctx) require.NoError(t, err) require.EqualValues(t, pairs.pairs, map[string]string{ diff --git a/session_rpcserver.go b/session_rpcserver.go index a3a3b339c..652196f59 100644 --- a/session_rpcserver.go +++ b/session_rpcserver.go @@ -901,10 +901,10 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, linkedGroupSession = groupSess privDB := s.cfg.privMap(groupID) - err = privDB.View(ctx, func(_ context.Context, + err = privDB.View(ctx, func(ctx context.Context, tx firewalldb.PrivacyMapTx) error { - knownPrivMapPairs, err = tx.FetchAllPairs() + knownPrivMapPairs, err = tx.FetchAllPairs(ctx) return err }) From 62723bd2f3ecc1ea55d5100591a98fcbfcbad61b Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 11 Mar 2025 15:01:26 -0500 Subject: [PATCH 6/7] firewalldb: use DBExecutor for PrivacyMapDB In this commit, we delete the PrivacyMapDB interface definition and instead use the generic DBExecutor interface parameterised by a PrivacyMapTx to define the PrivacyMapDB interface. --- firewalldb/privacy_mapper.go | 88 +++++++----------------------------- 1 file changed, 17 insertions(+), 71 deletions(-) diff --git a/firewalldb/privacy_mapper.go b/firewalldb/privacy_mapper.go index eadb49339..2db18594c 100644 --- a/firewalldb/privacy_mapper.go +++ b/firewalldb/privacy_mapper.go @@ -43,7 +43,7 @@ type NewPrivacyMapDB func(groupID session.ID) PrivacyMapDB // group ID key. func (db *DB) PrivacyDB(groupID session.ID) PrivacyMapDB { return &privacyMapDB{ - DB: db, + db: db, groupID: groupID, } } @@ -51,21 +51,7 @@ func (db *DB) PrivacyDB(groupID session.ID) PrivacyMapDB { // PrivacyMapDB provides an Update and View method that will allow the caller // to perform atomic read and write transactions defined by PrivacyMapTx on the // underlying DB. -type PrivacyMapDB interface { - // Update opens a database read/write transaction and executes the - // function f with the transaction passed as a parameter. After f exits, - // if f did not error, the transaction is committed. Otherwise, if f did - // error, the transaction is rolled back. If the rollback fails, the - // original error returned by f is still returned. If the commit fails, - // the commit error is returned. - Update(context.Context, func(context.Context, PrivacyMapTx) error) error - - // View opens a database read transaction and executes the function f - // with the transaction passed as a parameter. After f exits, the - // transaction is rolled back. If f errors, its error is returned, not a - // rollback error (if any occur). - View(context.Context, func(context.Context, PrivacyMapTx) error) error -} +type PrivacyMapDB = DBExecutor[PrivacyMapTx] // PrivacyMapTx represents a db that can be used to create, store and fetch // real-pseudo pairs. @@ -88,23 +74,10 @@ type PrivacyMapTx interface { // privacyMapDB is an implementation of PrivacyMapDB. type privacyMapDB struct { - *DB + db *DB groupID session.ID } -// beginTx starts db transaction. The transaction will be a read or read-write -// transaction depending on the value of the `writable` parameter. -func (p *privacyMapDB) beginTx(writable bool) (*privacyMapTx, error) { - boltTx, err := p.Begin(writable) - if err != nil { - return nil, err - } - return &privacyMapTx{ - privacyMapDB: p, - boltTx: boltTx, - }, nil -} - // Update opens a database read/write transaction and executes the function f // with the transaction passed as a parameter. After f exits, if f did not // error, the transaction is committed. Otherwise, if f did error, the @@ -113,30 +86,17 @@ func (p *privacyMapDB) beginTx(writable bool) (*privacyMapTx, error) { // returned. // // NOTE: this is part of the PrivacyMapDB interface. -func (p *privacyMapDB) Update(ctx context.Context, f func(ctx context.Context, +func (p *privacyMapDB) Update(ctx context.Context, fn func(ctx context.Context, tx PrivacyMapTx) error) error { - tx, err := p.beginTx(true) - if err != nil { - return err - } - - // Make sure the transaction rolls back in the event of a panic. - defer func() { - if tx != nil { - _ = tx.boltTx.Rollback() + return p.db.Update(func(tx *bbolt.Tx) error { + boltTx := &privacyMapTx{ + privacyMapDB: p, + boltTx: tx, } - }() - - err = f(ctx, tx) - if err != nil { - // Want to return the original error, not a rollback error if - // any occur. - _ = tx.boltTx.Rollback() - return err - } - return tx.boltTx.Commit() + return fn(ctx, boltTx) + }) } // View opens a database read transaction and executes the function f with the @@ -145,31 +105,17 @@ func (p *privacyMapDB) Update(ctx context.Context, f func(ctx context.Context, // occur). // // NOTE: this is part of the PrivacyMapDB interface. -func (p *privacyMapDB) View(ctx context.Context, f func(ctx context.Context, +func (p *privacyMapDB) View(ctx context.Context, fn func(ctx context.Context, tx PrivacyMapTx) error) error { - tx, err := p.beginTx(false) - if err != nil { - return err - } - - // Make sure the transaction rolls back in the event of a panic. - defer func() { - if tx != nil { - _ = tx.boltTx.Rollback() + return p.db.View(func(tx *bbolt.Tx) error { + boltTx := &privacyMapTx{ + privacyMapDB: p, + boltTx: tx, } - }() - err = f(ctx, tx) - rollbackErr := tx.boltTx.Rollback() - if err != nil { - return err - } - - if rollbackErr != nil { - return rollbackErr - } - return nil + return fn(ctx, boltTx) + }) } // privacyMapTx is an implementation of PrivacyMapTx. From a395cb9b3fef3b20b2c2f00aef94e155a44599b2 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 11 Mar 2025 15:37:42 -0500 Subject: [PATCH 7/7] firewaldb: single implementation of BBolt DBExecutor The `kvStores` and `privacyMapDB` types have very similar looking `Update` and `View` methods. Instead of the duplication, here we let things be more generic by defining a generic `kvdbExecutor` which has Update and View methods defined. --- firewalldb/kvdb_store.go | 44 +++++++++++++++++++++++++++ firewalldb/kvstores.go | 58 ++++++++---------------------------- firewalldb/privacy_mapper.go | 54 +++++++-------------------------- 3 files changed, 66 insertions(+), 90 deletions(-) create mode 100644 firewalldb/kvdb_store.go diff --git a/firewalldb/kvdb_store.go b/firewalldb/kvdb_store.go new file mode 100644 index 000000000..d4ce79f20 --- /dev/null +++ b/firewalldb/kvdb_store.go @@ -0,0 +1,44 @@ +package firewalldb + +import ( + "context" + + "go.etcd.io/bbolt" +) + +// kvdbExecutor is a concrete implementation of the DBExecutor interface that +// uses a bbolt database as its backing store. +type kvdbExecutor[T any] struct { + db *bbolt.DB + wrapTx func(tx *bbolt.Tx) T +} + +// Update opens a database read/write transaction and executes the function f +// with the transaction passed as a parameter. After f exits, if f did not +// error, the transaction is committed. Otherwise, if f did error, the +// transaction is rolled back. If the rollback fails, the original error +// returned by f is still returned. If the commit fails, the commit error is +// returned. +// +// NOTE: this is part of the DBExecutor interface. +func (e *kvdbExecutor[T]) Update(ctx context.Context, + fn func(ctx context.Context, tx T) error) error { + + return e.db.Update(func(tx *bbolt.Tx) error { + return fn(ctx, e.wrapTx(tx)) + }) +} + +// View opens a database read transaction and executes the function f with the +// transaction passed as a parameter. After f exits, the transaction is rolled +// back. If f errors, its error is returned, not a rollback error (if any +// occur). +// +// NOTE: this is part of the DBExecutor interface. +func (e *kvdbExecutor[T]) View(ctx context.Context, + fn func(ctx context.Context, tx T) error) error { + + return e.db.View(func(tx *bbolt.Tx) error { + return fn(ctx, e.wrapTx(tx)) + }) +} diff --git a/firewalldb/kvstores.go b/firewalldb/kvstores.go index 1dffd54cf..9dad0a0cc 100644 --- a/firewalldb/kvstores.go +++ b/firewalldb/kvstores.go @@ -107,62 +107,28 @@ type RulesDB interface { func (db *DB) GetKVStores(rule string, groupID session.ID, feature string) KVStores { - return &kvStores{ - db: db.DB, - ruleName: rule, - groupID: groupID, - featureName: feature, + return &kvdbExecutor[KVStoreTx]{ + db: db.DB, + wrapTx: func(tx *bbolt.Tx) KVStoreTx { + return &kvStoreTx{ + boltTx: tx, + kvStores: &kvStores{ + ruleName: rule, + groupID: groupID, + featureName: feature, + }, + } + }, } } // kvStores implements the rules.KVStores interface. type kvStores struct { - db *bbolt.DB ruleName string groupID session.ID featureName string } -// Update opens a database read/write transaction and executes the function f -// with the transaction passed as a parameter. After f exits, if f did not -// error, the transaction is committed. Otherwise, if f did error, the -// transaction is rolled back. If the rollback fails, the original error -// returned by f is still returned. If the commit fails, the commit error is -// returned. -// -// NOTE: this is part of the KVStores interface. -func (s *kvStores) Update(ctx context.Context, fn func(ctx context.Context, - tx KVStoreTx) error) error { - - return s.db.Update(func(tx *bbolt.Tx) error { - boltTx := &kvStoreTx{ - boltTx: tx, - kvStores: s, - } - - return fn(ctx, boltTx) - }) -} - -// View opens a database read transaction and executes the function f with the -// transaction passed as a parameter. After f exits, the transaction is rolled -// back. If f errors, its error is returned, not a rollback error (if any -// occur). -// -// NOTE: this is part of the KVStores interface. -func (s *kvStores) View(ctx context.Context, fn func(ctx context.Context, - tx KVStoreTx) error) error { - - return s.db.View(func(tx *bbolt.Tx) error { - boltTx := &kvStoreTx{ - boltTx: tx, - kvStores: s, - } - - return fn(ctx, boltTx) - }) -} - // getBucketFunc defines the signature of the bucket creation/fetching function // required by kvStoreTx. If create is true, then all the bucket (and all // buckets leading up to the bucket) should be created if they do not already diff --git a/firewalldb/privacy_mapper.go b/firewalldb/privacy_mapper.go index 2db18594c..ab8e60e40 100644 --- a/firewalldb/privacy_mapper.go +++ b/firewalldb/privacy_mapper.go @@ -42,9 +42,16 @@ type NewPrivacyMapDB func(groupID session.ID) PrivacyMapDB // PrivacyDB constructs a PrivacyMapDB that will be indexed under the given // group ID key. func (db *DB) PrivacyDB(groupID session.ID) PrivacyMapDB { - return &privacyMapDB{ - db: db, - groupID: groupID, + return &kvdbExecutor[PrivacyMapTx]{ + db: db.DB, + wrapTx: func(tx *bbolt.Tx) PrivacyMapTx { + return &privacyMapTx{ + boltTx: tx, + privacyMapDB: &privacyMapDB{ + groupID: groupID, + }, + } + }, } } @@ -74,50 +81,9 @@ type PrivacyMapTx interface { // privacyMapDB is an implementation of PrivacyMapDB. type privacyMapDB struct { - db *DB groupID session.ID } -// Update opens a database read/write transaction and executes the function f -// with the transaction passed as a parameter. After f exits, if f did not -// error, the transaction is committed. Otherwise, if f did error, the -// transaction is rolled back. If the rollback fails, the original error -// returned by f is still returned. If the commit fails, the commit error is -// returned. -// -// NOTE: this is part of the PrivacyMapDB interface. -func (p *privacyMapDB) Update(ctx context.Context, fn func(ctx context.Context, - tx PrivacyMapTx) error) error { - - return p.db.Update(func(tx *bbolt.Tx) error { - boltTx := &privacyMapTx{ - privacyMapDB: p, - boltTx: tx, - } - - return fn(ctx, boltTx) - }) -} - -// View opens a database read transaction and executes the function f with the -// transaction passed as a parameter. After f exits, the transaction is rolled -// back. If f errors, its error is returned, not a rollback error (if any -// occur). -// -// NOTE: this is part of the PrivacyMapDB interface. -func (p *privacyMapDB) View(ctx context.Context, fn func(ctx context.Context, - tx PrivacyMapTx) error) error { - - return p.db.View(func(tx *bbolt.Tx) error { - boltTx := &privacyMapTx{ - privacyMapDB: p, - boltTx: tx, - } - - return fn(ctx, boltTx) - }) -} - // privacyMapTx is an implementation of PrivacyMapTx. type privacyMapTx struct { *privacyMapDB