diff --git a/packets.go b/packets.go index 68f4378d8..21513b27b 100644 --- a/packets.go +++ b/packets.go @@ -1048,7 +1048,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { if rows.mc.parseTime { dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.loc) } else { - dest[i], err = formatBinaryDate(num, data[pos:]) + dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], false) } if err == nil { @@ -1116,7 +1116,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { if rows.mc.parseTime { dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.loc) } else { - dest[i], err = formatBinaryDateTime(num, data[pos:]) + dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], true) } if err == nil { diff --git a/utils.go b/utils.go index 9324806b2..b5100e465 100644 --- a/utils.go +++ b/utils.go @@ -503,55 +503,84 @@ func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Va return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num) } -func formatBinaryDate(num uint64, data []byte) (driver.Value, error) { - switch num { - case 0: - return []byte("0000-00-00"), nil - case 4: - return []byte(fmt.Sprintf( - "%04d-%02d-%02d", - binary.LittleEndian.Uint16(data[:2]), - data[2], - data[3], - )), nil - } - return nil, fmt.Errorf("Invalid DATE-packet length %d", num) -} - -func formatBinaryDateTime(num uint64, data []byte) (driver.Value, error) { - switch num { - case 0: - return []byte("0000-00-00 00:00:00"), nil - case 4: - return []byte(fmt.Sprintf( - "%04d-%02d-%02d 00:00:00", - binary.LittleEndian.Uint16(data[:2]), - data[2], - data[3], - )), nil - case 7: - return []byte(fmt.Sprintf( - "%04d-%02d-%02d %02d:%02d:%02d", - binary.LittleEndian.Uint16(data[:2]), - data[2], - data[3], - data[4], - data[5], - data[6], - )), nil - case 11: - return []byte(fmt.Sprintf( - "%04d-%02d-%02d %02d:%02d:%02d.%06d", - binary.LittleEndian.Uint16(data[:2]), - data[2], - data[3], - data[4], - data[5], - data[6], - binary.LittleEndian.Uint32(data[7:11]), - )), nil +// zeroDateTime is used in formatBinaryDateTime to avoid an allocation +// if the DATE or DATETIME has the zero value. +// It must never be changed. +// The current behavior depends on database/sql copying the result. +var zeroDateTime = []byte("0000-00-00 00:00:00") + +func formatBinaryDateTime(src []byte, withTime bool) (driver.Value, error) { + if len(src) == 0 { + if withTime { + return zeroDateTime, nil + } + return zeroDateTime[:10], nil + } + var dst []byte + if withTime { + if len(src) == 11 { + dst = []byte("0000-00-00 00:00:00.000000") + } else { + dst = []byte("0000-00-00 00:00:00") + } + } else { + dst = []byte("0000-00-00") } - return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num) + switch len(src) { + case 11: + microsecs := binary.LittleEndian.Uint32(src[7:11]) + tmp32 := microsecs / 10 + dst[25] += byte(microsecs - 10*tmp32) + tmp32, microsecs = tmp32/10, tmp32 + dst[24] += byte(microsecs - 10*tmp32) + tmp32, microsecs = tmp32/10, tmp32 + dst[23] += byte(microsecs - 10*tmp32) + tmp32, microsecs = tmp32/10, tmp32 + dst[22] += byte(microsecs - 10*tmp32) + tmp32, microsecs = tmp32/10, tmp32 + dst[21] += byte(microsecs - 10*tmp32) + dst[20] += byte(microsecs / 10) + fallthrough + case 7: + second := src[6] + tmp := second / 10 + dst[18] += second - 10*tmp + dst[17] += tmp + minute := src[5] + tmp = minute / 10 + dst[15] += minute - 10*tmp + dst[14] += tmp + hour := src[4] + tmp = hour / 10 + dst[12] += hour - 10*tmp + dst[11] += tmp + fallthrough + case 4: + day := src[3] + tmp := day / 10 + dst[9] += day - 10*tmp + dst[8] += tmp + month := src[2] + tmp = month / 10 + dst[6] += month - 10*tmp + dst[5] += tmp + year := binary.LittleEndian.Uint16(src[:2]) + tmp16 := year / 10 + dst[3] += byte(year - 10*tmp16) + tmp16, year = tmp16/10, tmp16 + dst[2] += byte(year - 10*tmp16) + tmp16, year = tmp16/10, tmp16 + dst[1] += byte(year - 10*tmp16) + dst[0] += byte(tmp16) + return dst, nil + } + var t string + if withTime { + t = "DATETIME" + } else { + t = "DATE" + } + return nil, fmt.Errorf("invalid %s-packet length %d", t, len(src)) } /****************************************************************************** diff --git a/utils_test.go b/utils_test.go index 233214d04..baf2b8c26 100644 --- a/utils_test.go +++ b/utils_test.go @@ -10,6 +10,7 @@ package mysql import ( "bytes" + "encoding/binary" "fmt" "testing" "time" @@ -180,3 +181,32 @@ func TestOldPass(t *testing.T) { } } } + +func TestFormatBinaryDateTime(t *testing.T) { + rawDate := [11]byte{} + binary.LittleEndian.PutUint16(rawDate[:2], 1978) // years + rawDate[2] = 12 // months + rawDate[3] = 30 // days + rawDate[4] = 15 // hours + rawDate[5] = 46 // minutes + rawDate[6] = 23 // seconds + binary.LittleEndian.PutUint32(rawDate[7:], 987654) // microseconds + expect := func(expected string, length int, withTime bool) { + actual, _ := formatBinaryDateTime(rawDate[:length], withTime) + bytes, ok := actual.([]byte) + if !ok { + t.Errorf("formatBinaryDateTime must return []byte, was %T", actual) + } + if string(bytes) != expected { + t.Errorf( + "expected %q, got %q for length %d, withTime %v", + bytes, actual, length, withTime, + ) + } + } + expect("0000-00-00", 0, false) + expect("0000-00-00 00:00:00", 0, true) + expect("1978-12-30", 4, false) + expect("1978-12-30 15:46:23", 7, true) + expect("1978-12-30 15:46:23.987654", 11, true) +}