From c280b773fbe93681bb496acab1ef26c66b9e1aff Mon Sep 17 00:00:00 2001 From: Min RK Date: Thu, 21 Jan 2016 11:26:27 +0100 Subject: [PATCH] use session.deserialize to unpack message for rate limiting rather than hardcoding json.loads Messages should **never** be deserialized by any means other than the Session API. --- notebook/base/zmqhandlers.py | 19 +++++++++++++------ notebook/services/kernels/handlers.py | 7 ++++--- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/notebook/base/zmqhandlers.py b/notebook/base/zmqhandlers.py index 6a7b5bdcbe..1cfa8ecb34 100644 --- a/notebook/base/zmqhandlers.py +++ b/notebook/base/zmqhandlers.py @@ -218,16 +218,23 @@ def send_error(self, *args, **kwargs): self.stream.close() - def _reserialize_reply(self, msg_list, channel=None): + def _reserialize_reply(self, msg_or_list, channel=None): """Reserialize a reply message using JSON. - This takes the msg list from the ZMQ socket, deserializes it using - self.session and then serializes the result using JSON. This method - should be used by self._on_zmq_reply to build messages that can + msg_or_list can be an already-deserialized msg dict or the zmq buffer list. + If it is the zmq list, it will be deserialized with self.session. + + This takes the msg list from the ZMQ socket and serializes the result for the websocket. + This method should be used by self._on_zmq_reply to build messages that can be sent back to the browser. + """ - idents, msg_list = self.session.feed_identities(msg_list) - msg = self.session.deserialize(msg_list) + if isinstance(msg_or_list, dict): + # already unpacked + msg = msg_or_list + else: + idents, msg_list = self.session.feed_identities(msg_or_list) + msg = self.session.deserialize(msg_list) if channel: msg['channel'] = channel if msg['buffers']: diff --git a/notebook/services/kernels/handlers.py b/notebook/services/kernels/handlers.py index fa7a9283b8..aa5b319791 100644 --- a/notebook/services/kernels/handlers.py +++ b/notebook/services/kernels/handlers.py @@ -269,9 +269,10 @@ def on_message(self, msg): def _on_zmq_reply(self, stream, msg_list): idents, fed_msg_list = self.session.feed_identities(msg_list) + msg = self.session.deserialize(fed_msg_list) + parent = msg['parent_header'] def write_stderr(error_message): self.log.warn(error_message) - parent = json.loads(fed_msg_list[2]) msg = self.session.msg("stream", content={"text": error_message, "name": "stderr"}, parent=parent @@ -280,7 +281,7 @@ def write_stderr(error_message): self.write_message(json.dumps(msg, default=date_default)) channel = getattr(stream, 'channel', None) - msg_type = json.loads(fed_msg_list[1])['msg_type'] + msg_type = msg['header']['msg_type'] if channel == 'iopub' and msg_type not in {'status', 'comm_open', 'execute_input'}: # Remove the counts queued for removal. @@ -345,7 +346,7 @@ def write_stderr(error_message): # If either of the limit flags are set, do not send the message. if self._iopub_msgs_exceeded or self._iopub_data_exceeded: return - super(ZMQChannelsHandler, self)._on_zmq_reply(stream, msg_list) + super(ZMQChannelsHandler, self)._on_zmq_reply(stream, msg) def on_close(self):