@@ -16,6 +16,7 @@ use crate::auth_passthrough::refetch_auth_hash;
1616use crate :: config:: { get_config, get_idle_client_in_transaction_timeout, Address , PoolMode } ;
1717use crate :: constants:: * ;
1818use crate :: messages:: * ;
19+ use crate :: plugins:: PluginOutput ;
1920use crate :: pool:: { get_pool, ClientServerMap , ConnectionPool } ;
2021use crate :: query_router:: { Command , QueryRouter } ;
2122use crate :: server:: Server ;
@@ -765,6 +766,9 @@ where
765766
766767 self . stats . register ( self . stats . clone ( ) ) ;
767768
769+ // Error returned by one of the plugins.
770+ let mut plugin_output = None ;
771+
768772 // Our custom protocol loop.
769773 // We expect the client to either start a transaction with regular queries
770774 // or issue commands for our sharding and server selection protocol.
@@ -816,6 +820,22 @@ where
816820 'Q' => {
817821 if query_router. query_parser_enabled ( ) {
818822 if let Ok ( ast) = QueryRouter :: parse ( & message) {
823+ let plugin_result = query_router. execute_plugins ( & ast) . await ;
824+
825+ match plugin_result {
826+ Ok ( PluginOutput :: Deny ( error) ) => {
827+ error_response ( & mut self . write , & error) . await ?;
828+ continue ;
829+ }
830+
831+ Ok ( PluginOutput :: Intercept ( result) ) => {
832+ write_all ( & mut self . write , result) . await ?;
833+ continue ;
834+ }
835+
836+ _ => ( ) ,
837+ } ;
838+
819839 let _ = query_router. infer ( & ast) ;
820840 }
821841 }
@@ -826,6 +846,10 @@ where
826846
827847 if query_router. query_parser_enabled ( ) {
828848 if let Ok ( ast) = QueryRouter :: parse ( & message) {
849+ if let Ok ( output) = query_router. execute_plugins ( & ast) . await {
850+ plugin_output = Some ( output) ;
851+ }
852+
829853 let _ = query_router. infer ( & ast) ;
830854 }
831855 }
@@ -861,6 +885,18 @@ where
861885 continue ;
862886 }
863887
888+ // Check on plugin results.
889+ match plugin_output {
890+ Some ( PluginOutput :: Deny ( error) ) => {
891+ self . buffer . clear ( ) ;
892+ error_response ( & mut self . write , & error) . await ?;
893+ plugin_output = None ;
894+ continue ;
895+ }
896+
897+ _ => ( ) ,
898+ } ;
899+
864900 // Get a pool instance referenced by the most up-to-date
865901 // pointer. This ensures we always read the latest config
866902 // when starting a query.
@@ -1089,6 +1125,27 @@ where
10891125 match code {
10901126 // Query
10911127 'Q' => {
1128+ if query_router. query_parser_enabled ( ) {
1129+ if let Ok ( ast) = QueryRouter :: parse ( & message) {
1130+ let plugin_result = query_router. execute_plugins ( & ast) . await ;
1131+
1132+ match plugin_result {
1133+ Ok ( PluginOutput :: Deny ( error) ) => {
1134+ error_response ( & mut self . write , & error) . await ?;
1135+ continue ;
1136+ }
1137+
1138+ Ok ( PluginOutput :: Intercept ( result) ) => {
1139+ write_all ( & mut self . write , result) . await ?;
1140+ continue ;
1141+ }
1142+
1143+ _ => ( ) ,
1144+ } ;
1145+
1146+ let _ = query_router. infer ( & ast) ;
1147+ }
1148+ }
10921149 debug ! ( "Sending query to server" ) ;
10931150
10941151 self . send_and_receive_loop (
@@ -1128,6 +1185,14 @@ where
11281185 // Parse
11291186 // The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`.
11301187 'P' => {
1188+ if query_router. query_parser_enabled ( ) {
1189+ if let Ok ( ast) = QueryRouter :: parse ( & message) {
1190+ if let Ok ( output) = query_router. execute_plugins ( & ast) . await {
1191+ plugin_output = Some ( output) ;
1192+ }
1193+ }
1194+ }
1195+
11311196 self . buffer . put ( & message[ ..] ) ;
11321197 }
11331198
@@ -1159,6 +1224,24 @@ where
11591224 'S' => {
11601225 debug ! ( "Sending query to server" ) ;
11611226
1227+ match plugin_output {
1228+ Some ( PluginOutput :: Deny ( error) ) => {
1229+ error_response ( & mut self . write , & error) . await ?;
1230+ plugin_output = None ;
1231+ self . buffer . clear ( ) ;
1232+ continue ;
1233+ }
1234+
1235+ Some ( PluginOutput :: Intercept ( result) ) => {
1236+ write_all ( & mut self . write , result) . await ?;
1237+ plugin_output = None ;
1238+ self . buffer . clear ( ) ;
1239+ continue ;
1240+ }
1241+
1242+ _ => ( ) ,
1243+ } ;
1244+
11621245 self . buffer . put ( & message[ ..] ) ;
11631246
11641247 let first_message_code = ( * self . buffer . get ( 0 ) . unwrap_or ( & 0 ) ) as char ;
0 commit comments