Skip to content

Commit c04cc05

Browse files
committed
reset test sqlthread
1 parent 4d6170a commit c04cc05

File tree

1 file changed

+217
-12
lines changed

1 file changed

+217
-12
lines changed

src/tests/test_sqlthread.py

Lines changed: 217 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,242 @@
1-
"""Tests for SQL thread"""
1+
"""
2+
Test for sqlThread
3+
"""
24

35
import os
4-
import tempfile
5-
import threading
66
import unittest
7+
from ..helper_sql import sqlStoredProcedure, sql_ready, sqlExecute, SqlBulkExecute, sqlQuery, sqlExecuteScript
8+
from ..class_sqlThread import (sqlThread, UpgradeDB)
9+
from ..addresses import encodeAddress
10+
import threading
711

8-
from pybitmessage.helper_sql import sqlQuery, sql_ready, sqlStoredProcedure
9-
from pybitmessage.class_sqlThread import sqlThread
10-
from pybitmessage.addresses import encodeAddress
1112

13+
def filter_table_column(schema, column):
14+
"""
15+
Filter column from schema
16+
"""
17+
for x in schema:
18+
for y in x:
19+
if y == column:
20+
yield y
1221

1322

1423
class TestSqlThread(unittest.TestCase):
15-
"""Test case for SQL thread"""
24+
"""
25+
Test case for SQLThread
26+
"""
27+
28+
# query file path
29+
__name__ = None
30+
root_path = os.path.dirname(os.path.dirname(__file__))
1631

1732
@classmethod
1833
def setUpClass(cls):
19-
# Start SQL thread
34+
"""
35+
Start SQL thread
36+
"""
2037
sqlLookup = sqlThread()
2138
sqlLookup.daemon = True
2239
sqlLookup.start()
2340
sql_ready.wait()
2441

42+
@classmethod
43+
def setUp(cls):
44+
"""
45+
Drop all tables before each test case start
46+
"""
47+
tables = list(sqlQuery("select name from sqlite_master where type is 'table'"))
48+
with SqlBulkExecute() as sql:
49+
for q in tables:
50+
sql.execute("drop table if exists %s" % q)
51+
52+
@classmethod
53+
def tearDown(cls):
54+
pass
55+
2556
@classmethod
2657
def tearDownClass(cls):
58+
"""
59+
Join the thread
60+
"""
2761
sqlStoredProcedure('exit')
2862
for thread in threading.enumerate():
2963
if thread.name == "SQL":
3064
thread.join()
3165

66+
def initialise_database(self, file):
67+
"""
68+
Initialise DB
69+
"""
70+
with open(os.path.join(self.root_path, "tests/sql/{}.sql".format(file)), 'r') as sql_as_string:
71+
sql_as_string = sql_as_string.read()
72+
73+
sqlExecuteScript(sql_as_string)
74+
75+
def version(self):
76+
"""
77+
Run SQL Scripts, Initialize DB with respect to versioning
78+
and Upgrade DB schema for all versions
79+
"""
80+
def wrapper(*args):
81+
"""
82+
Run SQL and mocking DB for versions
83+
"""
84+
# import pdb; pdb.set_trace()
85+
self = args[0]
86+
func_name = func.__name__
87+
version = func_name.rsplit('_', 1)[-1]
88+
89+
# Update versions DB mocking
90+
self.initialise_database("init_version_{}".format(version))
91+
92+
if int(version) == 9:
93+
sqlThread().create_function()
94+
95+
# Test versions
96+
upgrade_db = UpgradeDB()
97+
upgrade_db._upgrade_one_level_sql_statement(int(version)) # pylint: disable= W0212, protected-access
98+
return func(*args) # <-- use (self, ...)
99+
func = self
100+
return wrapper
101+
32102
def test_create_function(self):
33-
"""Check the result of enaddr function"""
103+
"""
104+
Test create_function and asserting the result
105+
"""
106+
107+
# call create function
34108
encoded_str = encodeAddress(4, 1, "21122112211221122112")
35-
query = sqlQuery('SELECT enaddr(4, 1, "21122112211221122112")')
36-
self.assertEqual(
37-
query[0][-1], encoded_str, "test case fail for create_function")
109+
110+
# Initialise Database
111+
self.initialise_database("create_function")
112+
113+
sqlExecute('''INSERT INTO testhash (addressversion, hash) VALUES(4, "21122112211221122112")''')
114+
# call function in query
115+
116+
sqlExecute('''UPDATE testhash SET address=(enaddr(testhash.addressversion, 1, hash));''')
117+
118+
# Assertion
119+
query = sqlQuery('''select * from testhash;''')
120+
self.assertEqual(query[0][-1], encoded_str, "test case fail for create_function")
121+
sqlExecute('''DROP TABLE testhash''')
122+
123+
@version
124+
def test_sql_thread_version_2(self):
125+
"""
126+
Test with version 2
127+
"""
128+
129+
# Assertion
130+
res = sqlQuery(''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='inventory_backup' ''')
131+
self.assertNotEqual(res[0][0], 1, "Table inventory_backup not deleted in versioning 2")
132+
133+
@version
134+
def test_sql_thread_version_3(self):
135+
"""
136+
Test with version 1
137+
Version 1 and 3 are same so will skip 3
138+
"""
139+
140+
# Assertion after versioning
141+
res = sqlQuery('''PRAGMA table_info('inventory');''')
142+
result = list(filter_table_column(res, "tag"))
143+
res = [tup for tup in res if any(i in tup for i in ["tag"])]
144+
self.assertEqual(result, ['tag'], "Data not migrated for version 1")
145+
self.assertEqual(res, [(5, 'tag', 'blob', 0, "''", 0)], "Data not migrated for version 1")
146+
147+
@version
148+
def test_sql_thread_version_4(self):
149+
"""
150+
Test with version 4
151+
"""
152+
153+
# Assertion
154+
res = sqlQuery('''select * from inventory where objecttype = 'pubkey';''')
155+
self.assertNotEqual(len(res), 1, "Table inventory not deleted in versioning 4")
156+
157+
@version
158+
def test_sql_thread_version_5(self):
159+
"""
160+
Test with version 5
161+
"""
162+
163+
# Assertion
164+
res = sqlQuery(''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='knownnodes' ''')
165+
166+
self.assertNotEqual(res[0][0], 1, "Table knownnodes not deleted in versioning 5")
167+
res = sqlQuery(''' SELECT count(name) FROM sqlite_master
168+
WHERE type='table' AND name='objectprocessorqueue'; ''')
169+
self.assertNotEqual(len(res), 0, "Table objectprocessorqueue not created in versioning 5")
170+
171+
@version
172+
def test_sql_thread_version_6(self):
173+
"""
174+
Test with version 6
175+
"""
176+
177+
# Assertion
178+
179+
inventory = sqlQuery('''PRAGMA table_info('inventory');''')
180+
inventory = list(filter_table_column(inventory, "expirestime"))
181+
self.assertEqual(inventory, ['expirestime'], "Data not migrated for version 6")
182+
183+
objectprocessorqueue = sqlQuery('''PRAGMA table_info('inventory');''')
184+
objectprocessorqueue = list(filter_table_column(objectprocessorqueue, "objecttype"))
185+
self.assertEqual(objectprocessorqueue, ['objecttype'], "Data not migrated for version 6")
186+
187+
@version
188+
def test_sql_thread_version_7(self):
189+
"""
190+
Test with version 7
191+
"""
192+
193+
# Assertion
194+
pubkeys = sqlQuery('''SELECT * FROM pubkeys ''')
195+
self.assertEqual(pubkeys, [], "Data not migrated for version 7")
196+
197+
inventory = sqlQuery('''SELECT * FROM inventory ''')
198+
self.assertEqual(inventory, [], "Data not migrated for version 7")
199+
200+
sent = sqlQuery('''SELECT status FROM sent ''')
201+
self.assertEqual(sent, [('msgqueued',), ('msgqueued',)], "Data not migrated for version 7")
202+
203+
@version
204+
def test_sql_thread_version_8(self):
205+
"""
206+
Test with version 8
207+
"""
208+
209+
# Assertion
210+
res = sqlQuery('''PRAGMA table_info('inbox');''')
211+
result = list(filter_table_column(res, "sighash"))
212+
self.assertEqual(result, ['sighash'], "Data not migrated for version 8")
213+
214+
@version
215+
def test_sql_thread_version_9(self):
216+
"""
217+
Test with version 9
218+
"""
219+
220+
# Assertion
221+
res = sqlQuery(''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='pubkeys_backup' ''')
222+
self.assertNotEqual(res[0][0], 1, "Table pubkeys_backup not deleted")
223+
224+
res = sqlQuery('''PRAGMA table_info('pubkeys');''')
225+
# res = res.fetchall()
226+
result = list(filter_table_column(res, "address"))
227+
self.assertEqual(result, ['address'], "Data not migrated for version 9")
228+
229+
@version
230+
def test_sql_thread_version_10(self):
231+
"""
232+
Test with version 10
233+
"""
234+
235+
# Assertion
236+
res = sqlQuery(''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='old_addressbook' ''')
237+
self.assertNotEqual(res[0][0], 1, "Table old_addressbook not deleted")
238+
self.assertEqual(len(res), 1, "Table old_addressbook not deleted")
239+
240+
res = sqlQuery('''PRAGMA table_info('addressbook');''')
241+
result = list(filter_table_column(res, "address"))
242+
self.assertEqual(result, ['address'], "Data not migrated for version 10")

0 commit comments

Comments
 (0)