Skip to content

Commit a420692

Browse files
committed
made suggested changes
1 parent 31777f6 commit a420692

File tree

3 files changed

+81
-37
lines changed

3 files changed

+81
-37
lines changed

frontera/contrib/backends/remote/codecs/json.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,33 +9,33 @@
99
from w3lib.util import to_unicode, to_bytes
1010

1111

12-
def _convert(obj):
12+
def _encode_recursively(obj):
1313
"""
14-
recursively convert an object to unicode and store its type
14+
recursively encodes an object to unicode and store its type
1515
"""
1616
if isinstance(obj, bytes):
1717
return 'bytes', to_unicode(obj)
1818
elif isinstance(obj, dict):
19-
return 'dict', [(_convert(k), _convert(v)) for k, v in six.iteritems(obj)]
19+
return 'dict', [(_encode_recursively(k), _encode_recursively(v)) for k, v in six.iteritems(obj)]
2020
elif isinstance(obj, (list, tuple)):
21-
return type(obj).__name__, [_convert(item) for item in obj]
21+
return type(obj).__name__, [_encode_recursively(item) for item in obj]
2222
return 'other', obj
2323

2424

25-
def _reconvert(obj):
25+
def _decode_recursively(obj):
2626
"""
27-
reconvert an object changed by method `convert` back to
27+
decode an object changed by method `_encode_recursively` back to
2828
its original form
2929
"""
3030
obj_type = obj[0]
3131
obj_value = obj[1]
3232
if obj_type == 'bytes':
3333
return to_bytes(obj_value)
3434
elif obj_type == 'dict':
35-
return dict([(_reconvert(k), _reconvert(v)) for k, v in obj_value])
35+
return dict([(_decode_recursively(k), _decode_recursively(v)) for k, v in obj_value])
3636
elif obj_type in ['list', 'tuple']:
3737
_type = list if obj_type == 'list' else tuple
38-
return _type([_reconvert(item) for item in obj_value])
38+
return _type([_decode_recursively(item) for item in obj_value])
3939
return obj_value
4040

4141

@@ -76,8 +76,8 @@ def __init__(self, request_model, *a, **kw):
7676
super(Encoder, self).__init__(request_model, *a, **kw)
7777

7878
def encode(self, obj):
79-
converted = _convert(obj)
80-
return super(Encoder, self).encode(converted)
79+
encoded = _encode_recursively(obj)
80+
return super(Encoder, self).encode(encoded)
8181

8282
def encode_add_seeds(self, seeds):
8383
return self.encode({
@@ -151,7 +151,7 @@ def _request_from_object(self, obj):
151151
meta=obj['meta'])
152152

153153
def decode(self, message):
154-
message = _reconvert(super(Decoder, self).decode(message))
154+
message = _decode_recursively(super(Decoder, self).decode(message))
155155
if message['type'] == 'links_extracted':
156156
request = self._request_from_object(message['r'])
157157
links = [self._request_from_object(link) for link in message['links']]
@@ -177,7 +177,7 @@ def decode(self, message):
177177
return TypeError('Unknown message type')
178178

179179
def decode_request(self, message):
180-
obj = _reconvert(super(Decoder, self).decode(message))
180+
obj = _decode_recursively(super(Decoder, self).decode(message))
181181
return self._request_model(url=obj['url'],
182182
method=obj['method'],
183183
headers=obj['headers'],

frontera/contrib/backends/remote/codecs/msgpack.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,29 @@
66
from msgpack import packb, unpackb
77

88
from frontera.core.codec import BaseDecoder, BaseEncoder
9+
import six
910
from w3lib.util import to_native_str
1011

1112

1213
def _prepare_request_message(request):
13-
return [request.url, request.method, request.headers, request.cookies, request.meta]
14+
def serialize(obj):
15+
"""Recursively walk object's hierarchy."""
16+
if isinstance(obj, (bool, six.integer_types, float, six.binary_type, six.text_type)):
17+
return obj
18+
elif isinstance(obj, dict):
19+
obj = obj.copy()
20+
for key in obj:
21+
obj[key] = serialize(obj[key])
22+
return obj
23+
elif isinstance(obj, list):
24+
return [serialize(item) for item in obj]
25+
elif isinstance(obj, tuple):
26+
return tuple(serialize([item for item in obj]))
27+
elif hasattr(obj, '__dict__'):
28+
return serialize(obj.__dict__)
29+
else:
30+
return None
31+
return [request.url, request.method, request.headers, request.cookies, serialize(request.meta)]
1432

1533

1634
def _prepare_response_message(response, send_body):
@@ -22,29 +40,29 @@ def __init__(self, request_model, *a, **kw):
2240
self.send_body = True if 'send_body' in kw and kw['send_body'] else False
2341

2442
def encode_add_seeds(self, seeds):
25-
return packb([b'as', [_prepare_request_message(seed) for seed in seeds]], use_bin_type=True, encoding="utf-8")
43+
return packb([b'as', [_prepare_request_message(seed) for seed in seeds]], use_bin_type=True)
2644

2745
def encode_page_crawled(self, response):
28-
return packb([b'pc', _prepare_response_message(response, self.send_body)], use_bin_type=True, encoding="utf-8")
46+
return packb([b'pc', _prepare_response_message(response, self.send_body)], use_bin_type=True)
2947

3048
def encode_links_extracted(self, request, links):
3149
return packb([b'le', _prepare_request_message(request), [_prepare_request_message(link) for link in links]],
32-
use_bin_type=True, encoding="utf-8")
50+
use_bin_type=True)
3351

3452
def encode_request_error(self, request, error):
35-
return packb([b're', _prepare_request_message(request), str(error)], use_bin_type=True, encoding="utf-8")
53+
return packb([b're', _prepare_request_message(request), str(error)], use_bin_type=True)
3654

3755
def encode_request(self, request):
38-
return packb(_prepare_request_message(request), use_bin_type=True, encoding="utf-8")
56+
return packb(_prepare_request_message(request), use_bin_type=True)
3957

4058
def encode_update_score(self, request, score, schedule):
41-
return packb([b'us', _prepare_request_message(request), score, schedule], use_bin_type=True, encoding="utf-8")
59+
return packb([b'us', _prepare_request_message(request), score, schedule], use_bin_type=True)
4260

4361
def encode_new_job_id(self, job_id):
44-
return packb([b'njid', int(job_id)], use_bin_type=True, encoding="utf-8")
62+
return packb([b'njid', int(job_id)], use_bin_type=True)
4563

4664
def encode_offset(self, partition_id, offset):
47-
return packb([b'of', int(partition_id), int(offset)], use_bin_type=True, encoding="utf-8")
65+
return packb([b'of', int(partition_id), int(offset)], use_bin_type=True)
4866

4967

5068
class Decoder(BaseDecoder):
@@ -68,7 +86,7 @@ def _request_from_object(self, obj):
6886
meta=obj[4])
6987

7088
def decode(self, buffer):
71-
obj = unpackb(buffer, encoding="utf-8")
89+
obj = unpackb(buffer, encoding='utf-8')
7290
if obj[0] == b'pc':
7391
return ('page_crawled',
7492
self._response_from_object(obj[1]))
@@ -89,4 +107,4 @@ def decode(self, buffer):
89107
return TypeError('Unknown message type')
90108

91109
def decode_request(self, buffer):
92-
return self._request_from_object(unpackb(buffer, encoding="utf-8"))
110+
return self._request_from_object(unpackb(buffer, encoding='utf-8'))

tests/test_codecs.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,38 @@
55
import unittest
66
from frontera.contrib.backends.remote.codecs.json import (Encoder as JsonEncoder,
77
Decoder as JsonDecoder,
8-
_convert, _reconvert)
8+
_encode_recursively, _decode_recursively)
99
from frontera.contrib.backends.remote.codecs.msgpack import Encoder as MsgPackEncoder, Decoder as MsgPackDecoder
1010
from frontera.core.models import Request, Response
1111
import pytest
1212

1313

14+
def _compare_dicts(dict1, dict2):
15+
"""
16+
Compares two dicts
17+
:return: True if both dicts are equal else False
18+
"""
19+
if dict1 == None or dict2 == None:
20+
return False
21+
22+
if type(dict1) is not dict or type(dict2) is not dict:
23+
return False
24+
25+
shared_keys = set(dict2.keys()) & set(dict2.keys())
26+
27+
if not (len(shared_keys) == len(dict1.keys()) and len(shared_keys) == len(dict2.keys())):
28+
return False
29+
30+
dicts_are_equal = True
31+
for key in dict1.keys():
32+
if type(dict1[key]) is dict:
33+
dicts_are_equal = dicts_are_equal and _compare_dicts(dict1[key], dict2[key])
34+
else:
35+
dicts_are_equal = dicts_are_equal and (dict1[key] == dict2[key]) and (type(dict1[key]) == type(dict2[key]))
36+
37+
return dicts_are_equal
38+
39+
1440
@pytest.mark.parametrize(
1541
('encoder', 'decoder'), [
1642
(MsgPackEncoder, MsgPackDecoder),
@@ -19,8 +45,8 @@
1945
)
2046
def test_codec(encoder, decoder):
2147
def check_request(req1, req2):
22-
assert req1.url == req2.url and req1.meta == req2.meta and req1.headers == req2.headers \
23-
and req1.method == req2.method
48+
assert req1.url == req2.url and _compare_dicts(req1.meta, req2.meta) == True and \
49+
_compare_dicts(req1.headers, req2.headers) == True and req1.method == req2.method
2450

2551
enc = encoder(Request, send_body=True)
2652
dec = decoder(Request, Response)
@@ -87,12 +113,12 @@ def check_request(req1, req2):
87113
check_request(o, req)
88114

89115

90-
class TestConvertReconvertJson(unittest.TestCase):
116+
class TestEncodeDecodeJson(unittest.TestCase):
91117
"""
92-
Test for testing methods `_convert` and `_reconvert` used in json codec
118+
Test for testing methods `_encode_recursively` and `_decode_recursively` used in json codec
93119
"""
94120

95-
def test_convert_encode_decode_reconvert(self):
121+
def test_encode_decode_json_recursively(self):
96122
_int = 1
97123
_bytes = b'bytes'
98124
_unicode = u'unicode'
@@ -112,11 +138,11 @@ def test_convert_encode_decode_reconvert(self):
112138
encoder = json.JSONEncoder()
113139
decoder = json.JSONDecoder()
114140
for original_msg in msgs:
115-
converted_msg = _convert(original_msg)
116-
encoded_msg = encoder.encode(converted_msg)
117-
decoded_msg = decoder.decode(encoded_msg)
118-
reconverted_msg = _reconvert(decoded_msg)
119-
if isinstance(reconverted_msg, dict):
120-
self.assertDictEqual(reconverted_msg, original_msg)
121-
elif isinstance(reconverted_msg, (list, tuple)):
122-
self.assertSequenceEqual(reconverted_msg, original_msg)
141+
encoded_msg_1 = _encode_recursively(original_msg)
142+
encoded_msg_2 = encoder.encode(encoded_msg_1)
143+
decoded_msg_2 = decoder.decode(encoded_msg_2)
144+
decoded_msg_1 = _decode_recursively(decoded_msg_2)
145+
if isinstance(decoded_msg_1, dict):
146+
self.assertDictEqual(decoded_msg_1, original_msg)
147+
elif isinstance(decoded_msg_1, (list, tuple)):
148+
self.assertSequenceEqual(decoded_msg_1, original_msg)

0 commit comments

Comments
 (0)