Skip to content

Commit 75327ca

Browse files
committed
multi: decode zero-length onion message payloads
Since the onion message payload can be zero-length, we need to decode it correctly. This commit adds a boolean flag to the HopPayload Decode that tells whether the payload is an onion message payload or not. If it is, the payload is decoded as a tlv payload also if the first byte is 0x00. sphinx_test: Add zero-length payload om test
1 parent a5115c8 commit 75327ca

File tree

4 files changed

+141
-45
lines changed

4 files changed

+141
-45
lines changed

error.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package sphinx
22

3-
import "fmt"
3+
import (
4+
"errors"
5+
"fmt"
6+
)
47

58
var (
69
// ErrReplayedPacket is an error returned when a packet is rejected
@@ -24,4 +27,7 @@ var (
2427
// ErrLogEntryNotFound is an error returned when a packet lookup in a replay
2528
// log fails because it is missing.
2629
ErrLogEntryNotFound = fmt.Errorf("sphinx packet is not in log")
30+
31+
// ErrIOReadFull is returned when an io read full operation fails.
32+
ErrIOReadFull = errors.New("io read full error")
2733
)

payload.go

Lines changed: 47 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -86,49 +86,61 @@ func (hp *HopPayload) Encode(w io.Writer) error {
8686
return encodeTLVHopPayload(hp, w)
8787
}
8888

89-
// Decode unpacks an encoded HopPayload from the passed reader into the target
90-
// HopPayload.
91-
func (hp *HopPayload) Decode(r io.Reader) error {
92-
bufReader := bufio.NewReader(r)
93-
94-
// In order to properly parse the payload, we'll need to check the
95-
// first byte. We'll use a bufio reader to peek at it without consuming
96-
// it from the buffer.
89+
// DecodeHopPayload unpacks an encoded HopPayload from the passed reader into
90+
// the target HopPayload. tlvGuaranteed should be set to true if the caller only
91+
// wishes to accept TLV encoded payloads. By doing so, zero-lengt tlv payloads
92+
// are supported. If set to false, then the function will inspect the first byte
93+
// to determine the type of payload.
94+
func DecodeHopPayload(r io.Reader, tlvGuaranteed bool) (*HopPayload, error) {
95+
var (
96+
payloadSize uint16
97+
payloadType = PayloadTLV
98+
hmac [HMACSize]byte
99+
bufReader = bufio.NewReader(r)
100+
)
101+
102+
// If we are not sure if this is a TLV or legacy payload, then we need
103+
// to inspect the first byte to determine the type of payload. The first
104+
// byte is either a realm (legacy) or the beginning of a var-int
105+
// encoding the length of the payload (TLV). We'll use a bufio reader to
106+
// peek at it without consuming it from the buffer.
97107
peekByte, err := bufReader.Peek(1)
98108
if err != nil {
99-
return err
109+
return nil, fmt.Errorf("peek first payload byte: %w", err)
100110
}
101111

102-
var (
103-
legacyPayload = isLegacyPayloadByte(peekByte[0])
104-
payloadSize uint16
105-
)
106-
107-
if legacyPayload {
112+
if !tlvGuaranteed && isLegacyPayloadByte(peekByte[0]) {
113+
// If we're not guaranteed TLV, and the first byte indicates a
114+
// legacy payload, then we treat this as a legacy payload.
115+
payloadType = PayloadLegacy
108116
payloadSize = legacyPayloadSize()
109-
hp.Type = PayloadLegacy
110117
} else {
118+
// Otherwise, we treat this as a TLV payload.
111119
payloadSize, err = tlvPayloadSize(bufReader)
112120
if err != nil {
113-
return err
121+
return nil, err
114122
}
115-
116-
hp.Type = PayloadTLV
117123
}
118124

119-
// Now that we know the payload size, we'll create a new buffer to
120-
// read it out in full.
121-
//
122-
// TODO(roasbeef): can avoid all these copies
123-
hp.Payload = make([]byte, payloadSize)
124-
if _, err := io.ReadFull(bufReader, hp.Payload[:]); err != nil {
125-
return err
125+
// Now that we know the payload size, we'll create a new buffer to read
126+
// it out in full.
127+
payload := make([]byte, payloadSize)
128+
129+
_, err = io.ReadFull(bufReader, payload)
130+
if err != nil {
131+
return nil, fmt.Errorf("%w: %w", ErrIOReadFull, err)
126132
}
127-
if _, err := io.ReadFull(bufReader, hp.HMAC[:]); err != nil {
128-
return err
133+
134+
_, err = io.ReadFull(bufReader, hmac[:])
135+
if err != nil {
136+
return nil, fmt.Errorf("%w: %w", ErrIOReadFull, err)
129137
}
130138

131-
return nil
139+
return &HopPayload{
140+
Type: payloadType,
141+
Payload: payload,
142+
HMAC: hmac,
143+
}, nil
132144
}
133145

134146
// HopData attempts to extract a set of forwarding instructions from the target
@@ -314,8 +326,12 @@ func legacyNumBytes() int {
314326
return LegacyHopDataSize
315327
}
316328

317-
// isLegacyPayload returns true if the given byte is equal to the 0x00 byte
318-
// which indicates that the payload should be decoded as a legacy payload.
329+
// isLegacyPayloadByte determines if the first byte of a hop payload indicates
330+
// that it is a legacy payload. The first byte of a legacy payload will always
331+
// be 0x00, as this is the realm. For TLV payloads, the first byte is a
332+
// var-int encoding the length of the payload. A TLV stream can be empty, in
333+
// which case its length is 0, which is also encoded as a 0x00 byte. This
334+
// creates an ambiguity between a legacy payload and an empty TLV payload.
319335
func isLegacyPayloadByte(b byte) bool {
320336
return b == 0x00
321337
}

sphinx.go

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,8 @@ func (r *Router) Stop() {
510510
// processOnionCfg is a set of config values that can be used to modify how an
511511
// onion is processed.
512512
type processOnionCfg struct {
513-
blindingPoint *btcec.PublicKey
513+
blindingPoint *btcec.PublicKey
514+
tlvPayloadOnly bool
514515
}
515516

516517
// ProcessOnionOpt defines the signature of a function option that can be used
@@ -525,6 +526,14 @@ func WithBlindingPoint(point *btcec.PublicKey) ProcessOnionOpt {
525526
}
526527
}
527528

529+
// WithTLVPayloadOnly is a functional option that signals that the onion packet
530+
// being processed is an onion_message_packet.
531+
func WithTLVPayloadOnly() ProcessOnionOpt {
532+
return func(cfg *processOnionCfg) {
533+
cfg.tlvPayloadOnly = true
534+
}
535+
}
536+
528537
// ProcessOnionPacket processes an incoming onion packet which has been forward
529538
// to the target Sphinx router. If the encoded ephemeral key isn't on the
530539
// target Elliptic Curve, then the packet is rejected. Similarly, if the
@@ -560,7 +569,9 @@ func (r *Router) ProcessOnionPacket(onionPkt *OnionPacket, assocData []byte,
560569
// Continue to optimistically process this packet, deferring replay
561570
// protection until the end to reduce the penalty of multiple IO
562571
// operations.
563-
packet, err := processOnionPacket(onionPkt, &sharedSecret, assocData)
572+
packet, err := processOnionPacket(
573+
onionPkt, &sharedSecret, assocData, cfg.tlvPayloadOnly,
574+
)
564575
if err != nil {
565576
return nil, err
566577
}
@@ -594,7 +605,9 @@ func (r *Router) ReconstructOnionPacket(onionPkt *OnionPacket, assocData []byte,
594605
return nil, err
595606
}
596607

597-
return processOnionPacket(onionPkt, &sharedSecret, assocData)
608+
return processOnionPacket(
609+
onionPkt, &sharedSecret, assocData, cfg.tlvPayloadOnly,
610+
)
598611
}
599612

600613
// DecryptBlindedHopData uses the router's private key to decrypt data encrypted
@@ -625,7 +638,8 @@ func (r *Router) OnionPublicKey() *btcec.PublicKey {
625638
// packet. This function returns the next inner onion packet layer, along with
626639
// the hop data extracted from the outer onion packet.
627640
func unwrapPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
628-
assocData []byte) (*OnionPacket, *HopPayload, error) {
641+
assocData []byte, tlvPayloadOnly bool) (*OnionPacket, *HopPayload,
642+
error) {
629643

630644
dhKey := onionPkt.EphemeralKey
631645
routeInfo := onionPkt.RoutingInfo
@@ -649,8 +663,8 @@ func unwrapPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
649663
zeroBytes := bytes.Repeat([]byte{0}, MaxPayloadSize)
650664
headerWithPadding := append(routeInfo[:], zeroBytes...)
651665

652-
var hopInfo [numStreamBytes]byte
653-
xor(hopInfo[:], headerWithPadding, streamBytes)
666+
hopInfo := make([]byte, numStreamBytes)
667+
xor(hopInfo, headerWithPadding, streamBytes)
654668

655669
// Randomize the DH group element for the next hop using the
656670
// deterministic blinding factor.
@@ -660,8 +674,10 @@ func unwrapPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
660674
// With the MAC checked, and the payload decrypted, we can now parse
661675
// out the payload so we can derive the specified forwarding
662676
// instructions.
663-
var hopPayload HopPayload
664-
if err := hopPayload.Decode(bytes.NewReader(hopInfo[:])); err != nil {
677+
hopPayload, err := DecodeHopPayload(
678+
bytes.NewReader(hopInfo), tlvPayloadOnly,
679+
)
680+
if err != nil {
665681
return nil, nil, err
666682
}
667683

@@ -676,14 +692,14 @@ func unwrapPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
676692
HeaderMAC: hopPayload.HMAC,
677693
}
678694

679-
return innerPkt, &hopPayload, nil
695+
return innerPkt, hopPayload, nil
680696
}
681697

682698
// processOnionPacket performs the primary key derivation and handling of onion
683699
// packets. The processed packets returned from this method should only be used
684700
// if the packet was not flagged as a replayed packet.
685701
func processOnionPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
686-
assocData []byte) (*ProcessedPacket, error) {
702+
assocData []byte, tlvPayloadOnly bool) (*ProcessedPacket, error) {
687703

688704
// First, we'll unwrap an initial layer of the onion packet. Typically,
689705
// we'll only have a single layer to unwrap, However, if the sender has
@@ -693,7 +709,7 @@ func processOnionPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
693709
// they can properly check the HMAC and unwrap a layer for their
694710
// handoff hop.
695711
innerPkt, outerHopPayload, err := unwrapPacket(
696-
onionPkt, sharedSecret, assocData,
712+
onionPkt, sharedSecret, assocData, tlvPayloadOnly,
697713
)
698714
if err != nil {
699715
return nil, err
@@ -703,7 +719,7 @@ func processOnionPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
703719
// However if the uncovered 'nextMac' is all zeroes, then this
704720
// indicates that we're the final hop in the route.
705721
var action ProcessCode = MoreHops
706-
if bytes.Compare(zeroHMAC[:], outerHopPayload.HMAC[:]) == 0 {
722+
if bytes.Equal(zeroHMAC[:], outerHopPayload.HMAC[:]) {
707723
action = ExitNode
708724
}
709725

@@ -794,7 +810,9 @@ func (t *Tx) ProcessOnionPacket(seqNum uint16, onionPkt *OnionPacket,
794810
// Continue to optimistically process this packet, deferring replay
795811
// protection until the end to reduce the penalty of multiple IO
796812
// operations.
797-
packet, err := processOnionPacket(onionPkt, &sharedSecret, assocData)
813+
packet, err := processOnionPacket(
814+
onionPkt, &sharedSecret, assocData, cfg.tlvPayloadOnly,
815+
)
798816
if err != nil {
799817
return err
800818
}

sphinx_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,62 @@ func TestTLVPayloadMessagePacket(t *testing.T) {
299299
hex.EncodeToString(finalPacket), hex.EncodeToString(b.Bytes()))
300300
}
301301

302+
// TestProcessOnionMessageZeroLengthPayload tests that we can properly process
303+
// an onion message that has a zero-length payload.
304+
func TestProcessOnionMessageZeroLengthPayload(t *testing.T) {
305+
t.Parallel()
306+
307+
// First, create a router that will be the destination of the onion
308+
// message.
309+
privKey, err := btcec.NewPrivateKey()
310+
require.NoError(t, err)
311+
312+
router := NewRouter(&PrivKeyECDH{privKey}, NewMemoryReplayLog())
313+
err = router.Start()
314+
require.NoError(t, err)
315+
316+
defer router.Stop()
317+
318+
// Next, create a session key for the onion packet.
319+
sessionKey, err := btcec.NewPrivateKey()
320+
require.NoError(t, err)
321+
322+
// We'll create a simple one-hop path.
323+
path := &PaymentPath{
324+
{
325+
NodePub: *privKey.PubKey(),
326+
},
327+
}
328+
329+
// The hop payload will be an empty TLV payload.
330+
payload, err := NewTLVHopPayload(nil)
331+
require.NoError(t, err)
332+
require.Empty(t, payload.Payload)
333+
path[0].HopPayload = payload
334+
335+
// Now, create the onion packet.
336+
onionPacket, err := NewOnionPacket(
337+
path, sessionKey, nil, DeterministicPacketFiller,
338+
)
339+
require.NoError(t, err)
340+
341+
// We'll now process the packet, making sure to indicate that this is
342+
// an onion message.
343+
processedPacket, err := router.ProcessOnionPacket(
344+
onionPacket, nil, 0, WithTLVPayloadOnly(),
345+
)
346+
require.NoError(t, err)
347+
348+
// The packet should be decoded as an exit node.
349+
require.EqualValues(t, ExitNode, processedPacket.Action)
350+
351+
// The payload should be of type TLV.
352+
require.Equal(t, PayloadTLV, processedPacket.Payload.Type)
353+
354+
// And the payload should be empty.
355+
require.Empty(t, processedPacket.Payload.Payload)
356+
}
357+
302358
func TestSphinxCorrectness(t *testing.T) {
303359
nodes, _, hopDatas, fwdMsg, err := newTestRoute(testLegacyRouteNumHops)
304360
if err != nil {

0 commit comments

Comments
 (0)