@@ -12,6 +12,7 @@ use sqlparser::dialect::PostgreSqlDialect;
1212use sqlparser:: parser:: Parser ;
1313
1414use crate :: config:: Role ;
15+ use crate :: errors:: Error ;
1516use crate :: messages:: BytesMutReader ;
1617use crate :: pool:: PoolSettings ;
1718use crate :: sharding:: Sharder ;
@@ -324,10 +325,7 @@ impl QueryRouter {
324325 Some ( ( command, value) )
325326 }
326327
327- /// Try to infer which server to connect to based on the contents of the query.
328- pub fn infer ( & mut self , message : & BytesMut ) -> bool {
329- debug ! ( "Inferring role" ) ;
330-
328+ pub fn parse ( message : & BytesMut ) -> Result < Vec < sqlparser:: ast:: Statement > , Error > {
331329 let mut message_cursor = Cursor :: new ( message) ;
332330
333331 let code = message_cursor. get_u8 ( ) as char ;
@@ -353,28 +351,33 @@ impl QueryRouter {
353351 query
354352 }
355353
356- _ => return false ,
354+ _ => return Err ( Error :: UnsupportedStatement ) ,
357355 } ;
358356
359- let ast = match Parser :: parse_sql ( & PostgreSqlDialect { } , & query) {
360- Ok ( ast) => ast,
357+ match Parser :: parse_sql ( & PostgreSqlDialect { } , & query) {
358+ Ok ( ast) => {
359+ debug ! ( "AST: {:?}" , ast) ;
360+ Ok ( ast)
361+ }
362+
361363 Err ( err) => {
362- // SELECT ... FOR UPDATE won't get parsed correctly.
363364 debug ! ( "{}: {}" , err, query) ;
364- self . active_role = Some ( Role :: Primary ) ;
365- return false ;
365+ Err ( Error :: QueryRouterParserError ( err. to_string ( ) ) )
366366 }
367- } ;
367+ }
368+ }
368369
369- debug ! ( "AST: {:?}" , ast) ;
370+ /// Try to infer which server to connect to based on the contents of the query.
371+ pub fn infer ( & mut self , ast : & Vec < sqlparser:: ast:: Statement > ) -> Result < ( ) , Error > {
372+ debug ! ( "Inferring role" ) ;
370373
371374 if ast. is_empty ( ) {
372375 // That's weird, no idea, let's go to primary
373376 self . active_role = Some ( Role :: Primary ) ;
374- return false ;
377+ return Err ( Error :: QueryRouterParserError ( "empty query" . into ( ) ) ) ;
375378 }
376379
377- for q in & ast {
380+ for q in ast {
378381 match q {
379382 // All transactions go to the primary, probably a write.
380383 StartTransaction { .. } => {
@@ -418,7 +421,7 @@ impl QueryRouter {
418421 } ;
419422 }
420423
421- true
424+ Ok ( ( ) )
422425 }
423426
424427 /// Parse the shard number from the Bind message
@@ -862,7 +865,7 @@ mod test {
862865
863866 for query in queries {
864867 // It's a recognized query
865- assert ! ( qr. infer( & query) ) ;
868+ assert ! ( qr. infer( & QueryRouter :: parse ( & query) . unwrap ( ) ) . is_ok ( ) ) ;
866869 assert_eq ! ( qr. role( ) , Some ( Role :: Replica ) ) ;
867870 }
868871 }
@@ -881,7 +884,7 @@ mod test {
881884
882885 for query in queries {
883886 // It's a recognized query
884- assert ! ( qr. infer( & query) ) ;
887+ assert ! ( qr. infer( & QueryRouter :: parse ( & query) . unwrap ( ) ) . is_ok ( ) ) ;
885888 assert_eq ! ( qr. role( ) , Some ( Role :: Primary ) ) ;
886889 }
887890 }
@@ -893,7 +896,7 @@ mod test {
893896 let query = simple_query ( "SELECT * FROM items WHERE id = 5" ) ;
894897 assert ! ( qr. try_execute_command( & simple_query( "SET PRIMARY READS TO on" ) ) != None ) ;
895898
896- assert ! ( qr. infer( & query) ) ;
899+ assert ! ( qr. infer( & QueryRouter :: parse ( & query) . unwrap ( ) ) . is_ok ( ) ) ;
897900 assert_eq ! ( qr. role( ) , None ) ;
898901 }
899902
@@ -913,7 +916,7 @@ mod test {
913916 res. put ( prepared_stmt) ;
914917 res. put_i16 ( 0 ) ;
915918
916- assert ! ( qr. infer( & res) ) ;
919+ assert ! ( qr. infer( & QueryRouter :: parse ( & res) . unwrap ( ) ) . is_ok ( ) ) ;
917920 assert_eq ! ( qr. role( ) , Some ( Role :: Replica ) ) ;
918921 }
919922
@@ -1077,11 +1080,11 @@ mod test {
10771080 assert_eq ! ( qr. role( ) , None ) ;
10781081
10791082 let query = simple_query ( "INSERT INTO test_table VALUES (1)" ) ;
1080- assert ! ( qr. infer( & query) ) ;
1083+ assert ! ( qr. infer( & QueryRouter :: parse ( & query) . unwrap ( ) ) . is_ok ( ) ) ;
10811084 assert_eq ! ( qr. role( ) , Some ( Role :: Primary ) ) ;
10821085
10831086 let query = simple_query ( "SELECT * FROM test_table" ) ;
1084- assert ! ( qr. infer( & query) ) ;
1087+ assert ! ( qr. infer( & QueryRouter :: parse ( & query) . unwrap ( ) ) . is_ok ( ) ) ;
10851088 assert_eq ! ( qr. role( ) , Some ( Role :: Replica ) ) ;
10861089
10871090 assert ! ( qr. query_parser_enabled( ) ) ;
@@ -1142,15 +1145,24 @@ mod test {
11421145 QueryRouter :: setup ( ) ;
11431146
11441147 let mut qr = QueryRouter :: new ( ) ;
1145- assert ! ( qr. infer( & simple_query( "BEGIN; SELECT 1; COMMIT;" ) ) ) ;
1148+ assert ! ( qr
1149+ . infer( & QueryRouter :: parse( & simple_query( "BEGIN; SELECT 1; COMMIT;" ) ) . unwrap( ) )
1150+ . is_ok( ) ) ;
11461151 assert_eq ! ( qr. role( ) , Role :: Primary ) ;
11471152
1148- assert ! ( qr. infer( & simple_query( "SELECT 1; SELECT 2;" ) ) ) ;
1153+ assert ! ( qr
1154+ . infer( & QueryRouter :: parse( & simple_query( "SELECT 1; SELECT 2;" ) ) . unwrap( ) )
1155+ . is_ok( ) ) ;
11491156 assert_eq ! ( qr. role( ) , Role :: Replica ) ;
11501157
1151- assert ! ( qr. infer( & simple_query(
1152- "SELECT 123; INSERT INTO t VALUES (5); SELECT 1;"
1153- ) ) ) ;
1158+ assert ! ( qr
1159+ . infer(
1160+ & QueryRouter :: parse( & simple_query(
1161+ "SELECT 123; INSERT INTO t VALUES (5); SELECT 1;"
1162+ ) )
1163+ . unwrap( )
1164+ )
1165+ . is_ok( ) ) ;
11541166 assert_eq ! ( qr. role( ) , Role :: Primary ) ;
11551167 }
11561168
@@ -1208,47 +1220,84 @@ mod test {
12081220 qr. pool_settings . automatic_sharding_key = Some ( "data.id" . to_string ( ) ) ;
12091221 qr. pool_settings . shards = 3 ;
12101222
1211- assert ! ( qr. infer( & simple_query( "SELECT * FROM data WHERE id = 5" ) ) ) ;
1223+ assert ! ( qr
1224+ . infer( & QueryRouter :: parse( & simple_query( "SELECT * FROM data WHERE id = 5" ) ) . unwrap( ) )
1225+ . is_ok( ) ) ;
12121226 assert_eq ! ( qr. shard( ) , 2 ) ;
12131227
1214- assert ! ( qr. infer( & simple_query(
1215- "SELECT one, two, three FROM public.data WHERE id = 6"
1216- ) ) ) ;
1228+ assert ! ( qr
1229+ . infer(
1230+ & QueryRouter :: parse( & simple_query(
1231+ "SELECT one, two, three FROM public.data WHERE id = 6"
1232+ ) )
1233+ . unwrap( )
1234+ )
1235+ . is_ok( ) ) ;
12171236 assert_eq ! ( qr. shard( ) , 0 ) ;
12181237
1219- assert ! ( qr. infer( & simple_query(
1220- "SELECT * FROM data
1238+ assert ! ( qr
1239+ . infer(
1240+ & QueryRouter :: parse( & simple_query(
1241+ "SELECT * FROM data
12211242 INNER JOIN t2 ON data.id = 5
12221243 AND t2.data_id = data.id
12231244 WHERE data.id = 5"
1224- ) ) ) ;
1245+ ) )
1246+ . unwrap( )
1247+ )
1248+ . is_ok( ) ) ;
12251249 assert_eq ! ( qr. shard( ) , 2 ) ;
12261250
12271251 // Shard did not move because we couldn't determine the sharding key since it could be ambiguous
12281252 // in the query.
1229- assert ! ( qr. infer( & simple_query(
1230- "SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
1231- ) ) ) ;
1253+ assert ! ( qr
1254+ . infer(
1255+ & QueryRouter :: parse( & simple_query(
1256+ "SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
1257+ ) )
1258+ . unwrap( )
1259+ )
1260+ . is_ok( ) ) ;
12321261 assert_eq ! ( qr. shard( ) , 2 ) ;
12331262
1234- assert ! ( qr. infer( & simple_query(
1235- r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
1236- ) ) ) ;
1263+ assert ! ( qr
1264+ . infer(
1265+ & QueryRouter :: parse( & simple_query(
1266+ r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
1267+ ) )
1268+ . unwrap( )
1269+ )
1270+ . is_ok( ) ) ;
12371271 assert_eq ! ( qr. shard( ) , 0 ) ;
12381272
1239- assert ! ( qr. infer( & simple_query(
1240- r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
1241- ) ) ) ;
1273+ assert ! ( qr
1274+ . infer(
1275+ & QueryRouter :: parse( & simple_query(
1276+ r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
1277+ ) )
1278+ . unwrap( )
1279+ )
1280+ . is_ok( ) ) ;
12421281 assert_eq ! ( qr. shard( ) , 2 ) ;
12431282
12441283 // Super unique sharding key
12451284 qr. pool_settings . automatic_sharding_key = Some ( "*.unique_enough_column_name" . to_string ( ) ) ;
1246- assert ! ( qr. infer( & simple_query(
1247- "SELECT * FROM table_x WHERE unique_enough_column_name = 6"
1248- ) ) ) ;
1285+ assert ! ( qr
1286+ . infer(
1287+ & QueryRouter :: parse( & simple_query(
1288+ "SELECT * FROM table_x WHERE unique_enough_column_name = 6"
1289+ ) )
1290+ . unwrap( )
1291+ )
1292+ . is_ok( ) ) ;
12491293 assert_eq ! ( qr. shard( ) , 0 ) ;
12501294
1251- assert ! ( qr. infer( & simple_query( "SELECT * FROM table_y WHERE another_key = 5" ) ) ) ;
1295+ assert ! ( qr
1296+ . infer(
1297+ & QueryRouter :: parse( & simple_query( "SELECT * FROM table_y WHERE another_key = 5" ) )
1298+ . unwrap( )
1299+ )
1300+ . is_ok( ) ) ;
12521301 assert_eq ! ( qr. shard( ) , 0 ) ;
12531302 }
12541303
@@ -1272,11 +1321,21 @@ mod test {
12721321 qr. pool_settings . automatic_sharding_key = Some ( "data.id" . to_string ( ) ) ;
12731322 qr. pool_settings . shards = 3 ;
12741323
1275- assert ! ( qr. infer( & simple_query( stmt) ) ) ;
1324+ assert ! ( qr
1325+ . infer( & QueryRouter :: parse( & simple_query( stmt) ) . unwrap( ) )
1326+ . is_ok( ) ) ;
12761327 assert_eq ! ( qr. placeholders. len( ) , 1 ) ;
12771328
12781329 assert ! ( qr. infer_shard_from_bind( & bind) ) ;
12791330 assert_eq ! ( qr. shard( ) , 2 ) ;
12801331 assert ! ( qr. placeholders. is_empty( ) ) ;
12811332 }
1333+
1334+ #[ test]
1335+ fn test_parse ( ) {
1336+ let query = simple_query ( "SELECT * FROM pg_database" ) ;
1337+ let ast = QueryRouter :: parse ( & query) ;
1338+
1339+ assert ! ( ast. is_ok( ) ) ;
1340+ }
12821341}
0 commit comments