@@ -16,6 +16,7 @@ use crate::auth_passthrough::refetch_auth_hash;
16
16
use crate :: config:: { get_config, get_idle_client_in_transaction_timeout, Address , PoolMode } ;
17
17
use crate :: constants:: * ;
18
18
use crate :: messages:: * ;
19
+ use crate :: plugins:: PluginOutput ;
19
20
use crate :: pool:: { get_pool, ClientServerMap , ConnectionPool } ;
20
21
use crate :: query_router:: { Command , QueryRouter } ;
21
22
use crate :: server:: Server ;
@@ -765,6 +766,9 @@ where
765
766
766
767
self . stats . register ( self . stats . clone ( ) ) ;
767
768
769
+ // Error returned by one of the plugins.
770
+ let mut plugin_output = None ;
771
+
768
772
// Our custom protocol loop.
769
773
// We expect the client to either start a transaction with regular queries
770
774
// or issue commands for our sharding and server selection protocol.
@@ -816,6 +820,22 @@ where
816
820
'Q' => {
817
821
if query_router. query_parser_enabled ( ) {
818
822
if let Ok ( ast) = QueryRouter :: parse ( & message) {
823
+ let plugin_result = query_router. execute_plugins ( & ast) . await ;
824
+
825
+ match plugin_result {
826
+ Ok ( PluginOutput :: Deny ( error) ) => {
827
+ error_response ( & mut self . write , & error) . await ?;
828
+ continue ;
829
+ }
830
+
831
+ Ok ( PluginOutput :: Intercept ( result) ) => {
832
+ write_all ( & mut self . write , result) . await ?;
833
+ continue ;
834
+ }
835
+
836
+ _ => ( ) ,
837
+ } ;
838
+
819
839
let _ = query_router. infer ( & ast) ;
820
840
}
821
841
}
@@ -826,6 +846,10 @@ where
826
846
827
847
if query_router. query_parser_enabled ( ) {
828
848
if let Ok ( ast) = QueryRouter :: parse ( & message) {
849
+ if let Ok ( output) = query_router. execute_plugins ( & ast) . await {
850
+ plugin_output = Some ( output) ;
851
+ }
852
+
829
853
let _ = query_router. infer ( & ast) ;
830
854
}
831
855
}
@@ -861,6 +885,18 @@ where
861
885
continue ;
862
886
}
863
887
888
+ // Check on plugin results.
889
+ match plugin_output {
890
+ Some ( PluginOutput :: Deny ( error) ) => {
891
+ self . buffer . clear ( ) ;
892
+ error_response ( & mut self . write , & error) . await ?;
893
+ plugin_output = None ;
894
+ continue ;
895
+ }
896
+
897
+ _ => ( ) ,
898
+ } ;
899
+
864
900
// Get a pool instance referenced by the most up-to-date
865
901
// pointer. This ensures we always read the latest config
866
902
// when starting a query.
@@ -1089,6 +1125,27 @@ where
1089
1125
match code {
1090
1126
// Query
1091
1127
'Q' => {
1128
+ if query_router. query_parser_enabled ( ) {
1129
+ if let Ok ( ast) = QueryRouter :: parse ( & message) {
1130
+ let plugin_result = query_router. execute_plugins ( & ast) . await ;
1131
+
1132
+ match plugin_result {
1133
+ Ok ( PluginOutput :: Deny ( error) ) => {
1134
+ error_response ( & mut self . write , & error) . await ?;
1135
+ continue ;
1136
+ }
1137
+
1138
+ Ok ( PluginOutput :: Intercept ( result) ) => {
1139
+ write_all ( & mut self . write , result) . await ?;
1140
+ continue ;
1141
+ }
1142
+
1143
+ _ => ( ) ,
1144
+ } ;
1145
+
1146
+ let _ = query_router. infer ( & ast) ;
1147
+ }
1148
+ }
1092
1149
debug ! ( "Sending query to server" ) ;
1093
1150
1094
1151
self . send_and_receive_loop (
@@ -1128,6 +1185,14 @@ where
1128
1185
// Parse
1129
1186
// The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`.
1130
1187
'P' => {
1188
+ if query_router. query_parser_enabled ( ) {
1189
+ if let Ok ( ast) = QueryRouter :: parse ( & message) {
1190
+ if let Ok ( output) = query_router. execute_plugins ( & ast) . await {
1191
+ plugin_output = Some ( output) ;
1192
+ }
1193
+ }
1194
+ }
1195
+
1131
1196
self . buffer . put ( & message[ ..] ) ;
1132
1197
}
1133
1198
@@ -1159,6 +1224,24 @@ where
1159
1224
'S' => {
1160
1225
debug ! ( "Sending query to server" ) ;
1161
1226
1227
+ match plugin_output {
1228
+ Some ( PluginOutput :: Deny ( error) ) => {
1229
+ error_response ( & mut self . write , & error) . await ?;
1230
+ plugin_output = None ;
1231
+ self . buffer . clear ( ) ;
1232
+ continue ;
1233
+ }
1234
+
1235
+ Some ( PluginOutput :: Intercept ( result) ) => {
1236
+ write_all ( & mut self . write , result) . await ?;
1237
+ plugin_output = None ;
1238
+ self . buffer . clear ( ) ;
1239
+ continue ;
1240
+ }
1241
+
1242
+ _ => ( ) ,
1243
+ } ;
1244
+
1162
1245
self . buffer . put ( & message[ ..] ) ;
1163
1246
1164
1247
let first_message_code = ( * self . buffer . get ( 0 ) . unwrap_or ( & 0 ) ) as char ;
0 commit comments