@@ -15,8 +15,15 @@ import { CursorTimeoutContext } from '../cursor/abstract_cursor';
1515import { getSocks , type SocksLib } from '../deps' ;
1616import { MongoOperationTimeoutError } from '../error' ;
1717import { type MongoClient , type MongoClientOptions } from '../mongo_client' ;
18+ import { type Abortable } from '../mongo_types' ;
1819import { Timeout , type TimeoutContext , TimeoutError } from '../timeout' ;
19- import { BufferPool , MongoDBCollectionNamespace , promiseWithResolvers } from '../utils' ;
20+ import {
21+ addAbortListener ,
22+ BufferPool ,
23+ kDispose ,
24+ MongoDBCollectionNamespace ,
25+ promiseWithResolvers
26+ } from '../utils' ;
2027import { autoSelectSocketOptions , type DataKey } from './client_encryption' ;
2128import { MongoCryptError } from './errors' ;
2229import { type MongocryptdManager } from './mongocryptd_manager' ;
@@ -189,7 +196,7 @@ export class StateMachine {
189196 async execute (
190197 executor : StateMachineExecutable ,
191198 context : MongoCryptContext ,
192- timeoutContext ?: TimeoutContext
199+ options : { timeoutContext ?: TimeoutContext } & Abortable
193200 ) : Promise < Uint8Array > {
194201 const keyVaultNamespace = executor . _keyVaultNamespace ;
195202 const keyVaultClient = executor . _keyVaultClient ;
@@ -199,6 +206,7 @@ export class StateMachine {
199206 let result : Uint8Array | null = null ;
200207
201208 while ( context . state !== MONGOCRYPT_CTX_DONE && context . state !== MONGOCRYPT_CTX_ERROR ) {
209+ options . signal ?. throwIfAborted ( ) ;
202210 debug ( `[context#${ context . id } ] ${ stateToString . get ( context . state ) || context . state } ` ) ;
203211
204212 switch ( context . state ) {
@@ -214,7 +222,7 @@ export class StateMachine {
214222 metaDataClient ,
215223 context . ns ,
216224 filter ,
217- timeoutContext
225+ options
218226 ) ;
219227 if ( collInfo ) {
220228 context . addMongoOperationResponse ( collInfo ) ;
@@ -235,9 +243,9 @@ export class StateMachine {
235243 // When we are using the shared library, we don't have a mongocryptd manager.
236244 const markedCommand : Uint8Array = mongocryptdManager
237245 ? await mongocryptdManager . withRespawn (
238- this . markCommand . bind ( this , mongocryptdClient , context . ns , command , timeoutContext )
246+ this . markCommand . bind ( this , mongocryptdClient , context . ns , command , options )
239247 )
240- : await this . markCommand ( mongocryptdClient , context . ns , command , timeoutContext ) ;
248+ : await this . markCommand ( mongocryptdClient , context . ns , command , options ) ;
241249
242250 context . addMongoOperationResponse ( markedCommand ) ;
243251 context . finishMongoOperation ( ) ;
@@ -246,12 +254,7 @@ export class StateMachine {
246254
247255 case MONGOCRYPT_CTX_NEED_MONGO_KEYS : {
248256 const filter = context . nextMongoOperation ( ) ;
249- const keys = await this . fetchKeys (
250- keyVaultClient ,
251- keyVaultNamespace ,
252- filter ,
253- timeoutContext
254- ) ;
257+ const keys = await this . fetchKeys ( keyVaultClient , keyVaultNamespace , filter , options ) ;
255258
256259 if ( keys . length === 0 ) {
257260 // See docs on EMPTY_V
@@ -273,7 +276,7 @@ export class StateMachine {
273276 }
274277
275278 case MONGOCRYPT_CTX_NEED_KMS : {
276- await Promise . all ( this . requests ( context , timeoutContext ) ) ;
279+ await Promise . all ( this . requests ( context , options ) ) ;
277280 context . finishKMSRequests ( ) ;
278281 break ;
279282 }
@@ -315,11 +318,13 @@ export class StateMachine {
315318 * @param kmsContext - A C++ KMS context returned from the bindings
316319 * @returns A promise that resolves when the KMS reply has be fully parsed
317320 */
318- async kmsRequest ( request : MongoCryptKMSRequest , timeoutContext ?: TimeoutContext ) : Promise < void > {
321+ async kmsRequest (
322+ request : MongoCryptKMSRequest ,
323+ options ?: { timeoutContext ?: TimeoutContext } & Abortable
324+ ) : Promise < void > {
319325 const parsedUrl = request . endpoint . split ( ':' ) ;
320326 const port = parsedUrl [ 1 ] != null ? Number . parseInt ( parsedUrl [ 1 ] , 10 ) : HTTPS_PORT ;
321- const socketOptions = autoSelectSocketOptions ( this . options . socketOptions || { } ) ;
322- const options : tls . ConnectionOptions & {
327+ const socketOptions : tls . ConnectionOptions & {
323328 host : string ;
324329 port : number ;
325330 autoSelectFamily ?: boolean ;
@@ -328,7 +333,7 @@ export class StateMachine {
328333 host : parsedUrl [ 0 ] ,
329334 servername : parsedUrl [ 0 ] ,
330335 port,
331- ...socketOptions
336+ ...autoSelectSocketOptions ( this . options . socketOptions || { } )
332337 } ;
333338 const message = request . message ;
334339 const buffer = new BufferPool ( ) ;
@@ -363,7 +368,7 @@ export class StateMachine {
363368 throw error ;
364369 }
365370 try {
366- await this . setTlsOptions ( providerTlsOptions , options ) ;
371+ await this . setTlsOptions ( providerTlsOptions , socketOptions ) ;
367372 } catch ( err ) {
368373 throw onerror ( err ) ;
369374 }
@@ -380,23 +385,25 @@ export class StateMachine {
380385 . once ( 'close' , ( ) => rejectOnNetSocketError ( onclose ( ) ) )
381386 . once ( 'connect' , ( ) => resolveOnNetSocketConnect ( ) ) ;
382387
388+ let abortListener ;
389+
383390 try {
384391 if ( this . options . proxyOptions && this . options . proxyOptions . proxyHost ) {
385392 const netSocketOptions = {
393+ ...socketOptions ,
386394 host : this . options . proxyOptions . proxyHost ,
387- port : this . options . proxyOptions . proxyPort || 1080 ,
388- ...socketOptions
395+ port : this . options . proxyOptions . proxyPort || 1080
389396 } ;
390397 netSocket . connect ( netSocketOptions ) ;
391398 await willConnect ;
392399
393400 try {
394401 socks ??= loadSocks ( ) ;
395- options . socket = (
402+ socketOptions . socket = (
396403 await socks . SocksClient . createConnection ( {
397404 existing_socket : netSocket ,
398405 command : 'connect' ,
399- destination : { host : options . host , port : options . port } ,
406+ destination : { host : socketOptions . host , port : socketOptions . port } ,
400407 proxy : {
401408 // host and port are ignored because we pass existing_socket
402409 host : 'iLoveJavaScript' ,
@@ -412,7 +419,7 @@ export class StateMachine {
412419 }
413420 }
414421
415- socket = tls . connect ( options , ( ) => {
422+ socket = tls . connect ( socketOptions , ( ) => {
416423 socket . write ( message ) ;
417424 } ) ;
418425
@@ -422,6 +429,11 @@ export class StateMachine {
422429 resolve
423430 } = promiseWithResolvers < void > ( ) ;
424431
432+ abortListener = addAbortListener ( options ?. signal , function ( ) {
433+ destroySockets ( ) ;
434+ rejectOnTlsSocketError ( this . reason ) ;
435+ } ) ;
436+
425437 socket
426438 . once ( 'error' , err => rejectOnTlsSocketError ( onerror ( err ) ) )
427439 . once ( 'close' , ( ) => rejectOnTlsSocketError ( onclose ( ) ) )
@@ -436,8 +448,11 @@ export class StateMachine {
436448 resolve ( ) ;
437449 }
438450 } ) ;
439- await ( timeoutContext ?. csotEnabled ( )
440- ? Promise . all ( [ willResolveKmsRequest , Timeout . expires ( timeoutContext ?. remainingTimeMS ) ] )
451+ await ( options ?. timeoutContext ?. csotEnabled ( )
452+ ? Promise . all ( [
453+ willResolveKmsRequest ,
454+ Timeout . expires ( options . timeoutContext ?. remainingTimeMS )
455+ ] )
441456 : willResolveKmsRequest ) ;
442457 } catch ( error ) {
443458 if ( error instanceof TimeoutError )
@@ -446,16 +461,17 @@ export class StateMachine {
446461 } finally {
447462 // There's no need for any more activity on this socket at this point.
448463 destroySockets ( ) ;
464+ abortListener ?. [ kDispose ] ( ) ;
449465 }
450466 }
451467
452- * requests ( context : MongoCryptContext , timeoutContext ?: TimeoutContext ) {
468+ * requests ( context : MongoCryptContext , options ?: { timeoutContext ?: TimeoutContext } & Abortable ) {
453469 for (
454470 let request = context . nextKMSRequest ( ) ;
455471 request != null ;
456472 request = context . nextKMSRequest ( )
457473 ) {
458- yield this . kmsRequest ( request , timeoutContext ) ;
474+ yield this . kmsRequest ( request , options ) ;
459475 }
460476 }
461477
@@ -516,14 +532,16 @@ export class StateMachine {
516532 client : MongoClient ,
517533 ns : string ,
518534 filter : Document ,
519- timeoutContext ?: TimeoutContext
535+ options ?: { timeoutContext ?: TimeoutContext } & Abortable
520536 ) : Promise < Uint8Array | null > {
521537 const { db } = MongoDBCollectionNamespace . fromString ( ns ) ;
522538
523539 const cursor = client . db ( db ) . listCollections ( filter , {
524540 promoteLongs : false ,
525541 promoteValues : false ,
526- timeoutContext : timeoutContext && new CursorTimeoutContext ( timeoutContext , Symbol ( ) )
542+ timeoutContext :
543+ options ?. timeoutContext && new CursorTimeoutContext ( options ?. timeoutContext , Symbol ( ) ) ,
544+ signal : options ?. signal
527545 } ) ;
528546
529547 // There is always exactly zero or one matching documents, so this should always exhaust the cursor
@@ -547,17 +565,30 @@ export class StateMachine {
547565 client : MongoClient ,
548566 ns : string ,
549567 command : Uint8Array ,
550- timeoutContext ?: TimeoutContext
568+ options ?: { timeoutContext ?: TimeoutContext } & Abortable
551569 ) : Promise < Uint8Array > {
552570 const { db } = MongoDBCollectionNamespace . fromString ( ns ) ;
553571 const bsonOptions = { promoteLongs : false , promoteValues : false } ;
554572 const rawCommand = deserialize ( command , bsonOptions ) ;
555573
574+ const commandOptions : {
575+ timeoutMS ?: number ;
576+ signal ?: AbortSignal ;
577+ } = {
578+ timeoutMS : undefined ,
579+ signal : undefined
580+ } ;
581+
582+ if ( options ?. timeoutContext ?. csotEnabled ( ) ) {
583+ commandOptions . timeoutMS = options . timeoutContext . remainingTimeMS ;
584+ }
585+ if ( options ?. signal ) {
586+ commandOptions . signal = options . signal ;
587+ }
588+
556589 const response = await client . db ( db ) . command ( rawCommand , {
557590 ...bsonOptions ,
558- ...( timeoutContext ?. csotEnabled ( )
559- ? { timeoutMS : timeoutContext ?. remainingTimeMS }
560- : undefined )
591+ ...commandOptions
561592 } ) ;
562593
563594 return serialize ( response , this . bsonOptions ) ;
@@ -575,17 +606,30 @@ export class StateMachine {
575606 client : MongoClient ,
576607 keyVaultNamespace : string ,
577608 filter : Uint8Array ,
578- timeoutContext ?: TimeoutContext
609+ options ?: { timeoutContext ?: TimeoutContext } & Abortable
579610 ) : Promise < Array < DataKey > > {
580611 const { db : dbName , collection : collectionName } =
581612 MongoDBCollectionNamespace . fromString ( keyVaultNamespace ) ;
582613
614+ const commandOptions : {
615+ timeoutContext ?: CursorTimeoutContext ;
616+ signal ?: AbortSignal ;
617+ } = {
618+ timeoutContext : undefined ,
619+ signal : undefined
620+ } ;
621+
622+ if ( options ?. timeoutContext != null ) {
623+ commandOptions . timeoutContext = new CursorTimeoutContext ( options . timeoutContext , Symbol ( ) ) ;
624+ }
625+ if ( options ?. signal != null ) {
626+ commandOptions . signal = options . signal ;
627+ }
628+
583629 return client
584630 . db ( dbName )
585631 . collection < DataKey > ( collectionName , { readConcern : { level : 'majority' } } )
586- . find ( deserialize ( filter ) , {
587- timeoutContext : timeoutContext && new CursorTimeoutContext ( timeoutContext , Symbol ( ) )
588- } )
632+ . find ( deserialize ( filter ) , commandOptions )
589633 . toArray ( ) ;
590634 }
591635}
0 commit comments