@@ -24,6 +24,7 @@ import (
24
24
"os"
25
25
"reflect"
26
26
"runtime"
27
+ "strconv"
27
28
"strings"
28
29
"sync"
29
30
"sync/atomic"
@@ -3377,11 +3378,31 @@ func TestConnectionAttributes(t *testing.T) {
3377
3378
t .Skipf ("MySQL server not running on %s" , netAddr )
3378
3379
}
3379
3380
3380
- attr1 := "attr1"
3381
- value1 := "value1"
3382
- attr2 := "foo"
3383
- value2 := "boo"
3384
- dsn += fmt .Sprintf ("&connectionAttributes=%s:%s,%s:%s" , attr1 , value1 , attr2 , value2 )
3381
+ defaultAttrs := []string {
3382
+ connAttrClientName ,
3383
+ connAttrOS ,
3384
+ connAttrPlatform ,
3385
+ connAttrPid ,
3386
+ connAttrServerHost ,
3387
+ }
3388
+ host , _ , _ := net .SplitHostPort (addr )
3389
+ defaultAttrValues := []string {
3390
+ connAttrClientNameValue ,
3391
+ connAttrOSValue ,
3392
+ connAttrPlatformValue ,
3393
+ strconv .Itoa (os .Getpid ()),
3394
+ host ,
3395
+ }
3396
+
3397
+ customAttrs := []string {"attr1" , "attr2" }
3398
+ customAttrValues := []string {"foo" , "bar" }
3399
+
3400
+ customAttrStrs := make ([]string , len (customAttrs ))
3401
+ for i := range customAttrs {
3402
+ customAttrStrs [i ] = fmt .Sprintf ("%s:%s" , customAttrs [i ], customAttrValues [i ])
3403
+ }
3404
+
3405
+ dsn += fmt .Sprintf ("&connectionAttributes=%s" , strings .Join (customAttrStrs , "," ))
3385
3406
3386
3407
var db * sql.DB
3387
3408
if _ , err := ParseDSN (dsn ); err != errInvalidDSNUnsafeCollation {
@@ -3394,27 +3415,22 @@ func TestConnectionAttributes(t *testing.T) {
3394
3415
3395
3416
dbt := & DBTest {t , db }
3396
3417
3397
- var attrValue string
3398
- queryString := "SELECT ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID() and ATTR_NAME = ?"
3399
- rows := dbt .mustQuery (queryString , connAttrClientName )
3400
- if rows .Next () {
3401
- rows .Scan (& attrValue )
3402
- if attrValue != connAttrClientNameValue {
3403
- dbt .Errorf ("expected %q, got %q" , connAttrClientNameValue , attrValue )
3404
- }
3405
- } else {
3406
- dbt .Errorf ("no data" )
3418
+ queryString := "SELECT ATTR_NAME, ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID()"
3419
+ rows := dbt .mustQuery (queryString )
3420
+ defer rows .Close ()
3421
+
3422
+ rowsMap := make (map [string ]string )
3423
+ for rows .Next () {
3424
+ var attrName , attrValue string
3425
+ rows .Scan (& attrName , & attrValue )
3426
+ rowsMap [attrName ] = attrValue
3407
3427
}
3408
- rows .Close ()
3409
3428
3410
- rows = dbt . mustQuery ( queryString , attr2 )
3411
- if rows . Next () {
3412
- rows . Scan ( & attrValue )
3413
- if attrValue != value2 {
3414
- dbt .Errorf ("expected %q , got %q " , value2 , attrValue )
3429
+ connAttrs := append ( append ([] string {}, defaultAttrs ... ), customAttrs ... )
3430
+ expectedAttrValues := append ( append ([] string {}, defaultAttrValues ... ), customAttrValues ... )
3431
+ for i := range connAttrs {
3432
+ if gotValue := rowsMap [ connAttrs [ i ]]; gotValue != expectedAttrValues [ i ] {
3433
+ dbt .Errorf ("expected %s , got %s " , expectedAttrValues [ i ], gotValue )
3415
3434
}
3416
- } else {
3417
- dbt .Errorf ("no data" )
3418
3435
}
3419
- rows .Close ()
3420
3436
}
0 commit comments