diff --git a/firebase_admin/_sseclient.py b/firebase_admin/_sseclient.py new file mode 100644 index 000000000..26dd66977 --- /dev/null +++ b/firebase_admin/_sseclient.py @@ -0,0 +1,206 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SSEClient module to handle streaming of realtime changes on the database +to the firebase-admin-sdk +""" + +import re +import time +import warnings +import six +import requests + + +# Technically, we should support streams that mix line endings. This regex, +# however, assumes that a system will provide consistent line endings. +end_of_field = re.compile(r'\r\n\r\n|\r\r|\n\n') + + +class KeepAuthSession(requests.Session): + """A session that does not drop Authentication on redirects between domains""" + def rebuild_auth(self, prepared_request, response): + pass + + +class SSEClient(object): + """SSE Client Class""" + + def __init__(self, url, session, last_id=None, retry=3000, **kwargs): + """Initialize the SSEClient + Args: + url: the url to connect to + session: the requests.session() + last_id: optional id + retry: the interval in ms + **kwargs: extra kwargs will be sent to requests.get + """ + self.should_connect = True + self.url = url + self.last_id = last_id + self.retry = retry + self.session = session + self.requests_kwargs = kwargs + + headers = self.requests_kwargs.get('headers', {}) + # The SSE spec requires making requests with Cache-Control: nocache + headers['Cache-Control'] = 'no-cache' + # The 'Accept' header is not required, but explicit > implicit + headers['Accept'] = 'text/event-stream' + + self.requests_kwargs['headers'] = headers + + # Keep data here as it streams in + self.buf = u'' + + self._connect() + + def close(self): + """Close the SSE Client instance""" + # TODO: check if AttributeError is needed to catch here + self.should_connect = False + self.retry = 0 + self.resp.close() + # self.resp.raw._fp.fp.raw._sock.shutdown(socket.SHUT_RDWR) + # self.resp.raw._fp.fp.raw._sock.close() + + + def _connect(self): + """connects to the server using requests""" + if self.should_connect: + success = False + while not success: + if self.last_id: + self.requests_kwargs['headers']['Last-Event-ID'] = self.last_id + # Use session if set. Otherwise fall back to requests module. + self.requester = self.session or requests + self.resp = self.requester.get(self.url, stream=True, **self.requests_kwargs) + + self.resp_iterator = self.resp.iter_content(decode_unicode=True) + + # TODO: Ensure we're handling redirects. Might also stick the 'origin' + # attribute on Events like the Javascript spec requires. + self.resp.raise_for_status() + success = True + else: + raise StopIteration() + + def _event_complete(self): + """Checks if the event is completed by matching regular expression + + Returns: + boolean: True if the regex matched meaning end of event, else False + """ + return re.search(end_of_field, self.buf) is not None + + def __iter__(self): + return self + + def __next__(self): + while not self._event_complete(): + try: + nextchar = next(self.resp_iterator) + self.buf += nextchar + except (StopIteration, requests.RequestException): + time.sleep(self.retry / 1000.0) + self._connect() + + + # The SSE spec only supports resuming from a whole message, so + # if we have half a message we should throw it out. + head, sep, tail = self.buf.rpartition('\n') + self.buf = head + sep + continue + + split = re.split(end_of_field, self.buf) + head = split[0] + tail = "".join(split[1:]) + + self.buf = tail + msg = Event.parse(head) + + if msg.data == "credential is no longer valid": + self._connect() + return None + + if msg.data == 'null': + return None + + # If the server requests a specific retry delay, we need to honor it. + if msg.retry: + self.retry = msg.retry + + # last_id should only be set if included in the message. It's not + # forgotten if a message omits it. + if msg.event_id: + self.last_id = msg.event_id + + return msg + + if six.PY2: + next = __next__ + + +class Event(object): + """Event class to handle the events fired by SSE""" + + sse_line_pattern = re.compile('(?P[^:]*):?( ?(?P.*))?') + + def __init__(self, data='', event='message', event_id=None, retry=None): + self.data = data + self.event = event + self.event_id = event_id + self.retry = retry + + @classmethod + def parse(cls, raw): + """Given a possibly-multiline string representing an SSE message, parse it + and return a Event object. + + Args: + raw: the raw data to parse + + Returns: + Event: newly intialized Event() object with the parameters initialized + """ + msg = cls() + for line in raw.split('\n'): + match = cls.sse_line_pattern.match(line) + if match is None: + # Malformed line. Discard but warn. + warnings.warn('Invalid SSE line: "%s"' % line, SyntaxWarning) + continue + + name = match.groupdict()['name'] + value = match.groupdict()['value'] + if name == '': + # line began with a ":", so is a comment. Ignore + continue + + if name == 'data': + # If we already have some data, then join to it with a newline. + # Else this is it. + if msg.data: + msg.data = '%s\n%s' % (msg.data, value) + else: + msg.data = value + elif name == 'event': + msg.event = value + elif name == 'id': + msg.event_id = value + elif name == 'retry': + msg.retry = int(value) + + return msg + + def __str__(self): + return self.data diff --git a/firebase_admin/db.py b/firebase_admin/db.py index 43ca7c495..508d7db4b 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -23,6 +23,7 @@ import collections import json import sys +import threading import requests import six @@ -31,8 +32,14 @@ import firebase_admin from firebase_admin import _http_client from firebase_admin import _utils +from firebase_admin import _sseclient +try: + from urllib.parse import urlencode +except ImportError: + from urllib import urlencode + _DB_ATTRIBUTE = '_database' _INVALID_PATH_CHARACTERS = '[].?#$' _RESERVED_FILTERS = ('$key', '$value', '$priority') @@ -75,6 +82,47 @@ def _parse_path(path): return [seg for seg in path.split('/') if seg] +class ListenerRegistration(object): + """Class that handles the streaming of data node changes from server""" + def __init__(self, url, stream_handler): + """Initialize a new ListenerRegistration object with given parameters + + Args: + url: the data node url to listen for changes + stream_handler: the callback function to fire in case of event + """ + self.url = url + self.stream_handler = stream_handler + self.sse = None + self.thread = None + self.start() + + def start(self): + """Start the streaming by spawning a thread""" + self.sse = _sseclient.SSEClient( + self.url, + session=_sseclient.KeepAuthSession() + ) + self.thread = threading.Thread(target=self.start_stream) + self.thread.start() + return self + + def start_stream(self): + """Streaming function for the spawned thread to run""" + for msg in self.sse: + # iterate the sse client's generator + if msg: + msg_data = json.loads(msg.data) + msg_data["event"] = msg.event + self.stream_handler(msg_data) + + def close(self): + """Terminates SSE server connection and joins the thread""" + self.sse.running = False + self.sse.close() + self.thread.join() + + class Reference(object): """Reference represents a node in the Firebase realtime database.""" @@ -107,6 +155,22 @@ def parent(self): return Reference(client=self._client, segments=self._segments[:-1]) return None + def listen(self, stream_handler): + """Function to setup the streaming of data from server data node changes + + Args: + stream_handler: A function to callback in the event of data node change detected + + Returns: + object: Returns a ListenerRegistration object which handles the stream + """ + parameters = {} + # reset path and build_query for next query + request_ref = '{}{}.json?{}'.format( + self._client.base_url, self._pathurl, urlencode(parameters) + ) + return ListenerRegistration(request_ref, stream_handler) + def child(self, path): """Returns a Reference to the specified child node. diff --git a/tests/test_db.py b/tests/test_db.py index 346977a11..3e3044c59 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -523,7 +523,7 @@ def test_range_query(self): assert recorder[0].headers['User-Agent'] == db._USER_AGENT -class TestDatabseInitialization(object): +class TestDatabaseInitialization(object): """Test cases for database initialization.""" def teardown_method(self): diff --git a/tests/test_sseclient.py b/tests/test_sseclient.py new file mode 100644 index 000000000..422fef172 --- /dev/null +++ b/tests/test_sseclient.py @@ -0,0 +1,69 @@ +"""Tests for firebase_admin.sseclient.""" +import json +import six +import requests + +from firebase_admin import _sseclient +from tests.testutils import MockAdapter + + +class MockSSEClient(MockAdapter): + def __init__(self, payload): + status = 200 + recorder = [] + MockAdapter.__init__(self, payload, status, recorder) + + def send(self, request, **kwargs): + resp = requests.models.Response() + resp.url = request.url + resp.status_code = self._status + resp.raw = six.BytesIO(self._data.encode()) + resp.encoding = "utf-8" + return resp + + +class TestSSEClient(object): + """Test cases for the SSEClient""" + + test_url = "https://test.firebaseio.com" + + + def init_sse(self): + payload = 'event: put\ndata: {"path":"/","data":"testevent"}\n\n' + + adapter = MockSSEClient(payload) + session = _sseclient.KeepAuthSession() + session.mount(self.test_url, adapter) + + sseclient = _sseclient.SSEClient(url=self.test_url, session=session) + return sseclient + + + def test_init_sseclient(self): + sseclient = self.init_sse() + + assert sseclient.url == self.test_url + assert sseclient.session != None + + def test_event(self): + sseclient = self.init_sse() + msg = next(sseclient) + event = json.loads(msg.data) + assert event["data"] == "testevent" + assert event["path"] == "/" + + +class TestEvent(object): + """Test cases for Events""" + + def test_normal(self): + data = 'event: put\ndata: {"path":"/","data":"testdata"}' + event = _sseclient.Event.parse(data) + assert event.event == "put" + assert event.data == '{"path":"/","data":"testdata"}' + + def test_invalid(self): + data = 'event: invalid_event' + event = _sseclient.Event.parse(data) + assert event.event == "invalid_event" + assert event.data == ''