@@ -57,7 +57,8 @@ def _apply_compat(events):
57
57
yield event
58
58
59
59
60
- def _create_mock_client ():
60
+
61
+ def _create_mock_client (grpc_exception = None ):
61
62
# Create a stub instance (using a test channel) in order to derive a mock
62
63
# from it with autospec enabled. Mocking TensorBoardWriterServiceStub itself
63
64
# doesn't work with autospec because grpc constructs stubs via metaclassing.
@@ -66,6 +67,30 @@ def _create_mock_client():
66
67
)
67
68
stub = write_service_pb2_grpc .TensorBoardWriterServiceStub (test_channel )
68
69
mock_client = mock .create_autospec (stub )
70
+
71
+ # Some surgery is required to make the mock object returned from the call to
72
+ # gRPC's `future` async endpoint act like the `grpc.Future` it is supposed
73
+ # to be. It needs to somewhat faithfully reproduce the following methods
74
+ # which are used in the grpc_util.async_call_with_retries machinery.
75
+ # * `done`
76
+ # * `result`
77
+ # * `exception`
78
+ grpc_future = mock .MagicMock ()
79
+ grpc_future .done .return_value = True
80
+ if grpc_exception :
81
+ grpc_future .result .side_effect = grpc_exception
82
+ grpc_future .exception .return_value = grpc_exception
83
+ else :
84
+ grpc_future .exception .return_value = None
85
+ mock_client .WriteScalar .future .return_value = grpc_future
86
+
87
+ # This is reimplementing the gRPC future behavior wherein a callback
88
+ # can be invoked when the future is complete. In this case, the callback
89
+ # must be invoked with the grpc_future as its argument.
90
+ def _invoke_callback (handler ):
91
+ handler (grpc_future )
92
+
93
+ grpc_future .add_done_callback .side_effect = _invoke_callback
69
94
return mock_client
70
95
71
96
@@ -117,8 +142,9 @@ def _add_events_and_flush(self, events):
117
142
)
118
143
self ._add_events (sender , "" , events )
119
144
sender .flush ()
145
+ sender .complete_all_pending_futures ()
120
146
121
- requests = [c [0 ][0 ] for c in mock_client .WriteScalar .call_args_list ]
147
+ requests = [c [0 ][0 ] for c in mock_client .WriteScalar .future . call_args_list ]
122
148
self .assertLen (requests , 1 )
123
149
self .assertLen (requests [0 ].runs , 1 )
124
150
return requests [0 ].runs [0 ]
@@ -205,18 +231,22 @@ def test_v2_summary(self):
205
231
)
206
232
self .assertProtoEquals (run_proto , expected_run_proto )
207
233
208
- def test_propagates_experiment_deletion (self ):
234
+ def test_async_propagates_experiment_deletion (self ):
235
+ # Setup: A client which return a grpc_error when asked to write.
209
236
event = event_pb2 .Event (step = 1 )
210
237
event .summary .value .add (tag = "foo" , simple_value = 1.0 )
211
238
212
- mock_client = _create_mock_client ()
239
+ error = test_util .grpc_error (grpc .StatusCode .NOT_FOUND , "nope" )
240
+ mock_client = _create_mock_client (grpc_exception = error )
213
241
sender = _create_scalar_request_sender ("123" , mock_client )
242
+ # Execute: sender.add_event on all the events.
214
243
self ._add_events (sender , "run" , _apply_compat ([event ]))
215
-
216
- error = test_util .grpc_error (grpc .StatusCode .NOT_FOUND , "nope" )
217
- mock_client .WriteScalar .side_effect = error
244
+ # Expect: Exception is raised, but it is the transforemd
245
+ # uploader exception, not the original grpc exception.
218
246
with self .assertRaises (uploader_errors .ExperimentNotFoundError ):
219
247
sender .flush ()
248
+ sender .complete_all_pending_futures ()
249
+
220
250
221
251
def test_no_budget_for_base_request (self ):
222
252
mock_client = _create_mock_client ()
@@ -263,7 +293,8 @@ def test_break_at_run_boundary(self):
263
293
self ._add_events (sender , long_run_1 , _apply_compat ([event_1 ]))
264
294
self ._add_events (sender , long_run_2 , _apply_compat ([event_2 ]))
265
295
sender .flush ()
266
- requests = [c [0 ][0 ] for c in mock_client .WriteScalar .call_args_list ]
296
+ sender .complete_all_pending_futures ()
297
+ requests = [c [0 ][0 ] for c in mock_client .WriteScalar .future .call_args_list ]
267
298
268
299
for request in requests :
269
300
_clear_wall_times (request )
@@ -306,7 +337,8 @@ def test_break_at_tag_boundary(self):
306
337
)
307
338
self ._add_events (sender , "train" , _apply_compat ([event ]))
308
339
sender .flush ()
309
- requests = [c [0 ][0 ] for c in mock_client .WriteScalar .call_args_list ]
340
+ sender .complete_all_pending_futures ()
341
+ requests = [c [0 ][0 ] for c in mock_client .WriteScalar .future .call_args_list ]
310
342
for request in requests :
311
343
_clear_wall_times (request )
312
344
@@ -352,7 +384,8 @@ def test_break_at_scalar_point_boundary(self):
352
384
)
353
385
self ._add_events (sender , "train" , _apply_compat (events ))
354
386
sender .flush ()
355
- requests = [c [0 ][0 ] for c in mock_client .WriteScalar .call_args_list ]
387
+ sender .complete_all_pending_futures ()
388
+ requests = [c [0 ][0 ] for c in mock_client .WriteScalar .future .call_args_list ]
356
389
for request in requests :
357
390
_clear_wall_times (request )
358
391
@@ -407,7 +440,8 @@ def mock_add_point(byte_budget_manager_self, point):
407
440
self ._add_events (sender , "train" , _apply_compat ([event_1 ]))
408
441
self ._add_events (sender , "test" , _apply_compat ([event_2 ]))
409
442
sender .flush ()
410
- requests = [c [0 ][0 ] for c in mock_client .WriteScalar .call_args_list ]
443
+ sender .complete_all_pending_futures ()
444
+ requests = [c [0 ][0 ] for c in mock_client .WriteScalar .future .call_args_list ]
411
445
for request in requests :
412
446
_clear_wall_times (request )
413
447
0 commit comments