@@ -13,7 +13,7 @@ use diesel::result::{ConnectionError, ConnectionResult};
1313use diesel:: QueryResult ;
1414use futures_core:: future:: BoxFuture ;
1515use futures_core:: stream:: BoxStream ;
16- use futures_util :: stream ;
16+ use futures_core :: Stream ;
1717use futures_util:: { FutureExt , StreamExt , TryStreamExt } ;
1818use mysql_async:: prelude:: Queryable ;
1919use mysql_async:: { Opts , OptsBuilder , Statement } ;
@@ -36,6 +36,7 @@ pub struct AsyncMysqlConnection {
3636 stmt_cache : StatementCache < Mysql , Statement > ,
3737 transaction_manager : AnsiTransactionManager ,
3838 instrumentation : DynInstrumentation ,
39+ stmt_to_free : Vec < mysql_async:: Statement > ,
3940}
4041
4142impl SimpleAsyncConnection for AsyncMysqlConnection {
@@ -81,48 +82,7 @@ impl AsyncConnectionCore for AsyncMysqlConnection {
8182 + ' query ,
8283 {
8384 self . with_prepared_statement ( source. as_query ( ) , |conn, stmt, binds| async move {
84- let stmt_for_exec = match stmt {
85- MaybeCached :: Cached ( ref s) => ( * s) . clone ( ) ,
86- MaybeCached :: CannotCache ( ref s) => s. clone ( ) ,
87- _ => unreachable ! (
88- "Diesel has only two variants here at the time of writing.\n \
89- If you ever see this error message please open in issue in the diesel-async issue tracker"
90- ) ,
91- } ;
92-
93- let ( tx, rx) = futures_channel:: mpsc:: channel ( 0 ) ;
94-
95- let yielder = async move {
96- let r = Self :: poll_result_stream ( conn, stmt_for_exec, binds, tx) . await ;
97- // We need to close any non-cached statement explicitly here as otherwise
98- // we might error out on too many open statements. See https://github.com/weiznich/diesel_async/issues/26
99- // for details
100- //
101- // This might be problematic for cases where the stream is dropped before the end is reached
102- //
103- // Such behaviour might happen if users:
104- // * Just drop the future/stream after polling at least once (timeouts!!)
105- // * Users only fetch a fixed number of elements from the stream
106- //
107- // For now there is not really a good solution to this problem as this would require something like async drop
108- // (and even with async drop that would be really hard to solve due to the involved lifetimes)
109- if let MaybeCached :: CannotCache ( stmt) = stmt {
110- conn. close ( stmt) . await . map_err ( ErrorHelper ) ?;
111- }
112- r
113- } ;
114-
115- let fake_stream = stream:: once ( yielder) . filter_map ( |e : QueryResult < ( ) > | async move {
116- if let Err ( e) = e {
117- Some ( Err ( e) )
118- } else {
119- None
120- }
121- } ) ;
122-
123- let stream = stream:: select ( fake_stream, rx) . boxed ( ) ;
124-
125- Ok ( stream)
85+ Ok ( Self :: poll_result_stream ( conn, stmt, binds) . await ?. boxed ( ) )
12686 } )
12787 . boxed ( )
12888 }
@@ -139,20 +99,6 @@ impl AsyncConnectionCore for AsyncMysqlConnection {
13999 self . with_prepared_statement ( source, |conn, stmt, binds| async move {
140100 let params = mysql_async:: Params :: try_from ( binds) ?;
141101 conn. exec_drop ( & * stmt, params) . await . map_err ( ErrorHelper ) ?;
142- // We need to close any non-cached statement explicitly here as otherwise
143- // we might error out on too many open statements. See https://github.com/weiznich/diesel_async/issues/26
144- // for details
145- //
146- // This might be problematic for cases where the stream is dropped before the end is reached
147- //
148- // Such behaviour might happen if users:
149- // * Just drop the future after polling at least once (timeouts!!)
150- //
151- // For now there is not really a good solution to this problem as this would require something like async drop
152- // (and even with async drop that would be really hard to solve due to the involved lifetimes)
153- if let MaybeCached :: CannotCache ( stmt) = stmt {
154- conn. close ( stmt) . await . map_err ( ErrorHelper ) ?;
155- }
156102 conn. affected_rows ( )
157103 . try_into ( )
158104 . map_err ( |e| diesel:: result:: Error :: DeserializationError ( Box :: new ( e) ) )
@@ -244,6 +190,7 @@ impl AsyncMysqlConnection {
244190 stmt_cache : StatementCache :: new ( ) ,
245191 transaction_manager : AnsiTransactionManager :: default ( ) ,
246192 instrumentation : DynInstrumentation :: default_instrumentation ( ) ,
193+ stmt_to_free : Vec :: new ( ) ,
247194 } ;
248195
249196 for stmt in CONNECTION_SETUP_QUERIES {
@@ -290,6 +237,7 @@ impl AsyncMysqlConnection {
290237 ref mut stmt_cache,
291238 ref mut transaction_manager,
292239 ref mut instrumentation,
240+ ref mut stmt_to_free,
293241 ..
294242 } = self ;
295243
@@ -299,6 +247,19 @@ impl AsyncMysqlConnection {
299247 let query_id = T :: query_id ( ) ;
300248
301249 async move {
250+ // We need to close any non-cached statement explicitly here as otherwise
251+ // we might error out on too many open statements. See https://github.com/weiznich/diesel_async/issues/26
252+ // and https://github.com/weiznich/diesel_async/issues/269 for details
253+ //
254+ // We remember these statements from the last run as there is currenly no relaible way to
255+ // run this as destruction step after the execution finished. Users might abort polling the future, etc
256+ //
257+ // The overhead for this is keeping one additional statement open until the connection is used
258+ // next, so you would need to have `max_prepared_stmt_count - 1` other statements open for this to cause issues.
259+ // This is hopefully not a problem in practice
260+ for stmt in std:: mem:: take ( stmt_to_free) {
261+ conn. close ( stmt) . await . map_err ( ErrorHelper ) ?;
262+ }
302263 let RawBytesBindCollector {
303264 metadata, binds, ..
304265 } = bind_collector?;
@@ -320,6 +281,10 @@ impl AsyncMysqlConnection {
320281 & mut * * instrumentation,
321282 )
322283 . await ?;
284+ // for any not cached statement we need to remember to close them on the next connection usage
285+ if let MaybeCached :: CannotCache ( stmt) = & stmt {
286+ stmt_to_free. push ( stmt. clone ( ) ) ;
287+ }
323288 callback ( conn, stmt, ToSqlHelper { metadata, binds } ) . await
324289 } ;
325290 let r = update_transaction_manager_status ( inner. await , transaction_manager) ;
@@ -332,21 +297,31 @@ impl AsyncMysqlConnection {
332297 . boxed ( )
333298 }
334299
335- async fn poll_result_stream (
336- conn : & mut mysql_async:: Conn ,
337- stmt_for_exec : mysql_async:: Statement ,
300+ async fn poll_result_stream < ' conn > (
301+ conn : & ' conn mut mysql_async:: Conn ,
302+ stmt : MaybeCached < ' _ , mysql_async:: Statement > ,
338303 binds : ToSqlHelper ,
339- mut tx : futures_channel:: mpsc:: Sender < QueryResult < MysqlRow > > ,
340- ) -> QueryResult < ( ) > {
341- use futures_util:: sink:: SinkExt ;
304+ ) -> QueryResult < impl Stream < Item = QueryResult < MysqlRow > > + Send + use < ' conn > > {
342305 let params = mysql_async:: Params :: try_from ( binds) ?;
306+ let stmt_for_exec = match stmt {
307+ MaybeCached :: Cached ( ref s) => {
308+ ( * s) . clone ( )
309+ } ,
310+ MaybeCached :: CannotCache ( ref s) => {
311+ s. clone ( )
312+ } ,
313+ _ => unreachable ! (
314+ "Diesel has only two variants here at the time of writing.\n \
315+ If you ever see this error message please open in issue in the diesel-async issue tracker"
316+ ) ,
317+ } ;
343318
344319 let res = conn
345320 . exec_iter ( stmt_for_exec, params)
346321 . await
347322 . map_err ( ErrorHelper ) ?;
348323
349- let mut stream = res
324+ let stream = res
350325 . stream_and_drop :: < MysqlRow > ( )
351326 . await
352327 . map_err ( ErrorHelper ) ?
@@ -357,14 +332,7 @@ impl AsyncMysqlConnection {
357332 } ) ?
358333 . map_err ( |e| diesel:: result:: Error :: from ( ErrorHelper ( e) ) ) ;
359334
360- while let Some ( row) = stream. next ( ) . await {
361- let row = row?;
362- tx. send ( Ok ( row) )
363- . await
364- . map_err ( |e| diesel:: result:: Error :: DeserializationError ( Box :: new ( e) ) ) ?;
365- }
366-
367- Ok ( ( ) )
335+ Ok ( stream)
368336 }
369337
370338 async fn establish_connection_inner (
@@ -384,6 +352,7 @@ impl AsyncMysqlConnection {
384352 stmt_cache : StatementCache :: new ( ) ,
385353 transaction_manager : AnsiTransactionManager :: default ( ) ,
386354 instrumentation : DynInstrumentation :: none ( ) ,
355+ stmt_to_free : Vec :: new ( ) ,
387356 } )
388357 }
389358}
@@ -404,39 +373,78 @@ mod tests {
404373 }
405374 include ! ( "../doctest_setup.rs" ) ;
406375
376+ const STMT_COUNT : usize = 16382 + 1000 ;
377+
378+ #[ derive( Queryable ) ]
379+ #[ expect( dead_code, reason = "used for the test as loading target" ) ]
380+ struct User {
381+ id : i32 ,
382+ name : String ,
383+ }
384+
407385 #[ tokio:: test]
408- async fn check_statements_are_dropped ( ) {
386+ async fn check_cached_statements_are_dropped ( ) {
409387 use self :: schema:: users;
410388
411389 let mut conn = establish_connection ( ) . await ;
412- // we cannot set a lower limit here without admin privileges
413- // which makes this test really slow
414- let stmt_count = 16382 + 10 ;
415390
416- for i in 0 ..stmt_count {
417- diesel :: insert_into ( users:: table)
418- . values ( Some ( users:: name . eq ( format ! ( "User{i}" ) ) ) )
419- . execute ( & mut conn)
391+ for _i in 0 ..STMT_COUNT {
392+ users:: table
393+ . select ( users:: id )
394+ . load :: < i32 > ( & mut conn)
420395 . await
421396 . unwrap ( ) ;
422397 }
398+ }
423399
424- #[ derive( QueryableByName ) ]
425- #[ diesel( table_name = users) ]
426- #[ allow( dead_code) ]
427- struct User {
428- id : i32 ,
429- name : String ,
430- }
400+ #[ tokio:: test]
401+ async fn check_uncached_statements_are_dropped ( ) {
402+ use self :: schema:: users;
403+
404+ let mut conn = establish_connection ( ) . await ;
431405
432- for i in 0 ..stmt_count {
433- diesel :: sql_query ( "SELECT id, name FROM users WHERE name = ?" )
434- . bind :: < diesel :: sql_types :: Text , _ > ( format ! ( "User{i}" ) )
406+ for _i in 0 ..STMT_COUNT {
407+ users :: table
408+ . filter ( users :: dsl :: id . eq_any ( & [ 1 , 2 ] ) )
435409 . load :: < User > ( & mut conn)
436410 . await
437411 . unwrap ( ) ;
438412 }
439413 }
414+
415+ #[ tokio:: test]
416+ async fn check_cached_statements_are_dropped_get_result ( ) {
417+ use self :: schema:: users;
418+ use diesel:: OptionalExtension ;
419+
420+ let mut conn = establish_connection ( ) . await ;
421+
422+ for _i in 0 ..STMT_COUNT {
423+ users:: table
424+ . select ( users:: id)
425+ . get_result :: < i32 > ( & mut conn)
426+ . await
427+ . optional ( )
428+ . unwrap ( ) ;
429+ }
430+ }
431+
432+ #[ tokio:: test]
433+ async fn check_uncached_statements_are_dropped_get_result ( ) {
434+ use self :: schema:: users;
435+ use diesel:: OptionalExtension ;
436+
437+ let mut conn = establish_connection ( ) . await ;
438+
439+ for _i in 0 ..STMT_COUNT {
440+ users:: table
441+ . filter ( users:: dsl:: id. eq_any ( & [ 1 , 2 ] ) )
442+ . get_result :: < User > ( & mut conn)
443+ . await
444+ . optional ( )
445+ . unwrap ( ) ;
446+ }
447+ }
440448}
441449
442450impl QueryFragmentForCachedStatement < Mysql > for QueryFragmentHelper {
0 commit comments