@@ -12,7 +12,9 @@ import {
1212} from '../bson' ;
1313import { type ProxyOptions } from '../cmap/connection' ;
1414import { getSocks , type SocksLib } from '../deps' ;
15+ import { MongoOperationTimeoutError } from '../error' ;
1516import { type MongoClient , type MongoClientOptions } from '../mongo_client' ;
17+ import { Timeout , type TimeoutContext , TimeoutError } from '../timeout' ;
1618import { BufferPool , MongoDBCollectionNamespace , promiseWithResolvers } from '../utils' ;
1719import { autoSelectSocketOptions , type DataKey } from './client_encryption' ;
1820import { MongoCryptError } from './errors' ;
@@ -173,6 +175,7 @@ export type StateMachineOptions = {
173175 * An internal class that executes across a MongoCryptContext until either
174176 * a finishing state or an error is reached. Do not instantiate directly.
175177 */
178+ // TODO(DRIVERS-2671): clarify CSOT behavior for FLE APIs
176179export class StateMachine {
177180 constructor (
178181 private options : StateMachineOptions ,
@@ -182,7 +185,11 @@ export class StateMachine {
182185 /**
183186 * Executes the state machine according to the specification
184187 */
185- async execute ( executor : StateMachineExecutable , context : MongoCryptContext ) : Promise < Uint8Array > {
188+ async execute (
189+ executor : StateMachineExecutable ,
190+ context : MongoCryptContext ,
191+ timeoutContext ?: TimeoutContext
192+ ) : Promise < Uint8Array > {
186193 const keyVaultNamespace = executor . _keyVaultNamespace ;
187194 const keyVaultClient = executor . _keyVaultClient ;
188195 const metaDataClient = executor . _metaDataClient ;
@@ -201,8 +208,13 @@ export class StateMachine {
201208 'unreachable state machine state: entered MONGOCRYPT_CTX_NEED_MONGO_COLLINFO but metadata client is undefined'
202209 ) ;
203210 }
204- const collInfo = await this . fetchCollectionInfo ( metaDataClient , context . ns , filter ) ;
205211
212+ const collInfo = await this . fetchCollectionInfo (
213+ metaDataClient ,
214+ context . ns ,
215+ filter ,
216+ timeoutContext
217+ ) ;
206218 if ( collInfo ) {
207219 context . addMongoOperationResponse ( collInfo ) ;
208220 }
@@ -222,9 +234,9 @@ export class StateMachine {
222234 // When we are using the shared library, we don't have a mongocryptd manager.
223235 const markedCommand : Uint8Array = mongocryptdManager
224236 ? await mongocryptdManager . withRespawn (
225- this . markCommand . bind ( this , mongocryptdClient , context . ns , command )
237+ this . markCommand . bind ( this , mongocryptdClient , context . ns , command , timeoutContext )
226238 )
227- : await this . markCommand ( mongocryptdClient , context . ns , command ) ;
239+ : await this . markCommand ( mongocryptdClient , context . ns , command , timeoutContext ) ;
228240
229241 context . addMongoOperationResponse ( markedCommand ) ;
230242 context . finishMongoOperation ( ) ;
@@ -233,7 +245,12 @@ export class StateMachine {
233245
234246 case MONGOCRYPT_CTX_NEED_MONGO_KEYS : {
235247 const filter = context . nextMongoOperation ( ) ;
236- const keys = await this . fetchKeys ( keyVaultClient , keyVaultNamespace , filter ) ;
248+ const keys = await this . fetchKeys (
249+ keyVaultClient ,
250+ keyVaultNamespace ,
251+ filter ,
252+ timeoutContext
253+ ) ;
237254
238255 if ( keys . length === 0 ) {
239256 // See docs on EMPTY_V
@@ -255,9 +272,7 @@ export class StateMachine {
255272 }
256273
257274 case MONGOCRYPT_CTX_NEED_KMS : {
258- const requests = Array . from ( this . requests ( context ) ) ;
259- await Promise . all ( requests ) ;
260-
275+ await Promise . all ( this . requests ( context , timeoutContext ) ) ;
261276 context . finishKMSRequests ( ) ;
262277 break ;
263278 }
@@ -299,7 +314,7 @@ export class StateMachine {
299314 * @param kmsContext - A C++ KMS context returned from the bindings
300315 * @returns A promise that resolves when the KMS reply has be fully parsed
301316 */
302- async kmsRequest ( request : MongoCryptKMSRequest ) : Promise < void > {
317+ async kmsRequest ( request : MongoCryptKMSRequest , timeoutContext ?: TimeoutContext ) : Promise < void > {
303318 const parsedUrl = request . endpoint . split ( ':' ) ;
304319 const port = parsedUrl [ 1 ] != null ? Number . parseInt ( parsedUrl [ 1 ] , 10 ) : HTTPS_PORT ;
305320 const socketOptions = autoSelectSocketOptions ( this . options . socketOptions || { } ) ;
@@ -329,10 +344,6 @@ export class StateMachine {
329344 }
330345 }
331346
332- function ontimeout ( ) {
333- return new MongoCryptError ( 'KMS request timed out' ) ;
334- }
335-
336347 function onerror ( cause : Error ) {
337348 return new MongoCryptError ( 'KMS request failed' , { cause } ) ;
338349 }
@@ -364,7 +375,6 @@ export class StateMachine {
364375 resolve : resolveOnNetSocketConnect
365376 } = promiseWithResolvers < void > ( ) ;
366377 netSocket
367- . once ( 'timeout' , ( ) => rejectOnNetSocketError ( ontimeout ( ) ) )
368378 . once ( 'error' , err => rejectOnNetSocketError ( onerror ( err ) ) )
369379 . once ( 'close' , ( ) => rejectOnNetSocketError ( onclose ( ) ) )
370380 . once ( 'connect' , ( ) => resolveOnNetSocketConnect ( ) ) ;
@@ -410,8 +420,8 @@ export class StateMachine {
410420 reject : rejectOnTlsSocketError ,
411421 resolve
412422 } = promiseWithResolvers < void > ( ) ;
423+
413424 socket
414- . once ( 'timeout' , ( ) => rejectOnTlsSocketError ( ontimeout ( ) ) )
415425 . once ( 'error' , err => rejectOnTlsSocketError ( onerror ( err ) ) )
416426 . once ( 'close' , ( ) => rejectOnTlsSocketError ( onclose ( ) ) )
417427 . on ( 'data' , data => {
@@ -425,20 +435,26 @@ export class StateMachine {
425435 resolve ( ) ;
426436 }
427437 } ) ;
428- await willResolveKmsRequest ;
438+ await ( timeoutContext ?. csotEnabled ( )
439+ ? Promise . all ( [ willResolveKmsRequest , Timeout . expires ( timeoutContext ?. remainingTimeMS ) ] )
440+ : willResolveKmsRequest ) ;
441+ } catch ( error ) {
442+ if ( error instanceof TimeoutError )
443+ throw new MongoOperationTimeoutError ( 'KMS request timed out' ) ;
444+ throw error ;
429445 } finally {
430446 // There's no need for any more activity on this socket at this point.
431447 destroySockets ( ) ;
432448 }
433449 }
434450
435- * requests ( context : MongoCryptContext ) {
451+ * requests ( context : MongoCryptContext , timeoutContext ?: TimeoutContext ) {
436452 for (
437453 let request = context . nextKMSRequest ( ) ;
438454 request != null ;
439455 request = context . nextKMSRequest ( )
440456 ) {
441- yield this . kmsRequest ( request ) ;
457+ yield this . kmsRequest ( request , timeoutContext ) ;
442458 }
443459 }
444460
@@ -498,15 +514,19 @@ export class StateMachine {
498514 async fetchCollectionInfo (
499515 client : MongoClient ,
500516 ns : string ,
501- filter : Document
517+ filter : Document ,
518+ timeoutContext ?: TimeoutContext
502519 ) : Promise < Uint8Array | null > {
503520 const { db } = MongoDBCollectionNamespace . fromString ( ns ) ;
504521
505522 const collections = await client
506523 . db ( db )
507524 . listCollections ( filter , {
508525 promoteLongs : false ,
509- promoteValues : false
526+ promoteValues : false ,
527+ ...( timeoutContext ?. csotEnabled ( )
528+ ? { timeoutMS : timeoutContext ?. remainingTimeMS , timeoutMode : 'cursorLifetime' }
529+ : { } )
510530 } )
511531 . toArray ( ) ;
512532
@@ -522,12 +542,22 @@ export class StateMachine {
522542 * @param command - The command to execute.
523543 * @param callback - Invoked with the serialized and marked bson command, or with an error
524544 */
525- async markCommand ( client : MongoClient , ns : string , command : Uint8Array ) : Promise < Uint8Array > {
526- const options = { promoteLongs : false , promoteValues : false } ;
545+ async markCommand (
546+ client : MongoClient ,
547+ ns : string ,
548+ command : Uint8Array ,
549+ timeoutContext ?: TimeoutContext
550+ ) : Promise < Uint8Array > {
527551 const { db } = MongoDBCollectionNamespace . fromString ( ns ) ;
528- const rawCommand = deserialize ( command , options ) ;
552+ const bsonOptions = { promoteLongs : false , promoteValues : false } ;
553+ const rawCommand = deserialize ( command , bsonOptions ) ;
529554
530- const response = await client . db ( db ) . command ( rawCommand , options ) ;
555+ const response = await client . db ( db ) . command ( rawCommand , {
556+ ...bsonOptions ,
557+ ...( timeoutContext ?. csotEnabled ( )
558+ ? { timeoutMS : timeoutContext ?. remainingTimeMS }
559+ : undefined )
560+ } ) ;
531561
532562 return serialize ( response , this . bsonOptions ) ;
533563 }
@@ -543,15 +573,21 @@ export class StateMachine {
543573 fetchKeys (
544574 client : MongoClient ,
545575 keyVaultNamespace : string ,
546- filter : Uint8Array
576+ filter : Uint8Array ,
577+ timeoutContext ?: TimeoutContext
547578 ) : Promise < Array < DataKey > > {
548579 const { db : dbName , collection : collectionName } =
549580 MongoDBCollectionNamespace . fromString ( keyVaultNamespace ) ;
550581
551582 return client
552583 . db ( dbName )
553584 . collection < DataKey > ( collectionName , { readConcern : { level : 'majority' } } )
554- . find ( deserialize ( filter ) )
585+ . find (
586+ deserialize ( filter ) ,
587+ timeoutContext ?. csotEnabled ( )
588+ ? { timeoutMS : timeoutContext ?. remainingTimeMS , timeoutMode : 'cursorLifetime' }
589+ : { }
590+ )
555591 . toArray ( ) ;
556592 }
557593}
0 commit comments