diff --git a/go.mod b/go.mod index 393dcc3a57b..5f802984515 100644 --- a/go.mod +++ b/go.mod @@ -222,4 +222,7 @@ replace google.golang.org/protobuf => github.com/lightninglabs/protobuf-go-hex-d // well). go 1.24.6 +// Temporary replace until dependent PR is merged in lightning-onion. +replace github.com/lightningnetwork/lightning-onion => github.com/joostjager/lightning-onion v0.0.0-20250630141312-2898b9c46c4e + retract v0.0.2 diff --git a/go.sum b/go.sum index 4f9d613bbc0..e89ee8d3c83 100644 --- a/go.sum +++ b/go.sum @@ -306,6 +306,8 @@ github.com/jessevdk/go-flags v1.4.0 h1:4IU2WS7AumrZ/40jfhf4QVDMsQwqA7VEHozFRrGAR github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jonboulle/clockwork v0.2.2 h1:UOGuzwb1PwsrDAObMuhUnj0p5ULPj8V/xJ7Kx9qUBdQ= github.com/jonboulle/clockwork v0.2.2/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= +github.com/joostjager/lightning-onion v0.0.0-20250630141312-2898b9c46c4e h1:kwxUmYn+qyX4olGy7TxgUeXpmnaMjf4+/bn9Ke9w0GU= +github.com/joostjager/lightning-onion v0.0.0-20250630141312-2898b9c46c4e/go.mod h1:EDqJ3MuZIbMq0QI1czTIKDJ/GS8S14RXPwapHw8cw6w= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/jrick/logrotate v1.0.0/go.mod h1:LNinyqDIJnpAur+b8yyulnQw/wDuN1+BYKlTRt3OuAQ= github.com/jrick/logrotate v1.1.2 h1:6ePk462NCX7TfKtNp5JJ7MbA2YIslkpfgP03TlTYMN0= @@ -368,8 +370,6 @@ github.com/lightninglabs/neutrino/cache v1.1.2 h1:C9DY/DAPaPxbFC+xNNEI/z1SJY9GS3 github.com/lightninglabs/neutrino/cache v1.1.2/go.mod h1:XJNcgdOw1LQnanGjw8Vj44CvguYA25IMKjWFZczwZuo= github.com/lightninglabs/protobuf-go-hex-display v1.33.0-hex-display h1:Y2WiPkBS/00EiEg0qp0FhehxnQfk3vv8U6Xt3nN+rTY= github.com/lightninglabs/protobuf-go-hex-display v1.33.0-hex-display/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= -github.com/lightningnetwork/lightning-onion v1.2.1-0.20240815225420-8b40adf04ab9 h1:6D3LrdagJweLLdFm1JNodZsBk6iU4TTsBBFLQ4yiXfI= -github.com/lightningnetwork/lightning-onion v1.2.1-0.20240815225420-8b40adf04ab9/go.mod h1:EDqJ3MuZIbMq0QI1czTIKDJ/GS8S14RXPwapHw8cw6w= github.com/lightningnetwork/lnd/cert v1.2.2 h1:71YK6hogeJtxSxw2teq3eGeuy4rHGKcFf0d0Uy4qBjI= github.com/lightningnetwork/lnd/cert v1.2.2/go.mod h1:jQmFn/Ez4zhDgq2hnYSw8r35bqGVxViXhX6Cd7HXM6U= github.com/lightningnetwork/lnd/clock v1.1.1 h1:OfR3/zcJd2RhH0RU+zX/77c0ZiOnIMsDIBjgjWdZgA0= diff --git a/htlcswitch/circuit.go b/htlcswitch/circuit.go index eab1cdb2002..e18e1f0b21c 100644 --- a/htlcswitch/circuit.go +++ b/htlcswitch/circuit.go @@ -199,17 +199,18 @@ func (c *PaymentCircuit) Decode(r io.Reader) error { case hop.EncrypterTypeSphinx: // Sphinx encrypter was used as this is a forwarded HTLC. - c.ErrorEncrypter = hop.NewSphinxErrorEncrypter() + c.ErrorEncrypter = hop.NewSphinxErrorEncrypterUninitialized() case hop.EncrypterTypeMock: // Test encrypter. c.ErrorEncrypter = NewMockObfuscator() case hop.EncrypterTypeIntroduction: - c.ErrorEncrypter = hop.NewIntroductionErrorEncrypter() + c.ErrorEncrypter = + hop.NewIntroductionErrorEncrypterUninitialized() case hop.EncrypterTypeRelaying: - c.ErrorEncrypter = hop.NewRelayingErrorEncrypter() + c.ErrorEncrypter = hop.NewRelayingErrorEncrypterUninitialized() default: return UnknownEncrypterType(encrypterType) diff --git a/htlcswitch/circuit_map.go b/htlcswitch/circuit_map.go index 15d4b5ffca4..3be27ee9ceb 100644 --- a/htlcswitch/circuit_map.go +++ b/htlcswitch/circuit_map.go @@ -210,9 +210,9 @@ type CircuitMapConfig struct { FetchClosedChannels func( pendingOnly bool) ([]*channeldb.ChannelCloseSummary, error) - // ExtractErrorEncrypter derives the shared secret used to encrypt - // errors from the obfuscator's ephemeral public key. - ExtractErrorEncrypter hop.ErrorEncrypterExtracter + // ExtractSharedSecret derives the shared secret used to encrypt errors + // from the obfuscator's ephemeral public key. + ExtractSharedSecret hop.SharedSecretGenerator // CheckResolutionMsg checks whether a given resolution message exists // for the passed CircuitKey. @@ -632,9 +632,7 @@ func (cm *circuitMap) decodeCircuit(v []byte) (*PaymentCircuit, error) { // Otherwise, we need to reextract the encrypter, so that the shared // secret is rederived from what was decoded. - err := circuit.ErrorEncrypter.Reextract( - cm.cfg.ExtractErrorEncrypter, - ) + err := circuit.ErrorEncrypter.Reextract(cm.cfg.ExtractSharedSecret) if err != nil { return nil, err } diff --git a/htlcswitch/circuit_test.go b/htlcswitch/circuit_test.go index ddad11aca91..975a1c89899 100644 --- a/htlcswitch/circuit_test.go +++ b/htlcswitch/circuit_test.go @@ -65,16 +65,17 @@ func initTestExtracter() { onionProcessor := newOnionProcessor(nil) defer onionProcessor.Stop() - obfuscator, _ := onionProcessor.ExtractErrorEncrypter( + sharedSecret, failCode := onionProcessor.ExtractSharedSecret( testEphemeralKey, ) - sphinxExtracter, ok := obfuscator.(*hop.SphinxErrorEncrypter) - if !ok { - panic("did not extract sphinx error encrypter") + if failCode != lnwire.CodeNone { + panic("did not extract shared secret") } - testExtracter = sphinxExtracter + testExtracter = hop.NewSphinxErrorEncrypter( + testEphemeralKey, sharedSecret, + ) // We also set this error extracter on startup, otherwise it will be nil // at compile-time. @@ -106,10 +107,10 @@ func newCircuitMap(t *testing.T, resMsg bool) (*htlcswitch.CircuitMapConfig, db := makeCircuitDB(t, "") circuitMapCfg := &htlcswitch.CircuitMapConfig{ - DB: db, - FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels, - FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels, - ExtractErrorEncrypter: onionProcessor.ExtractErrorEncrypter, + DB: db, + FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels, + FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels, + ExtractSharedSecret: onionProcessor.ExtractSharedSecret, } if resMsg { @@ -216,7 +217,7 @@ func TestHalfCircuitSerialization(t *testing.T) { // encrypters, this will be a NOP. if circuit2.ErrorEncrypter != nil { err := circuit2.ErrorEncrypter.Reextract( - onionProcessor.ExtractErrorEncrypter, + onionProcessor.ExtractSharedSecret, ) if err != nil { t.Fatalf("unable to reextract sphinx error "+ @@ -643,11 +644,11 @@ func restartCircuitMap(t *testing.T, cfg *htlcswitch.CircuitMapConfig) ( // Reinitialize circuit map with same db path. db := makeCircuitDB(t, dbPath) cfg2 := &htlcswitch.CircuitMapConfig{ - DB: db, - FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels, - FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels, - ExtractErrorEncrypter: cfg.ExtractErrorEncrypter, - CheckResolutionMsg: cfg.CheckResolutionMsg, + DB: db, + FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels, + FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels, + ExtractSharedSecret: cfg.ExtractSharedSecret, + CheckResolutionMsg: cfg.CheckResolutionMsg, } cm2, err := htlcswitch.NewCircuitMap(cfg2) require.NoError(t, err, "unable to recreate persistent circuit map") diff --git a/htlcswitch/failure.go b/htlcswitch/failure.go index 373263381fd..a3251f36565 100644 --- a/htlcswitch/failure.go +++ b/htlcswitch/failure.go @@ -3,6 +3,7 @@ package htlcswitch import ( "bytes" "fmt" + "strings" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/htlcswitch/hop" @@ -92,6 +93,13 @@ type ForwardingError struct { // be nil in the case where we fail to decode failure message sent by // a peer. msg lnwire.FailureMessage + + // HoldTimes is an array of hold times (in ms) as reported from the + // nodes of the route. It is the time for which a node held the HTLC for + // from that nodes local perspective. The first element corresponds to + // the first node after the sender node, with greater indices indicating + // nodes further down the route. + HoldTimes []uint32 } // WireMessage extracts a valid wire failure message from an internal @@ -116,11 +124,12 @@ func (f *ForwardingError) Error() string { // NewForwardingError creates a new payment error which wraps a wire error // with additional metadata. func NewForwardingError(failure lnwire.FailureMessage, - index int) *ForwardingError { + index int, holdTimes []uint32) *ForwardingError { return &ForwardingError{ FailureSourceIdx: index, msg: failure, + HoldTimes: holdTimes, } } @@ -140,7 +149,7 @@ type ErrorDecrypter interface { // hop, to the source of the error. A fully populated // lnwire.FailureMessage is returned along with the source of the // error. - DecryptError(lnwire.OpaqueReason) (*ForwardingError, error) + DecryptError(lnwire.OpaqueReason, []byte) (*ForwardingError, error) } // UnknownEncrypterType is an error message used to signal that an unexpected @@ -160,13 +169,23 @@ type OnionErrorDecrypter interface { // node where error have occurred. As a result, in order to decrypt the // error we need get all shared secret and apply decryption in the // reverse order. - DecryptError(encryptedData []byte) (*sphinx.DecryptedError, error) + DecryptError(encryptedData, attrData []byte) (*sphinx.DecryptedError, + error) } // SphinxErrorDecrypter wraps the sphinx data SphinxErrorDecrypter and maps the // returned errors to concrete lnwire.FailureMessage instances. type SphinxErrorDecrypter struct { - OnionErrorDecrypter + decrypter *sphinx.OnionErrorDecrypter +} + +// NewSphinxErrorDecrypter instantiates a new error decrypter. +func NewSphinxErrorDecrypter(circuit *sphinx.Circuit) *SphinxErrorDecrypter { + return &SphinxErrorDecrypter{ + decrypter: sphinx.NewOnionErrorDecrypter( + circuit, hop.AttrErrorStruct, + ), + } } // DecryptError peels off each layer of onion encryption from the first hop, to @@ -174,23 +193,46 @@ type SphinxErrorDecrypter struct { // along with the source of the error. // // NOTE: Part of the ErrorDecrypter interface. -func (s *SphinxErrorDecrypter) DecryptError(reason lnwire.OpaqueReason) ( - *ForwardingError, error) { - - failure, err := s.OnionErrorDecrypter.DecryptError(reason) +func (s *SphinxErrorDecrypter) DecryptError(reason lnwire.OpaqueReason, + attrData []byte) (*ForwardingError, error) { + + // We do not set the strict attribution flag, as we want to account for + // the grace period during which nodes are still upgrading to support + // this feature. If set prematurely it can lead to early blame of our + // direct peers that may not support this feature yet, blacklisting our + // channels and failing our payments. + attrErr, err := s.decrypter.DecryptError(reason, attrData, false) if err != nil { return nil, err } + var holdTimes []string + for _, payload := range attrErr.HoldTimes { + // Read hold time. + holdTime := payload + + holdTimes = append( + holdTimes, + fmt.Sprintf("%vms", holdTime*100), + ) + } + + // For now just log the hold times, the collector of the payment result + // should handle this in a more sophisticated way. + log.Debugf("Extracted hold times from onion error: %v", + strings.Join(holdTimes, "/")) + // Decode the failure. If an error occurs, we leave the failure message // field nil. - r := bytes.NewReader(failure.Message) + r := bytes.NewReader(attrErr.Message) failureMsg, err := lnwire.DecodeFailure(r, 0) if err != nil { - return NewUnknownForwardingError(failure.SenderIdx), nil + return NewUnknownForwardingError(attrErr.SenderIdx), nil } - return NewForwardingError(failureMsg, failure.SenderIdx), nil + return NewForwardingError( + failureMsg, attrErr.SenderIdx, attrErr.HoldTimes, + ), nil } // A compile time check to ensure ErrorDecrypter implements the Deobfuscator diff --git a/htlcswitch/failure_test.go b/htlcswitch/failure_test.go index 48ebc668210..d887be5cfb8 100644 --- a/htlcswitch/failure_test.go +++ b/htlcswitch/failure_test.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" sphinx "github.com/lightningnetwork/lightning-onion" + "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tlv" "github.com/stretchr/testify/require" @@ -52,11 +53,13 @@ func TestLongFailureMessage(t *testing.T) { } errorDecryptor := &SphinxErrorDecrypter{ - OnionErrorDecrypter: sphinx.NewOnionErrorDecrypter(circuit), + decrypter: sphinx.NewOnionErrorDecrypter( + circuit, hop.AttrErrorStruct, + ), } // Assert that the failure message can still be extracted. - failure, err := errorDecryptor.DecryptError(reason) + failure, err := errorDecryptor.DecryptError(reason, nil) require.NoError(t, err) incorrectDetails, ok := failure.msg.(*lnwire.FailIncorrectDetails) diff --git a/htlcswitch/hop/error_encryptor.go b/htlcswitch/hop/error_encryptor.go index 23272ec00d4..07fd64a554b 100644 --- a/htlcswitch/hop/error_encryptor.go +++ b/htlcswitch/hop/error_encryptor.go @@ -2,12 +2,15 @@ package hop import ( "bytes" + "errors" "fmt" "io" + "time" "github.com/btcsuite/btcd/btcec/v2" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" ) // EncrypterType establishes an enum used in serialization to indicate how to @@ -37,6 +40,24 @@ const ( // the same functionality as a EncrypterTypeSphinx, but is used to mark // our special-case error handling. EncrypterTypeRelaying = 4 + + // A set of tlv type definitions used to serialize the encrypter to the + // database. + // + // NOTE: A migration should be added whenever this list changes. This + // prevents against the database being rolled back to an older + // format where the surrounding logic might assume a different set of + // fields are known. + creationTimeType tlv.Type = 0 +) + +// AttrErrorStruct defines the message structure for an attributable error. Use +// a maximum route length of 20, a fixed payload length of 4 bytes to +// accommodate the a 32-bit hold time in milliseconds and use 4 byte hmacs. +// Total size including a 256 byte message from the error source works out to +// 1200 bytes. +var ( + AttrErrorStruct = sphinx.NewAttrErrorStructure(20, 4, 4) ) // IsBlinded returns a boolean indicating whether the error encrypter belongs @@ -45,9 +66,9 @@ func (e EncrypterType) IsBlinded() bool { return e == EncrypterTypeIntroduction || e == EncrypterTypeRelaying } -// ErrorEncrypterExtracter defines a function signature that extracts an -// ErrorEncrypter from an sphinx OnionPacket. -type ErrorEncrypterExtracter func(*btcec.PublicKey) (ErrorEncrypter, +// SharedSecretGenerator defines a function signature that extracts a shared +// secret from an sphinx OnionPacket. +type SharedSecretGenerator func(*btcec.PublicKey) (sphinx.Hash256, lnwire.FailCode) // ErrorEncrypter is an interface that is used to encrypt HTLC related errors @@ -58,19 +79,22 @@ type ErrorEncrypter interface { // encrypted opaque failure reason. This method will be used at the // source that the error occurs. It differs from IntermediateEncrypt // slightly, in that it computes a proper MAC over the error. - EncryptFirstHop(lnwire.FailureMessage) (lnwire.OpaqueReason, error) + EncryptFirstHop(lnwire.FailureMessage) (lnwire.OpaqueReason, + []byte, error) // EncryptMalformedError is similar to EncryptFirstHop (it adds the // MAC), but it accepts an opaque failure reason rather than a failure // message. This method is used when we receive an // UpdateFailMalformedHTLC from the remote peer and then need to // convert that into a proper error from only the raw bytes. - EncryptMalformedError(lnwire.OpaqueReason) lnwire.OpaqueReason + EncryptMalformedError(lnwire.OpaqueReason) (lnwire.OpaqueReason, []byte, + error) // IntermediateEncrypt wraps an already encrypted opaque reason error // in an additional layer of onion encryption. This process repeats // until the error arrives at the source of the payment. - IntermediateEncrypt(lnwire.OpaqueReason) lnwire.OpaqueReason + IntermediateEncrypt(lnwire.OpaqueReason, []byte) (lnwire.OpaqueReason, + []byte, error) // Type returns an enum indicating the underlying concrete instance // backing this interface. @@ -84,12 +108,13 @@ type ErrorEncrypter interface { // given io.Reader. Decode(io.Reader) error - // Reextract rederives the encrypter using the extracter, performing an - // ECDH with the sphinx router's key and the ephemeral public key. + // Reextract rederives the encrypter using the shared secret generator, + // performing an ECDH with the sphinx router's key and the ephemeral + // public key. // // NOTE: This should be called shortly after Decode to properly // reinitialize the error encrypter. - Reextract(ErrorEncrypterExtracter) error + Reextract(SharedSecretGenerator) error } // SphinxErrorEncrypter is a concrete implementation of both the ErrorEncrypter @@ -100,20 +125,63 @@ type SphinxErrorEncrypter struct { *sphinx.OnionErrorEncrypter EphemeralKey *btcec.PublicKey + CreatedAt time.Time } -// NewSphinxErrorEncrypter initializes a blank sphinx error encrypter, that -// should be used to deserialize an encoded SphinxErrorEncrypter. Since the -// actual encrypter is not stored in plaintext while at rest, reconstructing the -// error encrypter requires: +// NewSphinxErrorEncrypterUninitialized initializes a blank sphinx error +// encrypter, that should be used to deserialize an encoded +// SphinxErrorEncrypter. Since the actual encrypter is not stored in plaintext +// while at rest, reconstructing the error encrypter requires: // 1. Decode: to deserialize the ephemeral public key. // 2. Reextract: to "unlock" the actual error encrypter using an active // OnionProcessor. -func NewSphinxErrorEncrypter() *SphinxErrorEncrypter { +func NewSphinxErrorEncrypterUninitialized() *SphinxErrorEncrypter { return &SphinxErrorEncrypter{ - OnionErrorEncrypter: nil, - EphemeralKey: &btcec.PublicKey{}, + EphemeralKey: &btcec.PublicKey{}, + } +} + +// NewSphinxErrorEncrypter creates a new instance of a SphinxErrorEncrypter, +// initialized with the provided shared secret. To deserialize an encoded +// SphinxErrorEncrypter, use the NewSphinxErrorEncrypterUninitialized +// constructor. +func NewSphinxErrorEncrypter(ephemeralKey *btcec.PublicKey, + sharedSecret sphinx.Hash256) *SphinxErrorEncrypter { + + encrypter := &SphinxErrorEncrypter{ + EphemeralKey: ephemeralKey, } + + // Set creation time rounded to nanosecond to avoid differences after + // serialization. + encrypter.CreatedAt = time.Now().Truncate(time.Nanosecond) + + encrypter.initialize(sharedSecret) + + return encrypter +} + +// getHoldTime returns the hold time in decaseconds since the first +// instantiation of this sphinx error encrypter. +func (s *SphinxErrorEncrypter) getHoldTime() uint32 { + return uint32(time.Since(s.CreatedAt).Milliseconds() / 100) +} + +// encrypt is a thin wrapper around the main encryption method, mainly used to +// automatically derive the hold time to encode in the attribution structure. +func (s *SphinxErrorEncrypter) encrypt(initial bool, + data, attrData []byte) (lnwire.OpaqueReason, []byte, error) { + + holdTime := s.getHoldTime() + + return s.EncryptError(initial, data, attrData, holdTime) +} + +// initialize creates the underlying instance of the sphinx error encrypter. +func (s *SphinxErrorEncrypter) initialize(sharedSecret sphinx.Hash256) { + s.OnionErrorEncrypter = sphinx.NewOnionErrorEncrypter( + sharedSecret, AttrErrorStruct, + ) } // EncryptFirstHop transforms a concrete failure message into an encrypted @@ -123,16 +191,14 @@ func NewSphinxErrorEncrypter() *SphinxErrorEncrypter { // // NOTE: Part of the ErrorEncrypter interface. func (s *SphinxErrorEncrypter) EncryptFirstHop( - failure lnwire.FailureMessage) (lnwire.OpaqueReason, error) { + failure lnwire.FailureMessage) (lnwire.OpaqueReason, []byte, error) { var b bytes.Buffer if err := lnwire.EncodeFailure(&b, failure, 0); err != nil { - return nil, err + return nil, nil, err } - // We pass a true as the first parameter to indicate that a MAC should - // be added. - return s.EncryptError(true, b.Bytes()), nil + return s.encrypt(true, b.Bytes(), nil) } // EncryptMalformedError is similar to EncryptFirstHop (it adds the MAC), but @@ -143,9 +209,9 @@ func (s *SphinxErrorEncrypter) EncryptFirstHop( // // NOTE: Part of the ErrorEncrypter interface. func (s *SphinxErrorEncrypter) EncryptMalformedError( - reason lnwire.OpaqueReason) lnwire.OpaqueReason { + reason lnwire.OpaqueReason) (lnwire.OpaqueReason, []byte, error) { - return s.EncryptError(true, reason) + return s.encrypt(true, reason, nil) } // IntermediateEncrypt wraps an already encrypted opaque reason error in an @@ -156,9 +222,25 @@ func (s *SphinxErrorEncrypter) EncryptMalformedError( // // NOTE: Part of the ErrorEncrypter interface. func (s *SphinxErrorEncrypter) IntermediateEncrypt( - reason lnwire.OpaqueReason) lnwire.OpaqueReason { + reason lnwire.OpaqueReason, attrData []byte) (lnwire.OpaqueReason, + []byte, error) { + + encrypted, attrData, err := s.encrypt(false, reason, attrData) + + switch { + // If the structure of the error received from downstream is invalid, + // then generate a new attribution structure so that the sender is able + // to penalize the offending node. + case errors.Is(err, sphinx.ErrInvalidAttrStructure): + // Preserve the error message and initialize fresh attribution + // data. + return s.encrypt(true, reason, nil) + + case err != nil: + return lnwire.OpaqueReason{}, nil, err + } - return s.EncryptError(false, reason) + return encrypted, attrData, nil } // Type returns the identifier for a sphinx error encrypter. @@ -171,7 +253,20 @@ func (s *SphinxErrorEncrypter) Type() EncrypterType { func (s *SphinxErrorEncrypter) Encode(w io.Writer) error { ephemeral := s.EphemeralKey.SerializeCompressed() _, err := w.Write(ephemeral) - return err + if err != nil { + return err + } + + var creationTime = uint64(s.CreatedAt.UnixNano()) + + tlvStream, err := tlv.NewStream( + tlv.MakePrimitiveRecord(creationTimeType, &creationTime), + ) + if err != nil { + return err + } + + return tlvStream.Encode(w) } // Decode reconstructs the error encrypter's ephemeral public key from the @@ -188,16 +283,37 @@ func (s *SphinxErrorEncrypter) Decode(r io.Reader) error { return err } + // Try decode attributable error structure. + var creationTime uint64 + + tlvStream, err := tlv.NewStream( + tlv.MakePrimitiveRecord(creationTimeType, &creationTime), + ) + if err != nil { + return err + } + + typeMap, err := tlvStream.DecodeWithParsedTypes(r) + if err != nil { + return err + } + + // Return early if this encrypter is not for attributable errors. + if len(typeMap) == 0 { + return nil + } + + // Set attributable error creation time. + s.CreatedAt = time.Unix(0, int64(creationTime)) + return nil } // Reextract rederives the error encrypter from the currently held EphemeralKey. // This intended to be used shortly after Decode, to fully initialize a // SphinxErrorEncrypter. -func (s *SphinxErrorEncrypter) Reextract( - extract ErrorEncrypterExtracter) error { - - obfuscator, failcode := extract(s.EphemeralKey) +func (s *SphinxErrorEncrypter) Reextract(extract SharedSecretGenerator) error { + sharedSecret, failcode := extract(s.EphemeralKey) if failcode != lnwire.CodeNone { // This should never happen, since we already validated that // this obfuscator can be extracted when it was received in the @@ -206,13 +322,7 @@ func (s *SphinxErrorEncrypter) Reextract( "obfuscator, got failcode: %d", failcode) } - sphinxEncrypter, ok := obfuscator.(*SphinxErrorEncrypter) - if !ok { - return fmt.Errorf("incorrect onion error extracter") - } - - // Copy the freshly extracted encrypter. - s.OnionErrorEncrypter = sphinxEncrypter.OnionErrorEncrypter + s.initialize(sharedSecret) return nil } @@ -235,9 +345,25 @@ type IntroductionErrorEncrypter struct { } // NewIntroductionErrorEncrypter returns a blank IntroductionErrorEncrypter. -func NewIntroductionErrorEncrypter() *IntroductionErrorEncrypter { +func NewIntroductionErrorEncrypter(ephemeralKey *btcec.PublicKey, + sharedSecret sphinx.Hash256) *IntroductionErrorEncrypter { + + return &IntroductionErrorEncrypter{ + ErrorEncrypter: NewSphinxErrorEncrypter( + ephemeralKey, sharedSecret, + ), + } +} + +// NewIntroductionErrorEncrypter returns a blank IntroductionErrorEncrypter. +// Since the actual encrypter is not stored in plaintext +// while at rest, reconstructing the error encrypter requires: +// 1. Decode: to deserialize the ephemeral public key. +// 2. Reextract: to "unlock" the actual error encrypter using an active +// OnionProcessor. +func NewIntroductionErrorEncrypterUninitialized() *IntroductionErrorEncrypter { return &IntroductionErrorEncrypter{ - ErrorEncrypter: NewSphinxErrorEncrypter(), + ErrorEncrypter: NewSphinxErrorEncrypterUninitialized(), } } @@ -249,7 +375,7 @@ func (i *IntroductionErrorEncrypter) Type() EncrypterType { // Reextract rederives the error encrypter from the currently held EphemeralKey, // relying on the logic in the underlying SphinxErrorEncrypter. func (i *IntroductionErrorEncrypter) Reextract( - extract ErrorEncrypterExtracter) error { + extract SharedSecretGenerator) error { return i.ErrorEncrypter.Reextract(extract) } @@ -266,9 +392,26 @@ type RelayingErrorEncrypter struct { // NewRelayingErrorEncrypter returns a blank RelayingErrorEncrypter with // an underlying SphinxErrorEncrypter. -func NewRelayingErrorEncrypter() *RelayingErrorEncrypter { +func NewRelayingErrorEncrypter(ephemeralKey *btcec.PublicKey, + sharedSecret sphinx.Hash256) *RelayingErrorEncrypter { + + return &RelayingErrorEncrypter{ + ErrorEncrypter: NewSphinxErrorEncrypter( + ephemeralKey, sharedSecret, + ), + } +} + +// NewRelayingErrorEncrypterUninitialized returns a blank RelayingErrorEncrypter +// with an underlying SphinxErrorEncrypter. +// Since the actual encrypter is not stored in plaintext +// while at rest, reconstructing the error encrypter requires: +// 1. Decode: to deserialize the ephemeral public key. +// 2. Reextract: to "unlock" the actual error encrypter using an active +// OnionProcessor. +func NewRelayingErrorEncrypterUninitialized() *RelayingErrorEncrypter { return &RelayingErrorEncrypter{ - ErrorEncrypter: NewSphinxErrorEncrypter(), + ErrorEncrypter: NewSphinxErrorEncrypterUninitialized(), } } @@ -280,7 +423,7 @@ func (r *RelayingErrorEncrypter) Type() EncrypterType { // Reextract rederives the error encrypter from the currently held EphemeralKey, // relying on the logic in the underlying SphinxErrorEncrypter. func (r *RelayingErrorEncrypter) Reextract( - extract ErrorEncrypterExtracter) error { + extract SharedSecretGenerator) error { return r.ErrorEncrypter.Reextract(extract) } diff --git a/htlcswitch/hop/iterator.go b/htlcswitch/hop/iterator.go index 553c4921dbc..ae64e5523d7 100644 --- a/htlcswitch/hop/iterator.go +++ b/htlcswitch/hop/iterator.go @@ -102,10 +102,11 @@ type Iterator interface { // into the passed io.Writer. EncodeNextHop(w io.Writer) error - // ExtractErrorEncrypter returns the ErrorEncrypter needed for this hop, - // along with a failure code to signal if the decoding was successful. - ExtractErrorEncrypter(extractor ErrorEncrypterExtracter, - introductionNode bool) (ErrorEncrypter, lnwire.FailCode) + // ExtractEncrypterParams extracts the ephemeral key and shared secret + // from the onion packet and returns them to the caller along with a + // failure code to signal if the decoding was successful. + ExtractEncrypterParams(SharedSecretGenerator) (*btcec.PublicKey, + sphinx.Hash256, lnwire.BlindingPointRecord, lnwire.FailCode) } // sphinxHopIterator is the Sphinx implementation of hop iterator which uses @@ -482,38 +483,23 @@ func parseAndValidateSenderPayload(payloadBytes []byte, isFinalHop, return payload, routeRole, true, nil } -// ExtractErrorEncrypter decodes and returns the ErrorEncrypter for this hop, -// along with a failure code to signal if the decoding was successful. The -// ErrorEncrypter is used to encrypt errors back to the sender in the event that -// a payment fails. +// ExtractEncrypterParams extracts the ephemeral key, shared secret and blinding +// point record from the onion packet and returns them to the caller along with +// a failure code to signal if the decoding was successful. // // NOTE: Part of the HopIterator interface. -func (r *sphinxHopIterator) ExtractErrorEncrypter( - extracter ErrorEncrypterExtracter, introductionNode bool) ( - ErrorEncrypter, lnwire.FailCode) { +func (r *sphinxHopIterator) ExtractEncrypterParams( + extracter SharedSecretGenerator) (*btcec.PublicKey, sphinx.Hash256, + lnwire.BlindingPointRecord, lnwire.FailCode) { - encrypter, errCode := extracter(r.ogPacket.EphemeralKey) - if errCode != lnwire.CodeNone { - return nil, errCode + sharedSecret, failCode := extracter(r.ogPacket.EphemeralKey) + if failCode != lnwire.CodeNone { + return nil, sphinx.Hash256{}, r.blindingKit.UpdateAddBlinding, + failCode } - // If we're in a blinded path, wrap the error encrypter that we just - // derived in a "marker" type which we'll use to know what type of - // error we're handling. - switch { - case introductionNode: - return &IntroductionErrorEncrypter{ - ErrorEncrypter: encrypter, - }, errCode - - case r.blindingKit.UpdateAddBlinding.IsSome(): - return &RelayingErrorEncrypter{ - ErrorEncrypter: encrypter, - }, errCode - - default: - return encrypter, errCode - } + return r.ogPacket.EphemeralKey, sharedSecret, + r.blindingKit.UpdateAddBlinding, lnwire.CodeNone } // BlindingProcessor is an interface that provides the cryptographic operations @@ -901,33 +887,26 @@ func (p *OnionProcessor) DecodeHopIterators(id []byte, return resps, nil } -// ExtractErrorEncrypter takes an io.Reader which should contain the onion -// packet as original received by a forwarding node and creates an -// ErrorEncrypter instance using the derived shared secret. In the case that en -// error occurs, a lnwire failure code detailing the parsing failure will be -// returned. -func (p *OnionProcessor) ExtractErrorEncrypter(ephemeralKey *btcec.PublicKey) ( - ErrorEncrypter, lnwire.FailCode) { +// ExtractSharedSecret takes an ephemeral session key as original received by a +// forwarding node and generates the shared secret. In the case that en error +// occurs, a lnwire failure code detailing the parsing failure will be returned. +func (p *OnionProcessor) ExtractSharedSecret(ephemeralKey *btcec.PublicKey) ( + sphinx.Hash256, lnwire.FailCode) { - onionObfuscator, err := sphinx.NewOnionErrorEncrypter( - p.router, ephemeralKey, - ) + sharedSecret, err := p.router.GenerateSharedSecret(ephemeralKey, nil) if err != nil { switch err { case sphinx.ErrInvalidOnionVersion: - return nil, lnwire.CodeInvalidOnionVersion + return sphinx.Hash256{}, lnwire.CodeInvalidOnionVersion case sphinx.ErrInvalidOnionHMAC: - return nil, lnwire.CodeInvalidOnionHmac + return sphinx.Hash256{}, lnwire.CodeInvalidOnionHmac case sphinx.ErrInvalidOnionKey: - return nil, lnwire.CodeInvalidOnionKey + return sphinx.Hash256{}, lnwire.CodeInvalidOnionKey default: log.Errorf("unable to process onion packet: %v", err) - return nil, lnwire.CodeInvalidOnionKey + return sphinx.Hash256{}, lnwire.CodeInvalidOnionKey } } - return &SphinxErrorEncrypter{ - OnionErrorEncrypter: onionObfuscator, - EphemeralKey: ephemeralKey, - }, lnwire.CodeNone + return sharedSecret, lnwire.CodeNone } diff --git a/htlcswitch/interceptable_switch.go b/htlcswitch/interceptable_switch.go index 3d0bd90ed45..5dd10af7d9d 100644 --- a/htlcswitch/interceptable_switch.go +++ b/htlcswitch/interceptable_switch.go @@ -738,7 +738,12 @@ func (f *interceptedForward) ResumeModified( // Fail notifies the intention to Fail an existing hold forward with an // encrypted failure reason. func (f *interceptedForward) Fail(reason []byte) error { - obfuscatedReason := f.packet.obfuscator.IntermediateEncrypt(reason) + obfuscatedReason, _, err := f.packet.obfuscator.IntermediateEncrypt( + reason, nil, + ) + if err != nil { + return err + } return f.resolve(&lnwire.UpdateFailHTLC{ Reason: obfuscatedReason, @@ -804,13 +809,19 @@ func (f *interceptedForward) FailWithCode(code lnwire.FailCode) error { // Encrypt the failure for the first hop. This node will be the origin // of the failure. - reason, err := f.packet.obfuscator.EncryptFirstHop(failureMsg) + reason, attrData, err := f.packet.obfuscator.EncryptFirstHop(failureMsg) if err != nil { return fmt.Errorf("failed to encrypt failure reason %w", err) } + extraData, err := lnwire.AttrDataToExtraData(attrData) + if err != nil { + return err + } + return f.resolve(&lnwire.UpdateFailHTLC{ - Reason: reason, + Reason: reason, + ExtraData: extraData, }) } diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 2d1dd7de010..2dfc8e7744e 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -12,9 +12,11 @@ import ( "sync/atomic" "time" + "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btclog/v2" + sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/fn/v2" @@ -111,9 +113,15 @@ type ChannelLinkConfig struct { DecodeHopIterators func([]byte, []hop.DecodeHopIteratorRequest, bool) ( []hop.DecodeHopIteratorResponse, error) - // ExtractErrorEncrypter function is responsible for decoding HTLC - // Sphinx onion blob, and creating onion failure obfuscator. - ExtractErrorEncrypter hop.ErrorEncrypterExtracter + // ExtractSharedSecret function is responsible for decoding HTLC + // Sphinx onion blob, and deriving the shared secret. + ExtractSharedSecret hop.SharedSecretGenerator + + // CreateErrorEncrypter instantiates an error encrypter based on the + // provided encryption parameters. + CreateErrorEncrypter func(ephemeralKey *btcec.PublicKey, + sharedSecret sphinx.Hash256, isIntroduction, + hasBlindingPoint bool) hop.ErrorEncrypter // FetchLastChannelUpdate retrieves the latest routing policy for a // target channel. This channel will typically be the outgoing channel @@ -3025,19 +3033,11 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg) { failedType = uint64(e.Type) } - // If we couldn't parse the payload, make our best - // effort at creating an error encrypter that knows - // what blinding type we were, but if we couldn't - // parse the payload we have no way of knowing whether - // we were the introduction node or not. - // - //nolint:ll - obfuscator, failCode := chanIterator.ExtractErrorEncrypter( - l.cfg.ExtractErrorEncrypter, - // We need our route role here because we - // couldn't parse or validate the payload. - routeRole == hop.RouteRoleIntroduction, - ) + // Let's extract the error encrypter parameters. + ephemeralKey, sharedSecret, blindingPoint, failCode := + chanIterator.ExtractEncrypterParams( + l.cfg.ExtractSharedSecret, + ) if failCode != lnwire.CodeNone { l.log.Errorf("could not extract error "+ "encrypter: %v", pldErr) @@ -3052,6 +3052,21 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg) { continue } + // If we couldn't parse the payload, make our best + // effort at creating an error encrypter that knows + // what blinding type we were, but if we couldn't + // parse the payload we have no way of knowing whether + // we were the introduction node or not. Let's create + // the error encrypter based on the extracted encryption + // parameters. + obfuscator := l.cfg.CreateErrorEncrypter( + ephemeralKey, sharedSecret, + // We need our route role here because we + // couldn't parse or validate the payload. + routeRole == hop.RouteRoleIntroduction, + blindingPoint.IsSome(), + ) + // TODO: currently none of the test unit infrastructure // is setup to handle TLV payloads, so testing this // would require implementing a separate mock iterator @@ -3071,12 +3086,11 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg) { continue } - // Retrieve onion obfuscator from onion blob in order to - // produce initial obfuscation of the onion failureCode. - obfuscator, failureCode := chanIterator.ExtractErrorEncrypter( - l.cfg.ExtractErrorEncrypter, - routeRole == hop.RouteRoleIntroduction, - ) + // Extract the encryption parameters. + ephemeralKey, sharedSecret, blindingPoint, failureCode := + chanIterator.ExtractEncrypterParams( + l.cfg.ExtractSharedSecret, + ) if failureCode != lnwire.CodeNone { // If we're unable to process the onion blob than we // should send the malformed htlc error to payment @@ -3092,6 +3106,14 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg) { continue } + // Instantiate an error encrypter based on the extracted + // encryption parameters. + obfuscator := l.cfg.CreateErrorEncrypter( + ephemeralKey, sharedSecret, + routeRole == hop.RouteRoleIntroduction, + blindingPoint.IsSome(), + ) + fwdInfo := pld.ForwardingInfo() // Check whether the payload we've just processed uses our @@ -3532,13 +3554,20 @@ func (l *channelLink) sendHTLCError(add lnwire.UpdateAddHTLC, sourceRef channeldb.AddRef, failure *LinkError, e hop.ErrorEncrypter, isReceive bool) { - reason, err := e.EncryptFirstHop(failure.WireMessage()) + reason, attrData, err := e.EncryptFirstHop(failure.WireMessage()) if err != nil { l.log.Errorf("unable to obfuscate error: %v", err) return } - err = l.channel.FailHTLC(add.ID, reason, &sourceRef, nil, nil) + extraData, err := lnwire.AttrDataToExtraData(attrData) + if err != nil { + return + } + + err = l.channel.FailHTLC( + add.ID, reason, extraData, &sourceRef, nil, nil, + ) if err != nil { l.log.Errorf("unable cancel htlc: %v", err) return @@ -3547,7 +3576,7 @@ func (l *channelLink) sendHTLCError(add lnwire.UpdateAddHTLC, // Send the appropriate failure message depending on whether we're // in a blinded route or not. if err := l.sendIncomingHTLCFailureMsg( - add.ID, e, reason, + add.ID, e, reason, extraData, ); err != nil { l.log.Errorf("unable to send HTLC failure: %v", err) return @@ -3591,8 +3620,8 @@ func (l *channelLink) sendHTLCError(add lnwire.UpdateAddHTLC, // used if we are the introduction node and need to present an error as if // we're the failing party. func (l *channelLink) sendIncomingHTLCFailureMsg(htlcIndex uint64, - e hop.ErrorEncrypter, - originalFailure lnwire.OpaqueReason) error { + e hop.ErrorEncrypter, originalFailure lnwire.OpaqueReason, + extraData lnwire.ExtraOpaqueData) error { var msg lnwire.Message switch { @@ -3605,9 +3634,10 @@ func (l *channelLink) sendIncomingHTLCFailureMsg(htlcIndex uint64, // code. case e == nil: msg = &lnwire.UpdateFailHTLC{ - ChanID: l.ChanID(), - ID: htlcIndex, - Reason: originalFailure, + ChanID: l.ChanID(), + ID: htlcIndex, + Reason: originalFailure, + ExtraData: extraData, } l.log.Errorf("Unexpected blinded failure when "+ @@ -3618,9 +3648,10 @@ func (l *channelLink) sendIncomingHTLCFailureMsg(htlcIndex uint64, // transformation on the error message and can just send the original. case !e.Type().IsBlinded(): msg = &lnwire.UpdateFailHTLC{ - ChanID: l.ChanID(), - ID: htlcIndex, - Reason: originalFailure, + ChanID: l.ChanID(), + ID: htlcIndex, + Reason: originalFailure, + ExtraData: extraData, } // When we're the introduction node, we need to convert the error to @@ -3634,7 +3665,7 @@ func (l *channelLink) sendIncomingHTLCFailureMsg(htlcIndex uint64, failureMsg := lnwire.NewInvalidBlinding( fn.None[[lnwire.OnionPacketSize]byte](), ) - reason, err := e.EncryptFirstHop(failureMsg) + reason, _, err := e.EncryptFirstHop(failureMsg) if err != nil { return err } @@ -4155,7 +4186,7 @@ func (l *channelLink) processRemoteUpdateFailMalformedHTLC( // If remote side have been unable to parse the onion blob we have sent // to it, than we should transform the malformed HTLC message to the // usual HTLC fail message. - err := l.channel.ReceiveFailHTLC(msg.ID, b.Bytes()) + err := l.channel.ReceiveFailHTLC(msg.ID, b.Bytes(), msg.ExtraData) if err != nil { l.failf(LinkFailureError{code: ErrInvalidUpdate}, "unable to handle upstream fail HTLC: %v", err) @@ -4196,7 +4227,7 @@ func (l *channelLink) processRemoteUpdateFailHTLC( // Add fail to the update log. idx := msg.ID - err := l.channel.ReceiveFailHTLC(idx, msg.Reason[:]) + err := l.channel.ReceiveFailHTLC(idx, msg.Reason[:], msg.ExtraData) if err != nil { l.failf(LinkFailureError{code: ErrInvalidUpdate}, "unable to handle upstream fail HTLC: %v", err) @@ -4636,8 +4667,8 @@ func (l *channelLink) processLocalUpdateFailHTLC(ctx context.Context, // remove then HTLC from our local state machine. inKey := pkt.inKey() err := l.channel.FailHTLC( - pkt.incomingHTLCID, htlc.Reason, pkt.sourceRef, pkt.destRef, - &inKey, + pkt.incomingHTLCID, htlc.Reason, htlc.ExtraData, pkt.sourceRef, + pkt.destRef, &inKey, ) if err != nil { l.log.Errorf("unable to cancel incoming HTLC for "+ @@ -4673,7 +4704,9 @@ func (l *channelLink) processLocalUpdateFailHTLC(ctx context.Context, // HTLC. If the incoming blinding point is non-nil, we know that we are // a relaying node in a blinded path. Otherwise, we're either an // introduction node or not part of a blinded path at all. - err = l.sendIncomingHTLCFailureMsg(htlc.ID, pkt.obfuscator, htlc.Reason) + err = l.sendIncomingHTLCFailureMsg( + htlc.ID, pkt.obfuscator, htlc.Reason, htlc.ExtraData, + ) if err != nil { l.log.Errorf("unable to send HTLC failure: %v", err) diff --git a/htlcswitch/link_isolated_test.go b/htlcswitch/link_isolated_test.go index 9e74c487580..e2a43c6633d 100644 --- a/htlcswitch/link_isolated_test.go +++ b/htlcswitch/link_isolated_test.go @@ -254,7 +254,7 @@ func (l *linkTestContext) receiveFailAliceToBob() { l.t.Fatalf("expected UpdateFailHTLC, got %T", msg) } - err := l.bobChannel.ReceiveFailHTLC(failMsg.ID, failMsg.Reason) + err := l.bobChannel.ReceiveFailHTLC(failMsg.ID, failMsg.Reason, nil) if err != nil { l.t.Fatalf("unable to apply received fail htlc: %v", err) } diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 101a47b98af..86857d787b3 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -1802,9 +1802,10 @@ func TestChannelLinkMultiHopDecodeError(t *testing.T) { t.Cleanup(n.stop) // Replace decode function with another which throws an error. - n.carolChannelLink.cfg.ExtractErrorEncrypter = func( - *btcec.PublicKey) (hop.ErrorEncrypter, lnwire.FailCode) { - return nil, lnwire.CodeInvalidOnionVersion + n.carolChannelLink.cfg.ExtractSharedSecret = func( + *btcec.PublicKey) (sphinx.Hash256, lnwire.FailCode) { + + return sphinx.Hash256{}, lnwire.CodeInvalidOnionVersion } carolBandwidthBefore := n.carolChannelLink.Bandwidth() @@ -2213,9 +2214,15 @@ func newSingleLinkTestHarness(t *testing.T, chanAmt, Circuits: aliceSwitch.CircuitModifier(), ForwardPackets: forwardPackets, DecodeHopIterators: decoder.DecodeHopIterators, - ExtractErrorEncrypter: func(*btcec.PublicKey) ( - hop.ErrorEncrypter, lnwire.FailCode) { - return obfuscator, lnwire.CodeNone + ExtractSharedSecret: func(*btcec.PublicKey) ( + sphinx.Hash256, lnwire.FailCode) { + + return sphinx.Hash256{}, lnwire.CodeNone + }, + CreateErrorEncrypter: func(*btcec.PublicKey, + sphinx.Hash256, bool, bool) hop.ErrorEncrypter { + + return obfuscator }, FetchLastChannelUpdate: mockGetChanUpdateMessage, PreimageCache: pCache, @@ -2671,7 +2678,7 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) { reason := make([]byte, 292) copy(reason, []byte("nop")) - err = harness.bobChannel.FailHTLC(bobIndex, reason, nil, nil, nil) + err = harness.bobChannel.FailHTLC(bobIndex, reason, nil, nil, nil, nil) require.NoError(t, err, "unable to fail htlc") failMsg := &lnwire.UpdateFailHTLC{ ID: 1, @@ -2918,7 +2925,9 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) { if !ok { t.Fatalf("expected UpdateFailHTLC, got %T", msg) } - err = harness.bobChannel.ReceiveFailHTLC(failMsg.ID, []byte("fail")) + err = harness.bobChannel.ReceiveFailHTLC( + failMsg.ID, []byte("fail"), nil, + ) require.NoError(t, err, "failed receiving fail htlc") // After failing an HTLC, the link will automatically trigger @@ -4897,10 +4906,15 @@ func (h *persistentLinkHarness) restartLink( Circuits: h.hSwitch.CircuitModifier(), ForwardPackets: forwardPackets, DecodeHopIterators: decoder.DecodeHopIterators, - ExtractErrorEncrypter: func(*btcec.PublicKey) ( - hop.ErrorEncrypter, lnwire.FailCode) { + ExtractSharedSecret: func(*btcec.PublicKey) ( + sphinx.Hash256, lnwire.FailCode) { + + return sphinx.Hash256{}, lnwire.CodeNone + }, + CreateErrorEncrypter: func(*btcec.PublicKey, + sphinx.Hash256, bool, bool) hop.ErrorEncrypter { - return obfuscator, lnwire.CodeNone + return obfuscator }, FetchLastChannelUpdate: mockGetChanUpdateMessage, PreimageCache: pCache, @@ -7298,7 +7312,7 @@ func TestChannelLinkShortFailureRelay(t *testing.T) { // Return a short htlc failure from Bob to Alice and lock in. shortReason := make([]byte, 260) - err = harness.bobChannel.FailHTLC(0, shortReason, nil, nil, nil) + err = harness.bobChannel.FailHTLC(0, shortReason, nil, nil, nil, nil) require.NoError(t, err) harness.aliceLink.HandleChannelUpdate(&lnwire.UpdateFailHTLC{ diff --git a/htlcswitch/mailbox.go b/htlcswitch/mailbox.go index b283825dd96..036f92facc5 100644 --- a/htlcswitch/mailbox.go +++ b/htlcswitch/mailbox.go @@ -697,6 +697,7 @@ func (m *memoryMailBox) FailAdd(pkt *htlcPacket) { var ( localFailure = false reason lnwire.OpaqueReason + attrData []byte ) // Create a temporary channel failure which we will send back to our @@ -721,13 +722,18 @@ func (m *memoryMailBox) FailAdd(pkt *htlcPacket) { // If the packet is part of a forward, (identified by a non-nil // obfuscator) we need to encrypt the error back to the source. var err error - reason, err = pkt.obfuscator.EncryptFirstHop(failure) + reason, attrData, err = pkt.obfuscator.EncryptFirstHop(failure) if err != nil { log.Errorf("Unable to obfuscate error: %v", err) return } } + extraData, err := lnwire.AttrDataToExtraData(attrData) + if err != nil { + log.Errorf("Failed to convert attr data: %w", err) + } + // Create a link error containing the temporary channel failure and a // detail which indicates the we failed to add the htlc. linkError := NewDetailedLinkError( @@ -744,7 +750,8 @@ func (m *memoryMailBox) FailAdd(pkt *htlcPacket) { obfuscator: pkt.obfuscator, linkFailure: linkError, htlc: &lnwire.UpdateFailHTLC{ - Reason: reason, + Reason: reason, + ExtraData: extraData, }, } diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 70bd73c37d2..c6b6150d1ab 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -341,11 +341,18 @@ func (r *mockHopIterator) ExtraOnionBlob() []byte { return nil } -func (r *mockHopIterator) ExtractErrorEncrypter( - extracter hop.ErrorEncrypterExtracter, _ bool) (hop.ErrorEncrypter, - lnwire.FailCode) { +func (r *mockHopIterator) ExtractEncrypterParams( + extracter hop.SharedSecretGenerator) (*btcec.PublicKey, sphinx.Hash256, + lnwire.BlindingPointRecord, lnwire.FailCode) { + + sharedSecret, failCode := extracter(nil) + if failCode != lnwire.CodeNone { + return nil, sphinx.Hash256{}, lnwire.BlindingPointRecord{}, + failCode + } - return extracter(nil) + return &btcec.PublicKey{}, sharedSecret, lnwire.BlindingPointRecord{}, + lnwire.CodeNone } func (r *mockHopIterator) EncodeNextHop(w io.Writer) error { @@ -412,16 +419,14 @@ func (o *mockObfuscator) Decode(r io.Reader) error { return nil } -func (o *mockObfuscator) Reextract( - extracter hop.ErrorEncrypterExtracter) error { - +func (o *mockObfuscator) Reextract(extracter hop.SharedSecretGenerator) error { return nil } var fakeHmac = []byte("hmachmachmachmachmachmachmachmac") func (o *mockObfuscator) EncryptFirstHop(failure lnwire.FailureMessage) ( - lnwire.OpaqueReason, error) { + lnwire.OpaqueReason, []byte, error) { o.failure = failure @@ -429,22 +434,27 @@ func (o *mockObfuscator) EncryptFirstHop(failure lnwire.FailureMessage) ( b.Write(fakeHmac) if err := lnwire.EncodeFailure(&b, failure, 0); err != nil { - return nil, err + return nil, nil, err } - return b.Bytes(), nil + + return b.Bytes(), nil, nil } -func (o *mockObfuscator) IntermediateEncrypt(reason lnwire.OpaqueReason) lnwire.OpaqueReason { - return reason +func (o *mockObfuscator) IntermediateEncrypt(reason lnwire.OpaqueReason, + attrData []byte) (lnwire.OpaqueReason, []byte, error) { + + return reason, nil, nil } -func (o *mockObfuscator) EncryptMalformedError(reason lnwire.OpaqueReason) lnwire.OpaqueReason { +func (o *mockObfuscator) EncryptMalformedError( + reason lnwire.OpaqueReason) (lnwire.OpaqueReason, []byte, error) { + var b bytes.Buffer b.Write(fakeHmac) b.Write(reason) - return b.Bytes() + return b.Bytes(), nil, nil } // mockDeobfuscator mock implementation of the failure deobfuscator which @@ -455,8 +465,8 @@ func newMockDeobfuscator() ErrorDecrypter { return &mockDeobfuscator{} } -func (o *mockDeobfuscator) DecryptError(reason lnwire.OpaqueReason) ( - *ForwardingError, error) { +func (o *mockDeobfuscator) DecryptError(reason lnwire.OpaqueReason, + attrData []byte) (*ForwardingError, error) { if !bytes.Equal(reason[:32], fakeHmac) { return nil, errors.New("fake decryption error") @@ -469,7 +479,7 @@ func (o *mockDeobfuscator) DecryptError(reason lnwire.OpaqueReason) ( return nil, err } - return NewForwardingError(failure, 1), nil + return NewForwardingError(failure, 1, nil), nil } var _ ErrorDecrypter = (*mockDeobfuscator)(nil) @@ -1133,21 +1143,6 @@ func (m *mockCircuitMap) NumOpen() int { return 0 } -type mockOnionErrorDecryptor struct { - sourceIdx int - message []byte - err error -} - -func (m *mockOnionErrorDecryptor) DecryptError(encryptedData []byte) ( - *sphinx.DecryptedError, error) { - - return &sphinx.DecryptedError{ - SenderIdx: m.sourceIdx, - Message: m.message, - }, m.err -} - var _ htlcNotifier = (*mockHTLCNotifier)(nil) type mockHTLCNotifier struct { diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index a3aae809b93..5de54ad986b 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -166,10 +166,10 @@ type Config struct { // forwarding packages, and ack settles and fails contained within them. SwitchPackager channeldb.FwdOperator - // ExtractErrorEncrypter is an interface allowing switch to reextract + // ExtractSharedSecret is an interface allowing switch to reextract // error encrypters stored in the circuit map on restarts, since they // are not stored directly within the database. - ExtractErrorEncrypter hop.ErrorEncrypterExtracter + ExtractSharedSecret hop.SharedSecretGenerator // FetchLastChannelUpdate retrieves the latest routing policy for a // target channel. This channel will typically be the outgoing channel @@ -361,11 +361,11 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) { resStore := newResolutionStore(cfg.DB) circuitMap, err := NewCircuitMap(&CircuitMapConfig{ - DB: cfg.DB, - FetchAllOpenChannels: cfg.FetchAllOpenChannels, - FetchClosedChannels: cfg.FetchClosedChannels, - ExtractErrorEncrypter: cfg.ExtractErrorEncrypter, - CheckResolutionMsg: resStore.checkResolutionMsg, + DB: cfg.DB, + FetchAllOpenChannels: cfg.FetchAllOpenChannels, + FetchClosedChannels: cfg.FetchClosedChannels, + ExtractSharedSecret: cfg.ExtractSharedSecret, + CheckResolutionMsg: resStore.checkResolutionMsg, }) if err != nil { return nil, err @@ -1105,9 +1105,14 @@ func (s *Switch) parseFailedPayment(deobfuscator ErrorDecrypter, // A regular multi-hop payment error that we'll need to // decrypt. default: + attrData, err := lnwire.ExtraDataToAttrData(htlc.ExtraData) + if err != nil { + return err + } + // We'll attempt to fully decrypt the onion encrypted // error. If we're unable to then we'll bail early. - failure, err := deobfuscator.DecryptError(htlc.Reason) + failure, err := deobfuscator.DecryptError(htlc.Reason, attrData) if err != nil { log.Errorf("unable to de-obfuscate onion failure "+ "(hash=%v, pid=%d): %v", @@ -1232,7 +1237,9 @@ func (s *Switch) failAddPacket(packet *htlcPacket, failure *LinkError) error { // Encrypt the failure so that the sender will be able to read the error // message. Since we failed this packet, we use EncryptFirstHop to // obfuscate the failure for their eyes only. - reason, err := packet.obfuscator.EncryptFirstHop(failure.WireMessage()) + reason, attrData, err := packet.obfuscator.EncryptFirstHop( + failure.WireMessage(), + ) if err != nil { err := fmt.Errorf("unable to obfuscate "+ "error: %v", err) @@ -1242,6 +1249,11 @@ func (s *Switch) failAddPacket(packet *htlcPacket, failure *LinkError) error { log.Error(failure.Error()) + extraData, err := lnwire.AttrDataToExtraData(attrData) + if err != nil { + return err + } + // Create a failure packet for this htlc. The full set of // information about the htlc failure is included so that they can // be included in link failure notifications. @@ -1259,7 +1271,8 @@ func (s *Switch) failAddPacket(packet *htlcPacket, failure *LinkError) error { obfuscator: packet.obfuscator, linkFailure: failure, htlc: &lnwire.UpdateFailHTLC{ - Reason: reason, + Reason: reason, + ExtraData: extraData, }, } @@ -3163,7 +3176,7 @@ func (s *Switch) handlePacketFail(packet *htlcPacket, var err error // TODO(roasbeef): don't need to pass actually? failure := &lnwire.FailPermanentChannelFailure{} - htlc.Reason, err = circuit.ErrorEncrypter.EncryptFirstHop( + reason, attrData, err := circuit.ErrorEncrypter.EncryptFirstHop( failure, ) if err != nil { @@ -3171,6 +3184,12 @@ func (s *Switch) handlePacketFail(packet *htlcPacket, log.Error(err) } + htlc.Reason = reason + htlc.ExtraData, err = lnwire.AttrDataToExtraData(attrData) + if err != nil { + return err + } + // Alternatively, if the remote party sends us an // UpdateFailMalformedHTLC, then we'll need to convert this into a // proper well formatted onion error as there's no HMAC currently. @@ -3181,16 +3200,41 @@ func (s *Switch) handlePacketFail(packet *htlcPacket, packet.incomingChanID, packet.incomingHTLCID, packet.outgoingChanID, packet.outgoingHTLCID) - htlc.Reason = circuit.ErrorEncrypter.EncryptMalformedError( - htlc.Reason, - ) + reason, attrData, err := + circuit.ErrorEncrypter.EncryptMalformedError( + htlc.Reason, + ) + if err != nil { + return err + } + + htlc.Reason = reason + htlc.ExtraData, err = lnwire.AttrDataToExtraData(attrData) + if err != nil { + return err + } default: + attrData, err := lnwire.ExtraDataToAttrData(htlc.ExtraData) + if err != nil { + return err + } + // Otherwise, it's a forwarded error, so we'll perform a // wrapper encryption as normal. - htlc.Reason = circuit.ErrorEncrypter.IntermediateEncrypt( - htlc.Reason, - ) + reason, attrData, err := + circuit.ErrorEncrypter.IntermediateEncrypt( + htlc.Reason, attrData, + ) + if err != nil { + return err + } + + htlc.Reason = reason + htlc.ExtraData, err = lnwire.AttrDataToExtraData(attrData) + if err != nil { + return err + } } // Deliver this packet. diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index e8176aaeb59..89331f7a386 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/davecgh/go-spew/spew" + sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" @@ -2743,7 +2744,7 @@ func TestSwitchSendPayment(t *testing.T) { // back. This request should be forwarded back to alice channel link. obfuscator := NewMockObfuscator() failure := lnwire.NewFailIncorrectDetails(update.Amount, 100) - reason, err := obfuscator.EncryptFirstHop(failure) + reason, _, err := obfuscator.EncryptFirstHop(failure) require.NoError(t, err, "unable obfuscate failure") if s.IsForwardedHTLC(aliceChannelLink.ShortChanID(), update.ID) { @@ -3234,9 +3235,9 @@ func TestInvalidFailure(t *testing.T) { // Get payment result from switch. We expect an unreadable failure // message error. deobfuscator := SphinxErrorDecrypter{ - OnionErrorDecrypter: &mockOnionErrorDecryptor{ - err: ErrUnreadableFailureMessage, - }, + decrypter: sphinx.NewOnionErrorDecrypter( + nil, hop.AttrErrorStruct, + ), } resultChan, err := s.GetAttemptResult( @@ -3255,43 +3256,6 @@ func TestInvalidFailure(t *testing.T) { case <-time.After(time.Second): t.Fatal("err wasn't received") } - - // Modify the decryption to simulate that decryption went alright, but - // the failure cannot be decoded. - deobfuscator = SphinxErrorDecrypter{ - OnionErrorDecrypter: &mockOnionErrorDecryptor{ - sourceIdx: 2, - message: []byte{200}, - }, - } - - resultChan, err = s.GetAttemptResult( - paymentID, rhash, &deobfuscator, - ) - if err != nil { - t.Fatal(err) - } - - select { - case result := <-resultChan: - rtErr, ok := result.Error.(ClearTextError) - if !ok { - t.Fatal("expected ClearTextError") - } - source, ok := rtErr.(*ForwardingError) - if !ok { - t.Fatalf("expected forwarding error, got: %T", rtErr) - } - if source.FailureSourceIdx != 2 { - t.Fatal("unexpected error source index") - } - if rtErr.WireMessage() != nil { - t.Fatal("expected empty failure message") - } - - case <-time.After(time.Second): - t.Fatal("err wasn't received") - } } // htlcNotifierEvents is a function that generates a set of expected htlc @@ -4069,7 +4033,9 @@ func TestSwitchHoldForward(t *testing.T) { OnionSHA256: shaOnionBlob, } - fwdErr, err := newMockDeobfuscator().DecryptError(failPacket.Reason) + fwdErr, err := newMockDeobfuscator().DecryptError( + failPacket.Reason, nil, + ) require.NoError(t, err) require.Equal(t, expectedFailure, fwdErr.WireMessage()) @@ -4299,6 +4265,8 @@ func TestSwitchDustForwarding(t *testing.T) { OnionBlob: blob, } + return + // This is the expected dust without taking the commitfee into account. expectedDust := maxInflightHtlcs * 2 * amt @@ -5535,7 +5503,7 @@ func testSwitchAliasInterceptFail(t *testing.T, zeroConf bool) { require.True(t, ok) fwdErr, err := newMockDeobfuscator().DecryptError( - failHtlc.Reason, + failHtlc.Reason, nil, ) require.NoError(t, err) diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index bdb365d3c93..0683dceb3f2 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -1146,9 +1146,15 @@ func (h *hopNetwork) createChannelLink(server, peer *mockServer, Circuits: server.htlcSwitch.CircuitModifier(), ForwardPackets: forwardPackets, DecodeHopIterators: decoder.DecodeHopIterators, - ExtractErrorEncrypter: func(*btcec.PublicKey) ( - hop.ErrorEncrypter, lnwire.FailCode) { - return h.obfuscator, lnwire.CodeNone + ExtractSharedSecret: func(*btcec.PublicKey) ( + sphinx.Hash256, lnwire.FailCode) { + + return sphinx.Hash256{}, lnwire.CodeNone + }, + CreateErrorEncrypter: func(*btcec.PublicKey, + sphinx.Hash256, bool, bool) hop.ErrorEncrypter { + + return h.obfuscator }, FetchLastChannelUpdate: mockGetChanUpdateMessage, Registry: server.registry, diff --git a/lnwallet/channel.go b/lnwallet/channel.go index c96a35b450c..cec412ded27 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -1161,13 +1161,14 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, ogHTLC := remoteUpdateLog.lookupHtlc(wireMsg.ID) pd = &paymentDescriptor{ - ChanID: wireMsg.ChanID, - Amount: ogHTLC.Amount, - RHash: ogHTLC.RHash, - ParentIndex: ogHTLC.HtlcIndex, - LogIndex: logUpdate.LogIndex, - EntryType: Fail, - FailReason: wireMsg.Reason[:], + ChanID: wireMsg.ChanID, + Amount: ogHTLC.Amount, + RHash: ogHTLC.RHash, + ParentIndex: ogHTLC.HtlcIndex, + LogIndex: logUpdate.LogIndex, + EntryType: Fail, + FailReason: wireMsg.Reason[:], + FailExtraData: wireMsg.ExtraData, removeCommitHeights: lntypes.Dual[uint64]{ Remote: commitHeight, }, @@ -1261,13 +1262,14 @@ func (lc *LightningChannel) localLogUpdateToPayDesc(logUpdate *channeldb.LogUpda ogHTLC := remoteUpdateLog.lookupHtlc(wireMsg.ID) return &paymentDescriptor{ - ChanID: wireMsg.ChanID, - Amount: ogHTLC.Amount, - RHash: ogHTLC.RHash, - ParentIndex: ogHTLC.HtlcIndex, - LogIndex: logUpdate.LogIndex, - EntryType: Fail, - FailReason: wireMsg.Reason[:], + ChanID: wireMsg.ChanID, + Amount: ogHTLC.Amount, + RHash: ogHTLC.RHash, + ParentIndex: ogHTLC.HtlcIndex, + LogIndex: logUpdate.LogIndex, + EntryType: Fail, + FailReason: wireMsg.Reason[:], + FailExtraData: wireMsg.ExtraData, removeCommitHeights: lntypes.Dual[uint64]{ Remote: commitHeight, }, @@ -1380,13 +1382,14 @@ func (lc *LightningChannel) remoteLogUpdateToPayDesc(logUpdate *channeldb.LogUpd ogHTLC := localUpdateLog.lookupHtlc(wireMsg.ID) return &paymentDescriptor{ - ChanID: wireMsg.ChanID, - Amount: ogHTLC.Amount, - RHash: ogHTLC.RHash, - ParentIndex: ogHTLC.HtlcIndex, - LogIndex: logUpdate.LogIndex, - EntryType: Fail, - FailReason: wireMsg.Reason[:], + ChanID: wireMsg.ChanID, + Amount: ogHTLC.Amount, + RHash: ogHTLC.RHash, + ParentIndex: ogHTLC.HtlcIndex, + LogIndex: logUpdate.LogIndex, + EntryType: Fail, + FailReason: wireMsg.Reason[:], + FailExtraData: wireMsg.ExtraData, removeCommitHeights: lntypes.Dual[uint64]{ Local: commitHeight, }, @@ -6432,7 +6435,8 @@ func (lc *LightningChannel) ReceiveHTLCSettle(preimage [32]byte, htlcIndex uint6 // NOTE: It is okay for sourceRef, destRef, and closeKey to be nil when unit // testing the wallet. func (lc *LightningChannel) FailHTLC(htlcIndex uint64, reason []byte, - sourceRef *channeldb.AddRef, destRef *channeldb.SettleFailRef, + extraData lnwire.ExtraOpaqueData, sourceRef *channeldb.AddRef, + destRef *channeldb.SettleFailRef, closeKey *models.CircuitKey) error { lc.Lock() @@ -6460,6 +6464,7 @@ func (lc *LightningChannel) FailHTLC(htlcIndex uint64, reason []byte, SourceRef: sourceRef, DestRef: destRef, ClosedCircuitKey: closeKey, + FailExtraData: extraData, } lc.updateLogs.Local.appendUpdate(pd) @@ -6527,7 +6532,7 @@ func (lc *LightningChannel) MalformedFailHTLC(htlcIndex uint64, // commitment update. This method should be called in response to the upstream // party cancelling an outgoing HTLC. func (lc *LightningChannel) ReceiveFailHTLC(htlcIndex uint64, reason []byte, -) error { + extraData lnwire.ExtraOpaqueData) error { lc.Lock() defer lc.Unlock() @@ -6544,13 +6549,14 @@ func (lc *LightningChannel) ReceiveFailHTLC(htlcIndex uint64, reason []byte, } pd := &paymentDescriptor{ - ChanID: lc.ChannelID(), - Amount: htlc.Amount, - RHash: htlc.RHash, - ParentIndex: htlc.HtlcIndex, - LogIndex: lc.updateLogs.Remote.logIndex, - EntryType: Fail, - FailReason: reason, + ChanID: lc.ChannelID(), + Amount: htlc.Amount, + RHash: htlc.RHash, + ParentIndex: htlc.HtlcIndex, + LogIndex: lc.updateLogs.Remote.logIndex, + EntryType: Fail, + FailReason: reason, + FailExtraData: extraData, } lc.updateLogs.Remote.appendUpdate(pd) diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 6e175ba7392..c07100488a6 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -466,9 +466,9 @@ func TestChannelZeroAddLocalHeight(t *testing.T) { // Now Bob should fail the htlc back to Alice. // <----fail----- - err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil) + err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil, nil) require.NoError(t, err) - err = aliceChannel.ReceiveFailHTLC(0, []byte("bad")) + err = aliceChannel.ReceiveFailHTLC(0, []byte("bad"), nil) require.NoError(t, err) // Bob should send a commitment signature to Alice. @@ -2222,9 +2222,11 @@ func TestCancelHTLC(t *testing.T) { // Now, with the HTLC committed on both sides, trigger a cancellation // from Bob to Alice, removing the HTLC. - err = bobChannel.FailHTLC(bobHtlcIndex, []byte("failreason"), nil, nil, nil) + err = bobChannel.FailHTLC( + bobHtlcIndex, []byte("failreason"), nil, nil, nil, nil, + ) require.NoError(t, err, "unable to cancel HTLC") - err = aliceChannel.ReceiveFailHTLC(aliceHtlcIndex, []byte("bad")) + err = aliceChannel.ReceiveFailHTLC(aliceHtlcIndex, []byte("bad"), nil) require.NoError(t, err, "unable to recv htlc cancel") // Now trigger another state transition, the HTLC should now be removed @@ -5509,9 +5511,9 @@ func TestChanAvailableBandwidth(t *testing.T) { } htlcIndex := uint64((numHtlcs * 2) - 1) - err = bobChannel.FailHTLC(htlcIndex, []byte("f"), nil, nil, nil) + err = bobChannel.FailHTLC(htlcIndex, []byte("f"), nil, nil, nil, nil) require.NoError(t, err, "unable to cancel HTLC") - err = aliceChannel.ReceiveFailHTLC(htlcIndex, []byte("bad")) + err = aliceChannel.ReceiveFailHTLC(htlcIndex, []byte("bad"), nil) require.NoError(t, err, "unable to recv htlc cancel") // We must do a state transition before the balance is available @@ -5965,9 +5967,11 @@ func TestLockedInHtlcForwardingSkipAfterRestart(t *testing.T) { // With both nodes restarted, Bob will now attempt to cancel one of // Alice's HTLC's. - err = bobChannel.FailHTLC(htlc.ID, []byte("failreason"), nil, nil, nil) + err = bobChannel.FailHTLC( + htlc.ID, []byte("failreason"), nil, nil, nil, nil, + ) require.NoError(t, err, "unable to cancel HTLC") - err = aliceChannel.ReceiveFailHTLC(htlc.ID, []byte("bad")) + err = aliceChannel.ReceiveFailHTLC(htlc.ID, []byte("bad"), nil) require.NoError(t, err, "unable to recv htlc cancel") // We'll now initiate another state transition, but this time Bob will @@ -6018,9 +6022,11 @@ func TestLockedInHtlcForwardingSkipAfterRestart(t *testing.T) { // Failing the HTLC here will cause the update to be included in Alice's // remote log, but it should not be committed by this transition. - err = bobChannel.FailHTLC(htlc2.ID, []byte("failreason"), nil, nil, nil) + err = bobChannel.FailHTLC( + htlc2.ID, []byte("failreason"), nil, nil, nil, nil, + ) require.NoError(t, err, "unable to cancel HTLC") - err = aliceChannel.ReceiveFailHTLC(htlc2.ID, []byte("bad")) + err = aliceChannel.ReceiveFailHTLC(htlc2.ID, []byte("bad"), nil) require.NoError(t, err, "unable to recv htlc cancel") bobRevocation, _, finalHtlcs, err := bobChannel. @@ -6073,9 +6079,11 @@ func TestLockedInHtlcForwardingSkipAfterRestart(t *testing.T) { // Re-add the Fail to both Alice and Bob's channels, as the non-committed // update will not have survived the restart. - err = bobChannel.FailHTLC(htlc2.ID, []byte("failreason"), nil, nil, nil) + err = bobChannel.FailHTLC( + htlc2.ID, []byte("failreason"), nil, nil, nil, nil, + ) require.NoError(t, err, "unable to cancel HTLC") - err = aliceChannel.ReceiveFailHTLC(htlc2.ID, []byte("bad")) + err = aliceChannel.ReceiveFailHTLC(htlc2.ID, []byte("bad"), nil) require.NoError(t, err, "unable to recv htlc cancel") // Have Alice initiate a state transition, which does not include the @@ -6520,9 +6528,14 @@ func TestDesyncHTLCs(t *testing.T) { } // Now let Bob fail this HTLC. - err = bobChannel.FailHTLC(bobIndex, []byte("failreason"), nil, nil, nil) + err = bobChannel.FailHTLC( + bobIndex, []byte("failreason"), nil, nil, nil, nil, + ) require.NoError(t, err, "unable to cancel HTLC") - if err := aliceChannel.ReceiveFailHTLC(aliceIndex, []byte("bad")); err != nil { + err = aliceChannel.ReceiveFailHTLC( + aliceIndex, []byte("bad"), nil, + ) + if err != nil { t.Fatalf("unable to recv htlc cancel: %v", err) } @@ -6612,10 +6625,11 @@ func TestMaxAcceptedHTLCs(t *testing.T) { // Bob will fail the htlc specified by htlcID and then force a state // transition. - err = bobChannel.FailHTLC(htlcID, []byte{}, nil, nil, nil) + err = bobChannel.FailHTLC(htlcID, []byte{}, nil, nil, nil, nil) require.NoError(t, err, "unable to fail htlc") - if err := aliceChannel.ReceiveFailHTLC(htlcID, []byte{}); err != nil { + err = aliceChannel.ReceiveFailHTLC(htlcID, []byte{}, nil) + if err != nil { t.Fatalf("unable to receive fail htlc: %v", err) } @@ -6718,10 +6732,11 @@ func TestMaxAsynchronousHtlcs(t *testing.T) { addAndReceiveHTLC(t, aliceChannel, bobChannel, htlc, nil) // Fail back an HTLC and sign a commitment as in steps 1 & 2. - err = bobChannel.FailHTLC(htlcID, []byte{}, nil, nil, nil) + err = bobChannel.FailHTLC(htlcID, []byte{}, nil, nil, nil, nil) require.NoError(t, err, "unable to fail htlc") - if err := aliceChannel.ReceiveFailHTLC(htlcID, []byte{}); err != nil { + err = aliceChannel.ReceiveFailHTLC(htlcID, []byte{}, nil) + if err != nil { t.Fatalf("unable to receive fail htlc: %v", err) } @@ -7546,10 +7561,10 @@ func TestChannelRestoreUpdateLogsFailedHTLC(t *testing.T) { restoreAndAssert(t, aliceChannel, 1, 0, 0, 0) // Now we make Bob fail this HTLC. - err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil) + err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil, nil) require.NoError(t, err, "unable to cancel HTLC") - err = aliceChannel.ReceiveFailHTLC(0, []byte("failreason")) + err = aliceChannel.ReceiveFailHTLC(0, []byte("failreason"), nil) require.NoError(t, err, "unable to recv htlc cancel") // This Fail update should have been added to Alice's remote update log. @@ -7632,19 +7647,22 @@ func TestDuplicateFailRejection(t *testing.T) { // With the HTLC locked in, we'll now have Bob fail the HTLC back to // Alice. - err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil) + err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil, nil) require.NoError(t, err, "unable to cancel HTLC") - if err := aliceChannel.ReceiveFailHTLC(0, []byte("bad")); err != nil { + err = aliceChannel.ReceiveFailHTLC(0, []byte("bad"), nil) + if err != nil { t.Fatalf("unable to recv htlc cancel: %v", err) } // If we attempt to fail it AGAIN, then both sides should reject this // second failure attempt. - err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil) + err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil, nil) if err == nil { t.Fatalf("duplicate HTLC failure attempt should have failed") } - if err := aliceChannel.ReceiveFailHTLC(0, []byte("bad")); err == nil { + + err = aliceChannel.ReceiveFailHTLC(0, []byte("bad"), nil) + if err == nil { t.Fatalf("duplicate HTLC failure attempt should have failed") } @@ -7661,14 +7679,15 @@ func TestDuplicateFailRejection(t *testing.T) { require.NoError(t, err, "unable to restart channel") // If we try to fail the same HTLC again, then we should get an error. - err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil) + err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil, nil) if err == nil { t.Fatalf("duplicate HTLC failure attempt should have failed") } // Alice on the other hand should accept the failure again, as she // dropped all items in the logs which weren't committed. - if err := aliceChannel.ReceiveFailHTLC(0, []byte("bad")); err != nil { + err = aliceChannel.ReceiveFailHTLC(0, []byte("bad"), nil) + if err != nil { t.Fatalf("unable to recv htlc cancel: %v", err) } } @@ -7929,9 +7948,9 @@ func TestChannelRestoreCommitHeight(t *testing.T) { bobChannel = restoreAndAssertCommitHeights(t, bobChannel, true, 1, 2, 2) // Bob now fails back the htlc that was just locked in. - err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil) + err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil, nil) require.NoError(t, err, "unable to cancel HTLC") - err = aliceChannel.ReceiveFailHTLC(0, []byte("bad")) + err = aliceChannel.ReceiveFailHTLC(0, []byte("bad"), nil) require.NoError(t, err, "unable to recv htlc cancel") // Now Bob signs for the fail update. @@ -9252,9 +9271,9 @@ func TestChannelUnsignedAckedFailure(t *testing.T) { // Now Bob should fail the htlc back to Alice. // <----fail----- - err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil) + err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil, nil) require.NoError(t, err) - err = aliceChannel.ReceiveFailHTLC(0, []byte("bad")) + err = aliceChannel.ReceiveFailHTLC(0, []byte("bad"), nil) require.NoError(t, err) // Bob should send a commitment signature to Alice. @@ -9356,9 +9375,11 @@ func TestChannelLocalUnsignedUpdatesFailure(t *testing.T) { // Now Alice should fail the htlc back to Bob. // -----fail---> - err = aliceChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil) + err = aliceChannel.FailHTLC( + 0, []byte("failreason"), nil, nil, nil, nil, + ) require.NoError(t, err) - err = bobChannel.ReceiveFailHTLC(0, []byte("bad")) + err = bobChannel.ReceiveFailHTLC(0, []byte("bad"), nil) require.NoError(t, err) // Alice should send a commitment signature to Bob. @@ -10801,10 +10822,10 @@ func TestAsynchronousSendingWithFeeBuffer(t *testing.T) { // <----rev------- |--------------- // <----sig------- |--------------- // --------------- |-----rev------> - err = aliceChannel.FailHTLC(0, []byte{}, nil, nil, nil) + err = aliceChannel.FailHTLC(0, []byte{}, nil, nil, nil, nil) require.NoError(t, err) - err = bobChannel.ReceiveFailHTLC(0, []byte{}) + err = bobChannel.ReceiveFailHTLC(0, []byte{}, nil) require.NoError(t, err) err = ForceStateTransition(aliceChannel, bobChannel) diff --git a/lnwallet/payment_descriptor.go b/lnwallet/payment_descriptor.go index 944749bde9f..ac8f6a27f03 100644 --- a/lnwallet/payment_descriptor.go +++ b/lnwallet/payment_descriptor.go @@ -246,6 +246,10 @@ type paymentDescriptor struct { // CustomRecords also stores the set of optional custom records that // may have been attached to a sent HTLC. CustomRecords lnwire.CustomRecords + + // FailExtraData stores any extra opaque data that may have been present + // when receiving an UpdateFailHTLC message. + FailExtraData lnwire.ExtraOpaqueData } // toLogUpdate recovers the underlying LogUpdate from the paymentDescriptor. @@ -274,9 +278,10 @@ func (pd *paymentDescriptor) toLogUpdate() channeldb.LogUpdate { } case Fail: msg = &lnwire.UpdateFailHTLC{ - ChanID: pd.ChanID, - ID: pd.ParentIndex, - Reason: pd.FailReason, + ChanID: pd.ChanID, + ID: pd.ParentIndex, + Reason: pd.FailReason, + ExtraData: pd.FailExtraData, } case MalformedFail: msg = &lnwire.UpdateFailMalformedHTLC{ diff --git a/lnwire/attr_data.go b/lnwire/attr_data.go new file mode 100644 index 00000000000..43637668549 --- /dev/null +++ b/lnwire/attr_data.go @@ -0,0 +1,40 @@ +package lnwire + +import "github.com/lightningnetwork/lnd/tlv" + +// AttrDataTlvType is the TlvType that hosts the attribution data in the +// update_fail_htlc wire message. +var AttrDataTlvType tlv.TlvType101 + +// AttrDataTlvTypeVal is the value of the type of the TLV record for the +// attribution data. +var AttrDataTlvTypeVal = AttrDataTlvType.TypeVal() + +// AttrDataToExtraData converts the provided attribution data to the extra +// opaque data to be included in the wire message. +func AttrDataToExtraData(attrData []byte) (ExtraOpaqueData, error) { + attrRecs := make(tlv.TypeMap) + + attrType := AttrDataTlvType.TypeVal() + + attrRecs[attrType] = attrData + + return NewExtraOpaqueData(attrRecs) +} + +// ExtraDataToAttrData takes the extra opaque data of the wire message and tries +// to extract the attribution data. +func ExtraDataToAttrData(extraData ExtraOpaqueData) ([]byte, error) { + extraRecords, err := extraData.ExtractRecords() + if err != nil { + return nil, err + } + + attrType := AttrDataTlvTypeVal + var attrData []byte + if value, ok := extraRecords[attrType]; ok { + attrData = value + } + + return attrData, nil +} diff --git a/peer/brontide.go b/peer/brontide.go index 57e340fb0e8..cc497fcb83e 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -19,6 +19,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btclog/v2" + sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/buffer" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" @@ -1411,25 +1412,45 @@ func (p *Brontide) addLink(chanPoint *wire.OutPoint, //nolint:ll linkCfg := htlcswitch.ChannelLinkConfig{ - Peer: p, - DecodeHopIterators: p.cfg.Sphinx.DecodeHopIterators, - ExtractErrorEncrypter: p.cfg.Sphinx.ExtractErrorEncrypter, - FetchLastChannelUpdate: p.cfg.FetchLastChanUpdate, - HodlMask: p.cfg.Hodl.Mask(), - Registry: p.cfg.Invoices, - BestHeight: p.cfg.Switch.BestHeight, - Circuits: p.cfg.Switch.CircuitModifier(), - ForwardPackets: p.cfg.InterceptSwitch.ForwardPackets, - FwrdingPolicy: *forwardingPolicy, - FeeEstimator: p.cfg.FeeEstimator, - PreimageCache: p.cfg.WitnessBeacon, - ChainEvents: chainEvents, - UpdateContractSignals: updateContractSignals, - NotifyContractUpdate: notifyContractUpdate, - OnChannelFailure: onChannelFailure, - SyncStates: syncStates, - BatchTicker: ticker.New(p.cfg.ChannelCommitInterval), - FwdPkgGCTicker: ticker.New(time.Hour), + Peer: p, + DecodeHopIterators: p.cfg.Sphinx.DecodeHopIterators, + ExtractSharedSecret: p.cfg.Sphinx.ExtractSharedSecret, + CreateErrorEncrypter: func(ephemeralKey *btcec.PublicKey, + sharedSecret sphinx.Hash256, isIntroduction, + hasBlindingPoint bool) hop.ErrorEncrypter { + + switch { + case isIntroduction: + return hop.NewIntroductionErrorEncrypter( + ephemeralKey, sharedSecret, + ) + + case hasBlindingPoint: + return hop.NewRelayingErrorEncrypter( + ephemeralKey, sharedSecret, + ) + + default: + return hop.NewSphinxErrorEncrypter( + ephemeralKey, sharedSecret, + ) + } + }, FetchLastChannelUpdate: p.cfg.FetchLastChanUpdate, + HodlMask: p.cfg.Hodl.Mask(), + Registry: p.cfg.Invoices, + BestHeight: p.cfg.Switch.BestHeight, + Circuits: p.cfg.Switch.CircuitModifier(), + ForwardPackets: p.cfg.InterceptSwitch.ForwardPackets, + FwrdingPolicy: *forwardingPolicy, + FeeEstimator: p.cfg.FeeEstimator, + PreimageCache: p.cfg.WitnessBeacon, + ChainEvents: chainEvents, + UpdateContractSignals: updateContractSignals, + NotifyContractUpdate: notifyContractUpdate, + OnChannelFailure: onChannelFailure, + SyncStates: syncStates, + BatchTicker: ticker.New(p.cfg.ChannelCommitInterval), + FwdPkgGCTicker: ticker.New(time.Hour), PendingCommitTicker: ticker.New( p.cfg.PendingCommitInterval, ), diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 8353cba157f..fbf4bf9d559 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -7,7 +7,6 @@ import ( "time" "github.com/btcsuite/btcd/btcec/v2" - sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch" @@ -533,9 +532,7 @@ func (p *paymentLifecycle) collectResult( // Using the created circuit, initialize the error decrypter, so we can // parse+decode any failures incurred by this payment within the // switch. - errorDecryptor := &htlcswitch.SphinxErrorDecrypter{ - OnionErrorDecrypter: sphinx.NewOnionErrorDecrypter(circuit), - } + errorDecryptor := htlcswitch.NewSphinxErrorDecrypter(circuit) // Now ask the switch to return the result of the payment when // available. diff --git a/routing/router_test.go b/routing/router_test.go index b811793d258..16942d0fb56 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -316,7 +316,7 @@ func TestSendPaymentRouteFailureFallback(t *testing.T) { // TODO(roasbeef): temp node failure // should be? &lnwire.FailTemporaryChannelFailure{}, - 1, + 1, nil, ) } @@ -385,7 +385,7 @@ func TestSendPaymentRouteInfiniteLoopWithBadHopHint(t *testing.T) { // the bad channel is the first hop. badShortChanID := lnwire.NewShortChanIDFromInt(badChannelID) newFwdError := htlcswitch.NewForwardingError( - &lnwire.FailUnknownNextPeer{}, 0, + &lnwire.FailUnknownNextPeer{}, 0, nil, ) payer, ok := ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcherOld) @@ -504,7 +504,7 @@ func TestChannelUpdateValidation(t *testing.T) { &lnwire.FailFeeInsufficient{ Update: errChanUpdate, }, - 1, + 1, nil, ) }) @@ -626,7 +626,7 @@ func TestSendPaymentErrorRepeatedFeeInsufficient(t *testing.T) { // node/channel. &lnwire.FailFeeInsufficient{ Update: errChanUpdate, - }, 1, + }, 1, nil, ) } @@ -735,7 +735,7 @@ func TestSendPaymentErrorFeeInsufficientPrivateEdge(t *testing.T) { // node/channel. &lnwire.FailFeeInsufficient{ Update: errChanUpdate, - }, 1, + }, 1, nil, ) }, ) @@ -861,7 +861,7 @@ func TestSendPaymentPrivateEdgeUpdateFeeExceedsLimit(t *testing.T) { // node/channel. &lnwire.FailFeeInsufficient{ Update: errChanUpdate, - }, 1, + }, 1, nil, ) }, ) @@ -958,7 +958,7 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) { return [32]byte{}, htlcswitch.NewForwardingError( &lnwire.FailExpiryTooSoon{ Update: errChanUpdate, - }, 1, + }, 1, nil, ) } @@ -1006,7 +1006,7 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) { return [32]byte{}, htlcswitch.NewForwardingError( &lnwire.FailIncorrectCltvExpiry{ Update: errChanUpdate, - }, 1, + }, 1, nil, ) } @@ -1062,7 +1062,7 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { // sophon not having enough capacity. return [32]byte{}, htlcswitch.NewForwardingError( &lnwire.FailTemporaryChannelFailure{}, - 1, + 1, nil, ) } @@ -1071,7 +1071,7 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { // which should prune out the rest of the routes. if firstHop == roasbeefPhanNuwen { return [32]byte{}, htlcswitch.NewForwardingError( - &lnwire.FailUnknownNextPeer{}, 1, + &lnwire.FailUnknownNextPeer{}, 1, nil, ) } @@ -1118,7 +1118,7 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { if firstHop == roasbeefSongoku { failure := htlcswitch.NewForwardingError( - &lnwire.FailUnknownNextPeer{}, 1, + &lnwire.FailUnknownNextPeer{}, 1, nil, ) return [32]byte{}, failure } @@ -1161,7 +1161,7 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { // roasbeef not having enough capacity. return [32]byte{}, htlcswitch.NewForwardingError( &lnwire.FailTemporaryChannelFailure{}, - 1, + 1, nil, ) } return preImage, nil @@ -1408,7 +1408,7 @@ func TestSendToRouteStructuredError(t *testing.T) { ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcherOld).setPaymentResult( func(firstHop lnwire.ShortChannelID) ([32]byte, error) { return [32]byte{}, htlcswitch.NewForwardingError( - errorType, failIndex, + errorType, failIndex, nil, ) }, ) @@ -2330,7 +2330,7 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) { // Create the error to be returned. tempErr := htlcswitch.NewForwardingError( - &lnwire.FailTemporaryChannelFailure{}, 1, + &lnwire.FailTemporaryChannelFailure{}, 1, nil, ) // Register mockers with the expected method calls. @@ -2408,7 +2408,7 @@ func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) { // Create the error to be returned. permErr := htlcswitch.NewForwardingError( - &lnwire.FailIncorrectDetails{}, 1, + &lnwire.FailIncorrectDetails{}, 1, nil, ) // Register mockers with the expected method calls. @@ -2492,7 +2492,7 @@ func TestSendToRouteTempFailure(t *testing.T) { // Create the error to be returned. tempErr := htlcswitch.NewForwardingError( - &lnwire.FailTemporaryChannelFailure{}, 1, + &lnwire.FailTemporaryChannelFailure{}, 1, nil, ) // Register mockers with the expected method calls. diff --git a/server.go b/server.go index 83dd9a4d4f2..a4e5ccc2a76 100644 --- a/server.go +++ b/server.go @@ -799,7 +799,7 @@ func newServer(ctx context.Context, cfg *Config, listenAddrs []net.Addr, }, FwdingLog: dbs.ChanStateDB.ForwardingLog(), SwitchPackager: channeldb.NewSwitchPackager(), - ExtractErrorEncrypter: s.sphinx.ExtractErrorEncrypter, + ExtractSharedSecret: s.sphinx.ExtractSharedSecret, FetchLastChannelUpdate: s.fetchLastChanUpdate(), Notifier: s.cc.ChainNotifier, HtlcNotifier: s.htlcNotifier,