diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b59b9da9..c2d03b21f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release. package-level function getNumberLength (#219) - Datetime location after encode + decode is unequal (#217) - Wrong interval arithmetic with timezones (#221) +- Invalid MsgPack if STREAM_ID > 127 (#224) ## [1.8.0] - 2022-08-17 diff --git a/connection.go b/connection.go index 3fe6fec26..a4ae8cc36 100644 --- a/connection.go +++ b/connection.go @@ -6,10 +6,12 @@ import ( "bufio" "bytes" "context" + "encoding/binary" "errors" "fmt" "io" "log" + "math" "net" "runtime" "sync" @@ -531,23 +533,38 @@ func (conn *Connection) dial() (err error) { func pack(h *smallWBuf, enc *encoder, reqid uint32, req Request, streamId uint64, res SchemaResolver) (err error) { + const uint32Code = 0xce + const uint64Code = 0xcf + const streamBytesLenUint64 = 10 + const streamBytesLenUint32 = 6 + hl := h.Len() + var streamBytesLen = 0 + var streamBytes [streamBytesLenUint64]byte hMapLen := byte(0x82) // 2 element map. if streamId != ignoreStreamId { hMapLen = byte(0x83) // 3 element map. + streamBytes[0] = KeyStreamId + if streamId > math.MaxUint32 { + streamBytesLen = streamBytesLenUint64 + streamBytes[1] = uint64Code + binary.BigEndian.PutUint64(streamBytes[2:], streamId) + } else { + streamBytesLen = streamBytesLenUint32 + streamBytes[1] = uint32Code + binary.BigEndian.PutUint32(streamBytes[2:], uint32(streamId)) + } } - hBytes := []byte{ - 0xce, 0, 0, 0, 0, // Length. + + hBytes := append([]byte{ + uint32Code, 0, 0, 0, 0, // Length. hMapLen, KeyCode, byte(req.Code()), // Request code. - KeySync, 0xce, + KeySync, uint32Code, byte(reqid >> 24), byte(reqid >> 16), byte(reqid >> 8), byte(reqid), - } - if streamId != ignoreStreamId { - hBytes = append(hBytes, KeyStreamId, byte(streamId)) - } + }, streamBytes[:streamBytesLen]...) h.Write(hBytes) diff --git a/tarantool_test.go b/tarantool_test.go index a86fbb716..0037aea8c 100644 --- a/tarantool_test.go +++ b/tarantool_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log" + "math" "os" "reflect" "runtime" @@ -2442,6 +2443,38 @@ func TestComplexStructs(t *testing.T) { } } +func TestStream_IdValues(t *testing.T) { + test_helpers.SkipIfStreamsUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + cases := []uint64{ + 1, + 128, + math.MaxUint8, + math.MaxUint8 + 1, + math.MaxUint16, + math.MaxUint16 + 1, + math.MaxUint32, + math.MaxUint32 + 1, + math.MaxUint64, + } + + stream, _ := conn.NewStream() + req := NewPingRequest() + + for _, id := range cases { + t.Run(fmt.Sprintf("%d", id), func(t *testing.T) { + stream.Id = id + _, err := stream.Do(req).Get() + if err != nil { + t.Fatalf("Failed to Ping: %s", err.Error()) + } + }) + } +} + func TestStream_Commit(t *testing.T) { var req Request var resp *Response