Skip to content

Commit 3f1190d

Browse files
the-c0d3rhiranya911
authored andcommitted
[WIP] Adding stream listener to stream changes in child nodes (#183)
* Added SSEClient library * Added Streaming functionality to db.py * Added ignore directive 'protected-access' for db.py in lint.sh * Fixed typo * Renamed file to internal module * Reverted lint.sh, added ignore directive in db.py * Changed import module name to internal name * Fixed pylint protected-access by not calling protected member * Added test_sseclient.py * python2,3 compatibility, fixed encoding issue * Added tests for Event() class * removed build_headers() function * Removed Event().dump() test code * Changed import to match code style * init sseclient to Stream().start() and removed sleep * Added apache license 2.0 header * changed import style, removed for loop in event in test_sseclient.TestEvent * Removed self.running from firebase_admin._sseclient * Renamed Stream class to ListenerRegistration, stream() to listen(), removed 'stream_id' and added more documentation
1 parent 3ff6a07 commit 3f1190d

File tree

4 files changed

+340
-1
lines changed

4 files changed

+340
-1
lines changed

firebase_admin/_sseclient.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
13+
"""SSEClient module to handle streaming of realtime changes on the database
14+
to the firebase-admin-sdk
15+
"""
16+
17+
import re
18+
import time
19+
import warnings
20+
import six
21+
import requests
22+
23+
24+
# Technically, we should support streams that mix line endings. This regex,
25+
# however, assumes that a system will provide consistent line endings.
26+
end_of_field = re.compile(r'\r\n\r\n|\r\r|\n\n')
27+
28+
29+
class KeepAuthSession(requests.Session):
30+
"""A session that does not drop Authentication on redirects between domains"""
31+
def rebuild_auth(self, prepared_request, response):
32+
pass
33+
34+
35+
class SSEClient(object):
36+
"""SSE Client Class"""
37+
38+
def __init__(self, url, session, last_id=None, retry=3000, **kwargs):
39+
"""Initialize the SSEClient
40+
Args:
41+
url: the url to connect to
42+
session: the requests.session()
43+
last_id: optional id
44+
retry: the interval in ms
45+
**kwargs: extra kwargs will be sent to requests.get
46+
"""
47+
self.should_connect = True
48+
self.url = url
49+
self.last_id = last_id
50+
self.retry = retry
51+
self.session = session
52+
self.requests_kwargs = kwargs
53+
54+
headers = self.requests_kwargs.get('headers', {})
55+
# The SSE spec requires making requests with Cache-Control: nocache
56+
headers['Cache-Control'] = 'no-cache'
57+
# The 'Accept' header is not required, but explicit > implicit
58+
headers['Accept'] = 'text/event-stream'
59+
60+
self.requests_kwargs['headers'] = headers
61+
62+
# Keep data here as it streams in
63+
self.buf = u''
64+
65+
self._connect()
66+
67+
def close(self):
68+
"""Close the SSE Client instance"""
69+
# TODO: check if AttributeError is needed to catch here
70+
self.should_connect = False
71+
self.retry = 0
72+
self.resp.close()
73+
# self.resp.raw._fp.fp.raw._sock.shutdown(socket.SHUT_RDWR)
74+
# self.resp.raw._fp.fp.raw._sock.close()
75+
76+
77+
def _connect(self):
78+
"""connects to the server using requests"""
79+
if self.should_connect:
80+
success = False
81+
while not success:
82+
if self.last_id:
83+
self.requests_kwargs['headers']['Last-Event-ID'] = self.last_id
84+
# Use session if set. Otherwise fall back to requests module.
85+
self.requester = self.session or requests
86+
self.resp = self.requester.get(self.url, stream=True, **self.requests_kwargs)
87+
88+
self.resp_iterator = self.resp.iter_content(decode_unicode=True)
89+
90+
# TODO: Ensure we're handling redirects. Might also stick the 'origin'
91+
# attribute on Events like the Javascript spec requires.
92+
self.resp.raise_for_status()
93+
success = True
94+
else:
95+
raise StopIteration()
96+
97+
def _event_complete(self):
98+
"""Checks if the event is completed by matching regular expression
99+
100+
Returns:
101+
boolean: True if the regex matched meaning end of event, else False
102+
"""
103+
return re.search(end_of_field, self.buf) is not None
104+
105+
def __iter__(self):
106+
return self
107+
108+
def __next__(self):
109+
while not self._event_complete():
110+
try:
111+
nextchar = next(self.resp_iterator)
112+
self.buf += nextchar
113+
except (StopIteration, requests.RequestException):
114+
time.sleep(self.retry / 1000.0)
115+
self._connect()
116+
117+
118+
# The SSE spec only supports resuming from a whole message, so
119+
# if we have half a message we should throw it out.
120+
head, sep, tail = self.buf.rpartition('\n')
121+
self.buf = head + sep
122+
continue
123+
124+
split = re.split(end_of_field, self.buf)
125+
head = split[0]
126+
tail = "".join(split[1:])
127+
128+
self.buf = tail
129+
msg = Event.parse(head)
130+
131+
if msg.data == "credential is no longer valid":
132+
self._connect()
133+
return None
134+
135+
if msg.data == 'null':
136+
return None
137+
138+
# If the server requests a specific retry delay, we need to honor it.
139+
if msg.retry:
140+
self.retry = msg.retry
141+
142+
# last_id should only be set if included in the message. It's not
143+
# forgotten if a message omits it.
144+
if msg.event_id:
145+
self.last_id = msg.event_id
146+
147+
return msg
148+
149+
if six.PY2:
150+
next = __next__
151+
152+
153+
class Event(object):
154+
"""Event class to handle the events fired by SSE"""
155+
156+
sse_line_pattern = re.compile('(?P<name>[^:]*):?( ?(?P<value>.*))?')
157+
158+
def __init__(self, data='', event='message', event_id=None, retry=None):
159+
self.data = data
160+
self.event = event
161+
self.event_id = event_id
162+
self.retry = retry
163+
164+
@classmethod
165+
def parse(cls, raw):
166+
"""Given a possibly-multiline string representing an SSE message, parse it
167+
and return a Event object.
168+
169+
Args:
170+
raw: the raw data to parse
171+
172+
Returns:
173+
Event: newly intialized Event() object with the parameters initialized
174+
"""
175+
msg = cls()
176+
for line in raw.split('\n'):
177+
match = cls.sse_line_pattern.match(line)
178+
if match is None:
179+
# Malformed line. Discard but warn.
180+
warnings.warn('Invalid SSE line: "%s"' % line, SyntaxWarning)
181+
continue
182+
183+
name = match.groupdict()['name']
184+
value = match.groupdict()['value']
185+
if name == '':
186+
# line began with a ":", so is a comment. Ignore
187+
continue
188+
189+
if name == 'data':
190+
# If we already have some data, then join to it with a newline.
191+
# Else this is it.
192+
if msg.data:
193+
msg.data = '%s\n%s' % (msg.data, value)
194+
else:
195+
msg.data = value
196+
elif name == 'event':
197+
msg.event = value
198+
elif name == 'id':
199+
msg.event_id = value
200+
elif name == 'retry':
201+
msg.retry = int(value)
202+
203+
return msg
204+
205+
def __str__(self):
206+
return self.data

firebase_admin/db.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import collections
2424
import json
2525
import sys
26+
import threading
2627

2728
import requests
2829
import six
@@ -31,8 +32,14 @@
3132
import firebase_admin
3233
from firebase_admin import _http_client
3334
from firebase_admin import _utils
35+
from firebase_admin import _sseclient
3436

3537

38+
try:
39+
from urllib.parse import urlencode
40+
except ImportError:
41+
from urllib import urlencode
42+
3643
_DB_ATTRIBUTE = '_database'
3744
_INVALID_PATH_CHARACTERS = '[].?#$'
3845
_RESERVED_FILTERS = ('$key', '$value', '$priority')
@@ -75,6 +82,47 @@ def _parse_path(path):
7582
return [seg for seg in path.split('/') if seg]
7683

7784

85+
class ListenerRegistration(object):
86+
"""Class that handles the streaming of data node changes from server"""
87+
def __init__(self, url, stream_handler):
88+
"""Initialize a new ListenerRegistration object with given parameters
89+
90+
Args:
91+
url: the data node url to listen for changes
92+
stream_handler: the callback function to fire in case of event
93+
"""
94+
self.url = url
95+
self.stream_handler = stream_handler
96+
self.sse = None
97+
self.thread = None
98+
self.start()
99+
100+
def start(self):
101+
"""Start the streaming by spawning a thread"""
102+
self.sse = _sseclient.SSEClient(
103+
self.url,
104+
session=_sseclient.KeepAuthSession()
105+
)
106+
self.thread = threading.Thread(target=self.start_stream)
107+
self.thread.start()
108+
return self
109+
110+
def start_stream(self):
111+
"""Streaming function for the spawned thread to run"""
112+
for msg in self.sse:
113+
# iterate the sse client's generator
114+
if msg:
115+
msg_data = json.loads(msg.data)
116+
msg_data["event"] = msg.event
117+
self.stream_handler(msg_data)
118+
119+
def close(self):
120+
"""Terminates SSE server connection and joins the thread"""
121+
self.sse.running = False
122+
self.sse.close()
123+
self.thread.join()
124+
125+
78126
class Reference(object):
79127
"""Reference represents a node in the Firebase realtime database."""
80128

@@ -107,6 +155,22 @@ def parent(self):
107155
return Reference(client=self._client, segments=self._segments[:-1])
108156
return None
109157

158+
def listen(self, stream_handler):
159+
"""Function to setup the streaming of data from server data node changes
160+
161+
Args:
162+
stream_handler: A function to callback in the event of data node change detected
163+
164+
Returns:
165+
object: Returns a ListenerRegistration object which handles the stream
166+
"""
167+
parameters = {}
168+
# reset path and build_query for next query
169+
request_ref = '{}{}.json?{}'.format(
170+
self._client.base_url, self._pathurl, urlencode(parameters)
171+
)
172+
return ListenerRegistration(request_ref, stream_handler)
173+
110174
def child(self, path):
111175
"""Returns a Reference to the specified child node.
112176

tests/test_db.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ def test_range_query(self):
523523
assert recorder[0].headers['User-Agent'] == db._USER_AGENT
524524

525525

526-
class TestDatabseInitialization(object):
526+
class TestDatabaseInitialization(object):
527527
"""Test cases for database initialization."""
528528

529529
def teardown_method(self):

tests/test_sseclient.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""Tests for firebase_admin.sseclient."""
2+
import json
3+
import six
4+
import requests
5+
6+
from firebase_admin import _sseclient
7+
from tests.testutils import MockAdapter
8+
9+
10+
class MockSSEClient(MockAdapter):
11+
def __init__(self, payload):
12+
status = 200
13+
recorder = []
14+
MockAdapter.__init__(self, payload, status, recorder)
15+
16+
def send(self, request, **kwargs):
17+
resp = requests.models.Response()
18+
resp.url = request.url
19+
resp.status_code = self._status
20+
resp.raw = six.BytesIO(self._data.encode())
21+
resp.encoding = "utf-8"
22+
return resp
23+
24+
25+
class TestSSEClient(object):
26+
"""Test cases for the SSEClient"""
27+
28+
test_url = "https://test.firebaseio.com"
29+
30+
31+
def init_sse(self):
32+
payload = 'event: put\ndata: {"path":"/","data":"testevent"}\n\n'
33+
34+
adapter = MockSSEClient(payload)
35+
session = _sseclient.KeepAuthSession()
36+
session.mount(self.test_url, adapter)
37+
38+
sseclient = _sseclient.SSEClient(url=self.test_url, session=session)
39+
return sseclient
40+
41+
42+
def test_init_sseclient(self):
43+
sseclient = self.init_sse()
44+
45+
assert sseclient.url == self.test_url
46+
assert sseclient.session != None
47+
48+
def test_event(self):
49+
sseclient = self.init_sse()
50+
msg = next(sseclient)
51+
event = json.loads(msg.data)
52+
assert event["data"] == "testevent"
53+
assert event["path"] == "/"
54+
55+
56+
class TestEvent(object):
57+
"""Test cases for Events"""
58+
59+
def test_normal(self):
60+
data = 'event: put\ndata: {"path":"/","data":"testdata"}'
61+
event = _sseclient.Event.parse(data)
62+
assert event.event == "put"
63+
assert event.data == '{"path":"/","data":"testdata"}'
64+
65+
def test_invalid(self):
66+
data = 'event: invalid_event'
67+
event = _sseclient.Event.parse(data)
68+
assert event.event == "invalid_event"
69+
assert event.data == ''

0 commit comments

Comments
 (0)