Skip to content

Commit 4d6170a

Browse files
committed
chaanges for rebase
1 parent 769db2a commit 4d6170a

File tree

1 file changed

+12
-217
lines changed

1 file changed

+12
-217
lines changed

src/tests/test_sqlthread.py

Lines changed: 12 additions & 217 deletions
Original file line numberDiff line numberDiff line change
@@ -1,242 +1,37 @@
1-
"""
2-
Test for sqlThread
3-
"""
1+
"""Tests for SQL thread"""
42

53
import os
6-
import unittest
7-
from ..helper_sql import sqlStoredProcedure, sql_ready, sqlExecute, SqlBulkExecute, sqlQuery, sqlExecuteScript
8-
from ..class_sqlThread import (sqlThread, UpgradeDB)
9-
from ..addresses import encodeAddress
4+
import tempfile
105
import threading
6+
import unittest
117

8+
from pybitmessage.helper_sql import sqlQuery, sql_ready, sqlStoredProcedure
9+
from pybitmessage.class_sqlThread import sqlThread
10+
from pybitmessage.addresses import encodeAddress
1211

13-
def filter_table_column(schema, column):
14-
"""
15-
Filter column from schema
16-
"""
17-
for x in schema:
18-
for y in x:
19-
if y == column:
20-
yield y
2112

2213

2314
class TestSqlThread(unittest.TestCase):
24-
"""
25-
Test case for SQLThread
26-
"""
27-
28-
# query file path
29-
__name__ = None
30-
root_path = os.path.dirname(os.path.dirname(__file__))
15+
"""Test case for SQL thread"""
3116

3217
@classmethod
3318
def setUpClass(cls):
34-
"""
35-
Start SQL thread
36-
"""
19+
# Start SQL thread
3720
sqlLookup = sqlThread()
3821
sqlLookup.daemon = True
3922
sqlLookup.start()
4023
sql_ready.wait()
4124

42-
@classmethod
43-
def setUp(cls):
44-
"""
45-
Drop all tables before each test case start
46-
"""
47-
tables = list(sqlQuery("select name from sqlite_master where type is 'table'"))
48-
with SqlBulkExecute() as sql:
49-
for q in tables:
50-
sql.execute("drop table if exists %s" % q)
51-
52-
@classmethod
53-
def tearDown(cls):
54-
pass
55-
5625
@classmethod
5726
def tearDownClass(cls):
58-
"""
59-
Join the thread
60-
"""
6127
sqlStoredProcedure('exit')
6228
for thread in threading.enumerate():
6329
if thread.name == "SQL":
6430
thread.join()
6531

66-
def initialise_database(self, file):
67-
"""
68-
Initialise DB
69-
"""
70-
with open(os.path.join(self.root_path, "tests/sql/{}.sql".format(file)), 'r') as sql_as_string:
71-
sql_as_string = sql_as_string.read()
72-
73-
sqlExecuteScript(sql_as_string)
74-
75-
def version(self):
76-
"""
77-
Run SQL Scripts, Initialize DB with respect to versioning
78-
and Upgrade DB schema for all versions
79-
"""
80-
def wrapper(*args):
81-
"""
82-
Run SQL and mocking DB for versions
83-
"""
84-
# import pdb; pdb.set_trace()
85-
self = args[0]
86-
func_name = func.__name__
87-
version = func_name.rsplit('_', 1)[-1]
88-
89-
# Update versions DB mocking
90-
self.initialise_database("init_version_{}".format(version))
91-
92-
if int(version) == 9:
93-
sqlThread().create_function()
94-
95-
# Test versions
96-
upgrade_db = UpgradeDB()
97-
upgrade_db._upgrade_one_level_sql_statement(int(version)) # pylint: disable= W0212, protected-access
98-
return func(*args) # <-- use (self, ...)
99-
func = self
100-
return wrapper
101-
10232
def test_create_function(self):
103-
"""
104-
Test create_function and asserting the result
105-
"""
106-
107-
# call create function
33+
"""Check the result of enaddr function"""
10834
encoded_str = encodeAddress(4, 1, "21122112211221122112")
109-
110-
# Initialise Database
111-
self.initialise_database("create_function")
112-
113-
sqlExecute('''INSERT INTO testhash (addressversion, hash) VALUES(4, "21122112211221122112")''')
114-
# call function in query
115-
116-
sqlExecute('''UPDATE testhash SET address=(enaddr(testhash.addressversion, 1, hash));''')
117-
118-
# Assertion
119-
query = sqlQuery('''select * from testhash;''')
120-
self.assertEqual(query[0][-1], encoded_str, "test case fail for create_function")
121-
sqlExecute('''DROP TABLE testhash''')
122-
123-
@version
124-
def test_sql_thread_version_2(self):
125-
"""
126-
Test with version 2
127-
"""
128-
129-
# Assertion
130-
res = sqlQuery(''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='inventory_backup' ''')
131-
self.assertNotEqual(res[0][0], 1, "Table inventory_backup not deleted in versioning 2")
132-
133-
@version
134-
def test_sql_thread_version_3(self):
135-
"""
136-
Test with version 1
137-
Version 1 and 3 are same so will skip 3
138-
"""
139-
140-
# Assertion after versioning
141-
res = sqlQuery('''PRAGMA table_info('inventory');''')
142-
result = list(filter_table_column(res, "tag"))
143-
res = [tup for tup in res if any(i in tup for i in ["tag"])]
144-
self.assertEqual(result, ['tag'], "Data not migrated for version 1")
145-
self.assertEqual(res, [(5, 'tag', 'blob', 0, "''", 0)], "Data not migrated for version 1")
146-
147-
@version
148-
def test_sql_thread_version_4(self):
149-
"""
150-
Test with version 4
151-
"""
152-
153-
# Assertion
154-
res = sqlQuery('''select * from inventory where objecttype = 'pubkey';''')
155-
self.assertNotEqual(len(res), 1, "Table inventory not deleted in versioning 4")
156-
157-
@version
158-
def test_sql_thread_version_5(self):
159-
"""
160-
Test with version 5
161-
"""
162-
163-
# Assertion
164-
res = sqlQuery(''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='knownnodes' ''')
165-
166-
self.assertNotEqual(res[0][0], 1, "Table knownnodes not deleted in versioning 5")
167-
res = sqlQuery(''' SELECT count(name) FROM sqlite_master
168-
WHERE type='table' AND name='objectprocessorqueue'; ''')
169-
self.assertNotEqual(len(res), 0, "Table objectprocessorqueue not created in versioning 5")
170-
171-
@version
172-
def test_sql_thread_version_6(self):
173-
"""
174-
Test with version 6
175-
"""
176-
177-
# Assertion
178-
179-
inventory = sqlQuery('''PRAGMA table_info('inventory');''')
180-
inventory = list(filter_table_column(inventory, "expirestime"))
181-
self.assertEqual(inventory, ['expirestime'], "Data not migrated for version 6")
182-
183-
objectprocessorqueue = sqlQuery('''PRAGMA table_info('inventory');''')
184-
objectprocessorqueue = list(filter_table_column(objectprocessorqueue, "objecttype"))
185-
self.assertEqual(objectprocessorqueue, ['objecttype'], "Data not migrated for version 6")
186-
187-
@version
188-
def test_sql_thread_version_7(self):
189-
"""
190-
Test with version 7
191-
"""
192-
193-
# Assertion
194-
pubkeys = sqlQuery('''SELECT * FROM pubkeys ''')
195-
self.assertEqual(pubkeys, [], "Data not migrated for version 7")
196-
197-
inventory = sqlQuery('''SELECT * FROM inventory ''')
198-
self.assertEqual(inventory, [], "Data not migrated for version 7")
199-
200-
sent = sqlQuery('''SELECT status FROM sent ''')
201-
self.assertEqual(sent, [('msgqueued',), ('msgqueued',)], "Data not migrated for version 7")
202-
203-
@version
204-
def test_sql_thread_version_8(self):
205-
"""
206-
Test with version 8
207-
"""
208-
209-
# Assertion
210-
res = sqlQuery('''PRAGMA table_info('inbox');''')
211-
result = list(filter_table_column(res, "sighash"))
212-
self.assertEqual(result, ['sighash'], "Data not migrated for version 8")
213-
214-
@version
215-
def test_sql_thread_version_9(self):
216-
"""
217-
Test with version 9
218-
"""
219-
220-
# Assertion
221-
res = sqlQuery(''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='pubkeys_backup' ''')
222-
self.assertNotEqual(res[0][0], 1, "Table pubkeys_backup not deleted")
223-
224-
res = sqlQuery('''PRAGMA table_info('pubkeys');''')
225-
# res = res.fetchall()
226-
result = list(filter_table_column(res, "address"))
227-
self.assertEqual(result, ['address'], "Data not migrated for version 9")
228-
229-
@version
230-
def test_sql_thread_version_10(self):
231-
"""
232-
Test with version 10
233-
"""
234-
235-
# Assertion
236-
res = sqlQuery(''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='old_addressbook' ''')
237-
self.assertNotEqual(res[0][0], 1, "Table old_addressbook not deleted")
238-
self.assertEqual(len(res), 1, "Table old_addressbook not deleted")
239-
240-
res = sqlQuery('''PRAGMA table_info('addressbook');''')
241-
result = list(filter_table_column(res, "address"))
242-
self.assertEqual(result, ['address'], "Data not migrated for version 10")
35+
query = sqlQuery('SELECT enaddr(4, 1, "21122112211221122112")')
36+
self.assertEqual(
37+
query[0][-1], encoded_str, "test case fail for create_function")

0 commit comments

Comments
 (0)