diff --git a/utils.go b/utils.go index a20da9957..9ce8e503b 100644 --- a/utils.go +++ b/utils.go @@ -657,7 +657,7 @@ func readLengthEncodedInteger(b []byte) (uint64, bool, int) { case 0xfe: return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 | - uint64(b[7])<<48 | uint64(b[8])<<54, + uint64(b[7])<<48 | uint64(b[8])<<56, false, 9 } @@ -677,5 +677,6 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte { case n <= 0xffffff: return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16)) } - return b + return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24), + byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56)) } diff --git a/utils_test.go b/utils_test.go index d19c92731..088077a7c 100644 --- a/utils_test.go +++ b/utils_test.go @@ -11,6 +11,7 @@ package mysql import ( "fmt" "testing" + "bytes" "time" ) @@ -122,3 +123,42 @@ func TestScanNullTime(t *testing.T) { } } } + +func TestLengthEncodedInteger(t *testing.T) { + var integerTests = []struct { + num uint64 + encoded []byte + }{ + {0x0000000000000000, []byte{0x00}}, + {0x0000000000000012, []byte{0x12}}, + {0x00000000000000fa, []byte{0xfa}}, + {0x0000000000000100, []byte{0xfc, 0x00, 0x01}}, + {0x0000000000001234, []byte{0xfc, 0x34, 0x12}}, + {0x000000000000ffff, []byte{0xfc, 0xff, 0xff}}, + {0x0000000000010000, []byte{0xfd, 0x00, 0x00, 0x01}}, + {0x0000000000123456, []byte{0xfd, 0x56, 0x34, 0x12}}, + {0x0000000000ffffff, []byte{0xfd, 0xff, 0xff, 0xff}}, + {0x0000000001000000, []byte{0xfe, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}}, + {0x123456789abcdef0, []byte{0xfe, 0xf0, 0xde, 0xbc, 0x9a, 0x78, 0x56, 0x34, 0x12}}, + {0xffffffffffffffff, []byte{0xfe, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, + } + + for _, tst := range integerTests { + num, isNull, numLen := readLengthEncodedInteger(tst.encoded) + if isNull { + t.Errorf("%x: expected %d, got NULL", tst.encoded, tst.num) + } + if num != tst.num { + t.Errorf("%x: expected %d, got %d", tst.encoded, tst.num, num) + } + if numLen != len(tst.encoded) { + t.Errorf("%x: expected size %d, got %d", tst.encoded, len(tst.encoded), numLen) + } + encoded := appendLengthEncodedInteger(nil, num) + if (!bytes.Equal(encoded, tst.encoded)) { + t.Errorf("%v: expected %x, got %x", num, tst.encoded, encoded) + } + } + + +}