Skip to content

Commit 465f35e

Browse files
cis-muzahidkdcis
authored andcommitted
refactor sqlthread and add create test cases
1 parent 70dcb94 commit 465f35e

9 files changed

+563
-549
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def run(self):
5151
if __name__ == "__main__":
5252
here = os.path.abspath(os.path.dirname(__file__))
5353
with open(os.path.join(here, 'README.md')) as f:
54-
README = f.read()
54+
README. = f.read()
5555

5656
with open(os.path.join(here, 'requirements.txt'), 'r') as f:
5757
requirements = list(f.readlines())

src/class_sqlThread.py

Lines changed: 341 additions & 519 deletions
Large diffs are not rendered by default.

src/sql/init_version_5.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@ DROP TABLE knownnodes;
88
CREATE TABLE `objectprocessorqueue` (
99
`objecttype` text,
1010
`data` blob,
11-
UNIQUE(objecttype, data) ON CONFLICT REPLACE
11+
UNIQUE(objecttype, data) ON CONFLICT REPLACE
1212
) ;

src/sql/initialize_schema.sql

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,6 @@ CREATE TABLE `objectprocessorqueue` (
9898
`data` blob,
9999
UNIQUE(objecttype, data) ON CONFLICT REPLACE
100100
) ;
101+
102+
103+

src/tests/core.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from bmconfigparser import config
2525
from helper_msgcoding import MsgEncode, MsgDecode
26-
from helper_sql import sqlQuery
26+
from helper_sql import sqlQuery, sqlExecute
2727
from network import asyncore_pollchoose as asyncore, knownnodes
2828
from network.bmproto import BMProto
2929
from network.connectionpool import BMConnectionPool
@@ -409,6 +409,22 @@ def test_adding_two_same_case_sensitive_addresses(self):
409409
self.delete_address_from_addressbook(address1)
410410
self.delete_address_from_addressbook(address2)
411411

412+
def test_sqlscripts(self):
413+
""" Test sql statements"""
414+
415+
sqlExecute('create table if not exists testtbl (id integer)')
416+
tables = list(sqlQuery("select name from sqlite_master where type is 'table'"))
417+
res = [item for item in tables if 'testtbl' in item]
418+
self.assertEqual(res[0][0], 'testtbl')
419+
420+
queryreturn = sqlExecute("INSERT INTO testtbl VALUES(101);")
421+
self.assertEqual(queryreturn, 1)
422+
423+
queryreturn = sqlQuery('''SELECT * FROM testtbl''')
424+
self.assertEqual(queryreturn[0][0], 101)
425+
426+
sqlQuery("DROP TABLE testtbl")
427+
412428

413429
def run():
414430
"""Starts all tests defined in this module"""

src/tests/sql/__init__.py

Whitespace-only changes.

src/tests/sql/init_version_5.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
INSERT INTO `objectprocessorqueue` VALUES ('hash', 1);
2+

src/tests/sql/init_version_6.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
INSERT INTO `inventory` VALUES ('hash', 1, 1, 1,'test','test');
2+

src/tests/test_sqlthread.py

Lines changed: 198 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,43 +2,214 @@
22
# flake8: noqa:E402
33
import os
44
import tempfile
5-
import threading
65
import unittest
76

8-
from .common import skip_python3
9-
10-
skip_python3()
11-
127
os.environ['BITMESSAGE_HOME'] = tempfile.gettempdir()
138

14-
from pybitmessage.helper_sql import (
15-
sqlQuery, sql_ready, sqlStoredProcedure) # noqa:E402
16-
from pybitmessage.class_sqlThread import sqlThread # noqa:E402
9+
from pybitmessage.class_sqlThread import TestDB # noqa:E402
1710
from pybitmessage.addresses import encodeAddress # noqa:E402
1811

1912

20-
class TestSqlThread(unittest.TestCase):
21-
"""Test case for SQL thread"""
13+
def filter_table_column(schema, column):
14+
"""
15+
Filter column from schema
16+
"""
17+
for x in schema:
18+
for y in x:
19+
if y == column:
20+
yield y
21+
22+
23+
class TestSqlBase(object): # pylint: disable=E1101, too-few-public-methods, E1004, W0232
24+
""" Base for test case """
25+
26+
__name__ = None
27+
root_path = os.path.dirname(os.path.dirname(__file__))
28+
test_db = None
29+
30+
def _setup_db(self): # pylint: disable=W0622, redefined-builtin
31+
"""
32+
Drop all tables before each test case start
33+
"""
34+
self.test_db = TestDB()
35+
self.test_db.create_sql_function()
36+
self.test_db.initialize_schema()
37+
38+
def initialise_database(self, test_db_cur, file): # pylint: disable=W0622, redefined-builtin
39+
"""
40+
Initialise DB
41+
"""
42+
43+
with open(os.path.join(self.root_path, "tests/sql/{}.sql".format(file)), 'r') as sql_as_string:
44+
sql_as_string = sql_as_string.read()
45+
46+
test_db_cur.cur.executescript(sql_as_string)
47+
2248

23-
@classmethod
24-
def setUpClass(cls):
25-
# Start SQL thread
26-
sqlLookup = sqlThread()
27-
sqlLookup.daemon = True
28-
sqlLookup.start()
29-
sql_ready.wait()
49+
class TestFnBitmessageDB(TestSqlBase, unittest.TestCase): # pylint: disable=protected-access
50+
""" Test case for Sql function"""
3051

31-
@classmethod
32-
def tearDownClass(cls):
33-
sqlStoredProcedure('exit')
34-
for thread in threading.enumerate():
35-
if thread.name == "SQL":
36-
thread.join()
52+
def setUp(self):
53+
"""
54+
setup for test case
55+
"""
56+
self._setup_db()
3757

3858
def test_create_function(self):
3959
"""Check the result of enaddr function"""
40-
encoded_str = encodeAddress(4, 1, "21122112211221122112")
60+
st = "21122112211221122112".encode()
61+
encoded_str = encodeAddress(4, 1, st)
62+
63+
item = '''SELECT enaddr(4, 1, ?);'''
64+
parameters = (st, )
65+
self.test_db.cur.execute(item, parameters)
66+
query = self.test_db.cur.fetchall()
67+
self.assertEqual(query[0][-1], encoded_str, "test case fail for create_function")
68+
69+
70+
class TestUpgradeBitmessageDB(TestSqlBase, unittest.TestCase): # pylint: disable=protected-access
71+
"""Test case for SQL versions"""
72+
73+
def setUp(self):
74+
"""
75+
Setup DB schema before start.
76+
And applying default schema for version test.
77+
"""
78+
self._setup_db()
79+
self.test_db.cur.execute('''INSERT INTO settings VALUES('version','2')''')
80+
81+
def version(self):
82+
"""
83+
Run SQL Scripts, Initialize DB with respect to versioning
84+
and Upgrade DB schema for all versions
85+
"""
86+
def wrapper(*args):
87+
"""
88+
Run SQL and mocking DB for versions
89+
"""
90+
self = args[0]
91+
func_name = func.__name__
92+
version = func_name.rsplit('_', 1)[-1]
93+
94+
self.test_db._upgrade_one_level_sql_statement(int(version)) # pylint: disable= W0212, protected-access
95+
96+
# Update versions DB mocking
97+
self.initialise_database(self.test_db, "init_version_{}".format(version))
98+
99+
return func(*args) # <-- use (self, ...)
100+
func = self
101+
return wrapper
102+
103+
@version
104+
def test_bm_db_version_2(self):
105+
"""
106+
Test with version 2
107+
"""
108+
109+
res = self.test_db.cur.execute(''' SELECT count(name) FROM sqlite_master
110+
WHERE type='table' AND name='inventory_backup' ''')
111+
self.assertNotEqual(res, 1, "Table inventory_backup not deleted in versioning 2")
112+
113+
@version
114+
def test_bm_db_version_3(self):
115+
"""
116+
Test with version 1
117+
Version 1 and 3 are same so will skip 3
118+
"""
119+
120+
res = self.test_db.cur.execute('''PRAGMA table_info('inventory');''')
121+
result = list(filter_table_column(res, "tag"))
122+
self.assertEqual(result, ['tag'], "Data not migrated for version 3")
123+
124+
@version
125+
def test_bm_db_version_4(self):
126+
"""
127+
Test with version 4
128+
"""
129+
130+
self.test_db.cur.execute("select * from pubkeys where addressversion = '1';")
131+
res = self.test_db.cur.fetchall()
132+
self.assertEqual(len(res), 1, "Table inventory not deleted in versioning 4")
133+
134+
@version
135+
def test_bm_db_version_5(self):
136+
"""
137+
Test with version 5
138+
"""
139+
140+
self.test_db.cur.execute(''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='knownnodes' ''') # noqa
141+
res = self.test_db.cur.fetchall()
142+
self.assertNotEqual(res[0][0], 1, "Table knownnodes not deleted in versioning 5")
143+
self.test_db.cur.execute(''' SELECT count(name) FROM sqlite_master
144+
WHERE type='table' AND name='objectprocessorqueue'; ''')
145+
res = self.test_db.cur.fetchall()
146+
self.assertNotEqual(len(res), 0, "Table objectprocessorqueue not created in versioning 5")
147+
self.test_db.cur.execute(''' SELECT * FROM objectprocessorqueue where objecttype='hash' ; ''')
148+
res = self.test_db.cur.fetchall()
149+
self.assertNotEqual(len(res), 0, "Table objectprocessorqueue not created in versioning 5")
150+
151+
@version
152+
def test_bm_db_version_6(self):
153+
"""
154+
Test with version 6
155+
"""
156+
157+
self.test_db.cur.execute('''PRAGMA table_info('inventory');''')
158+
inventory = self.test_db.cur.fetchall()
159+
inventory = list(filter_table_column(inventory, "expirestime"))
160+
self.assertEqual(inventory, ['expirestime'], "Data not migrated for version 6")
161+
162+
self.test_db.cur.execute('''PRAGMA table_info('objectprocessorqueue');''')
163+
objectprocessorqueue = self.test_db.cur.fetchall()
164+
objectprocessorqueue = list(filter_table_column(objectprocessorqueue, "objecttype"))
165+
self.assertEqual(objectprocessorqueue, ['objecttype'], "Data not migrated for version 6")
166+
167+
@version
168+
def test_bm_db_version_7(self):
169+
"""
170+
Test with version 7
171+
"""
172+
173+
self.test_db.cur.execute('''SELECT * FROM pubkeys ''')
174+
pubkeys = self.test_db.cur.fetchall()
175+
self.assertEqual(pubkeys, [], "Data not migrated for version 7")
176+
177+
self.test_db.cur.execute('''SELECT * FROM inventory ''')
178+
inventory = self.test_db.cur.fetchall()
179+
self.assertEqual(inventory, [], "Data not migrated for version 7")
180+
181+
self.test_db.cur.execute('''SELECT status FROM sent ''')
182+
sent = self.test_db.cur.fetchall()
183+
self.assertEqual(sent, [('msgqueued',), ('msgqueued',)], "Data not migrated for version 7")
184+
185+
@version
186+
def test_bm_db_version_8(self):
187+
"""
188+
Test with version 8
189+
"""
190+
191+
self.test_db.cur.execute('''PRAGMA table_info('inbox');''')
192+
res = self.test_db.cur.fetchall()
193+
result = list(filter_table_column(res, "sighash"))
194+
self.assertEqual(result, ['sighash'], "Data not migrated for version 8")
195+
196+
@version
197+
def test_bm_db_version_9(self):
198+
"""
199+
Test with version 9
200+
"""
201+
202+
self.test_db.cur.execute("SELECT count(name) FROM sqlite_master WHERE type='table' AND name='pubkeys_backup'") # noqa
203+
res = self.test_db.cur.fetchall()
204+
self.assertNotEqual(res[0][0], 1, "Table pubkeys_backup not deleted")
205+
206+
@version
207+
def test_bm_db_version_10(self):
208+
"""
209+
Test with version 10
210+
"""
41211

42-
query = sqlQuery('SELECT enaddr(4, 1, "21122112211221122112")')
43-
self.assertEqual(
44-
query[0][-1], encoded_str, "test case fail for create_function")
212+
label = "test"
213+
self.test_db.cur.execute("SELECT * FROM addressbook WHERE label='test' ") # noqa
214+
res = self.test_db.cur.fetchall()
215+
self.assertEqual(res[0][0], label, "Data not migrated for version 10")

0 commit comments

Comments
 (0)