Skip to content

Commit ce24d3d

Browse files
committed
upload works (too fast?)
1 parent 57f97ba commit ce24d3d

File tree

3 files changed

+64
-15
lines changed

3 files changed

+64
-15
lines changed

tensorboard/uploader/orchestration/batched_request_sender.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,13 @@ def send_requests(self, run_to_events):
189189
run_name, event, value, metadata
190190
)
191191

192+
# Send requests corresponding to whatever remaining events.
192193
self._scalar_request_sender.flush()
193194
self._tensor_request_sender.flush()
194195
self._blob_request_sender.flush()
196+
# Wait for asynchronous calls to complete.
197+
self._scalar_request_sender.complete_all_pending_futures()
198+
195199

196200
def _run_values(self, run_to_events):
197201
"""Helper generator to create a single stream of work items.

tensorboard/uploader/orchestration/scalar_batched_request_sender.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
logger = tb_logging.get_logger()
3131

32+
# How long to wait for a response on a scalar write request.
33+
_SCALAR_WRITE_TIMEOUT_SECS = 30
3234

3335
def _prune_empty_tags_and_runs(request):
3436
for (run_idx, run) in reversed(list(enumerate(request.runs))):
@@ -149,7 +151,7 @@ def flush(self):
149151
return
150152

151153
self._rpc_rate_limiter.tick()
152-
self.complete_grpc_futures(grpc_futures)
154+
self._groom_grpc_futures()
153155

154156
with _request_logger(
155157
request, request.runs
@@ -160,8 +162,8 @@ def flush(self):
160162

161163
self._new_request()
162164

163-
def complete_grpc_futures(self, grpc_futures):
164-
"""Handle any excptions."""
165+
def _groom_grpc_futures(self):
166+
"""Handle any excptions, remove completed futures."""
165167
done_futures = []
166168
# Check if any exceptions raised, collect indicies of futures which can
167169
# be removed.
@@ -172,7 +174,7 @@ def complete_grpc_futures(self, grpc_futures):
172174
# WriteScalar RPCs, but if we did, it would go here. This
173175
# call to result will raise any exception caused in the
174176
# gRPC call.
175-
future.result()
177+
future.result(_SCALAR_WRITE_TIMEOUT_SECS)
176178
done_futures.append(i)
177179
except grpc.RpcError as e:
178180
if e.code() == grpc.StatusCode.NOT_FOUND:
@@ -183,6 +185,15 @@ def complete_grpc_futures(self, grpc_futures):
183185
for i in done_futures:
184186
self._grpc_futures.pop(i)
185187

188+
def complete_all_pending_futures(self):
189+
"""Continuously checks the futures until they are done.
190+
191+
This is guaranteed to complete if the underlying gRPC future
192+
requests are made with timeouts.
193+
"""
194+
while self._grpc_futures:
195+
self._groom_grpc_futures()
196+
186197

187198
def _create_run(self, run_name):
188199
"""Adds a run to the live request, if there's space.

tensorboard/uploader/orchestration/scalar_batched_request_sender_test.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ def _apply_compat(events):
5757
yield event
5858

5959

60-
def _create_mock_client():
60+
61+
def _create_mock_client(grpc_exception=None):
6162
# Create a stub instance (using a test channel) in order to derive a mock
6263
# from it with autospec enabled. Mocking TensorBoardWriterServiceStub itself
6364
# doesn't work with autospec because grpc constructs stubs via metaclassing.
@@ -66,6 +67,30 @@ def _create_mock_client():
6667
)
6768
stub = write_service_pb2_grpc.TensorBoardWriterServiceStub(test_channel)
6869
mock_client = mock.create_autospec(stub)
70+
71+
# Some surgery is required to make the mock object returned from the call to
72+
# gRPC's `future` async endpoint act like the `grpc.Future` it is supposed
73+
# to be. It needs to somewhat faithfully reproduce the following methods
74+
# which are used in the grpc_util.async_call_with_retries machinery.
75+
# * `done`
76+
# * `result`
77+
# * `exception`
78+
grpc_future = mock.MagicMock()
79+
grpc_future.done.return_value = True
80+
if grpc_exception:
81+
grpc_future.result.side_effect = grpc_exception
82+
grpc_future.exception.return_value = grpc_exception
83+
else:
84+
grpc_future.exception.return_value = None
85+
mock_client.WriteScalar.future.return_value = grpc_future
86+
87+
# This is reimplementing the gRPC future behavior wherein a callback
88+
# can be invoked when the future is complete. In this case, the callback
89+
# must be invoked with the grpc_future as its argument.
90+
def _invoke_callback(handler):
91+
handler(grpc_future)
92+
93+
grpc_future.add_done_callback.side_effect = _invoke_callback
6994
return mock_client
7095

7196

@@ -117,8 +142,9 @@ def _add_events_and_flush(self, events):
117142
)
118143
self._add_events(sender, "", events)
119144
sender.flush()
145+
sender.complete_all_pending_futures()
120146

121-
requests = [c[0][0] for c in mock_client.WriteScalar.call_args_list]
147+
requests = [c[0][0] for c in mock_client.WriteScalar.future.call_args_list]
122148
self.assertLen(requests, 1)
123149
self.assertLen(requests[0].runs, 1)
124150
return requests[0].runs[0]
@@ -205,18 +231,22 @@ def test_v2_summary(self):
205231
)
206232
self.assertProtoEquals(run_proto, expected_run_proto)
207233

208-
def test_propagates_experiment_deletion(self):
234+
def test_async_propagates_experiment_deletion(self):
235+
# Setup: A client which return a grpc_error when asked to write.
209236
event = event_pb2.Event(step=1)
210237
event.summary.value.add(tag="foo", simple_value=1.0)
211238

212-
mock_client = _create_mock_client()
239+
error = test_util.grpc_error(grpc.StatusCode.NOT_FOUND, "nope")
240+
mock_client = _create_mock_client(grpc_exception=error)
213241
sender = _create_scalar_request_sender("123", mock_client)
242+
# Execute: sender.add_event on all the events.
214243
self._add_events(sender, "run", _apply_compat([event]))
215-
216-
error = test_util.grpc_error(grpc.StatusCode.NOT_FOUND, "nope")
217-
mock_client.WriteScalar.side_effect = error
244+
# Expect: Exception is raised, but it is the transforemd
245+
# uploader exception, not the original grpc exception.
218246
with self.assertRaises(uploader_errors.ExperimentNotFoundError):
219247
sender.flush()
248+
sender.complete_all_pending_futures()
249+
220250

221251
def test_no_budget_for_base_request(self):
222252
mock_client = _create_mock_client()
@@ -263,7 +293,8 @@ def test_break_at_run_boundary(self):
263293
self._add_events(sender, long_run_1, _apply_compat([event_1]))
264294
self._add_events(sender, long_run_2, _apply_compat([event_2]))
265295
sender.flush()
266-
requests = [c[0][0] for c in mock_client.WriteScalar.call_args_list]
296+
sender.complete_all_pending_futures()
297+
requests = [c[0][0] for c in mock_client.WriteScalar.future.call_args_list]
267298

268299
for request in requests:
269300
_clear_wall_times(request)
@@ -306,7 +337,8 @@ def test_break_at_tag_boundary(self):
306337
)
307338
self._add_events(sender, "train", _apply_compat([event]))
308339
sender.flush()
309-
requests = [c[0][0] for c in mock_client.WriteScalar.call_args_list]
340+
sender.complete_all_pending_futures()
341+
requests = [c[0][0] for c in mock_client.WriteScalar.future.call_args_list]
310342
for request in requests:
311343
_clear_wall_times(request)
312344

@@ -352,7 +384,8 @@ def test_break_at_scalar_point_boundary(self):
352384
)
353385
self._add_events(sender, "train", _apply_compat(events))
354386
sender.flush()
355-
requests = [c[0][0] for c in mock_client.WriteScalar.call_args_list]
387+
sender.complete_all_pending_futures()
388+
requests = [c[0][0] for c in mock_client.WriteScalar.future.call_args_list]
356389
for request in requests:
357390
_clear_wall_times(request)
358391

@@ -407,7 +440,8 @@ def mock_add_point(byte_budget_manager_self, point):
407440
self._add_events(sender, "train", _apply_compat([event_1]))
408441
self._add_events(sender, "test", _apply_compat([event_2]))
409442
sender.flush()
410-
requests = [c[0][0] for c in mock_client.WriteScalar.call_args_list]
443+
sender.complete_all_pending_futures()
444+
requests = [c[0][0] for c in mock_client.WriteScalar.future.call_args_list]
411445
for request in requests:
412446
_clear_wall_times(request)
413447

0 commit comments

Comments
 (0)