1
- """Tests for SQL thread"""
1
+ """
2
+ Test for sqlThread
3
+ """
2
4
3
5
import os
4
- import tempfile
5
- import threading
6
6
import unittest
7
+ from ..helper_sql import sqlStoredProcedure , sql_ready , sqlExecute , SqlBulkExecute , sqlQuery , sqlExecuteScript
8
+ from ..class_sqlThread import (sqlThread , UpgradeDB )
9
+ from ..addresses import encodeAddress
10
+ import threading
7
11
8
- from pybitmessage .helper_sql import sqlQuery , sql_ready , sqlStoredProcedure
9
- from pybitmessage .class_sqlThread import sqlThread
10
- from pybitmessage .addresses import encodeAddress
11
12
13
+ def filter_table_column (schema , column ):
14
+ """
15
+ Filter column from schema
16
+ """
17
+ for x in schema :
18
+ for y in x :
19
+ if y == column :
20
+ yield y
12
21
13
22
14
23
class TestSqlThread (unittest .TestCase ):
15
- """Test case for SQL thread"""
24
+ """
25
+ Test case for SQLThread
26
+ """
27
+
28
+ # query file path
29
+ __name__ = None
30
+ root_path = os .path .dirname (os .path .dirname (__file__ ))
16
31
17
32
@classmethod
18
33
def setUpClass (cls ):
19
- # Start SQL thread
34
+ """
35
+ Start SQL thread
36
+ """
20
37
sqlLookup = sqlThread ()
21
38
sqlLookup .daemon = True
22
39
sqlLookup .start ()
23
40
sql_ready .wait ()
24
41
42
+ @classmethod
43
+ def setUp (cls ):
44
+ """
45
+ Drop all tables before each test case start
46
+ """
47
+ tables = list (sqlQuery ("select name from sqlite_master where type is 'table'" ))
48
+ with SqlBulkExecute () as sql :
49
+ for q in tables :
50
+ sql .execute ("drop table if exists %s" % q )
51
+
52
+ @classmethod
53
+ def tearDown (cls ):
54
+ pass
55
+
25
56
@classmethod
26
57
def tearDownClass (cls ):
58
+ """
59
+ Join the thread
60
+ """
27
61
sqlStoredProcedure ('exit' )
28
62
for thread in threading .enumerate ():
29
63
if thread .name == "SQL" :
30
64
thread .join ()
31
65
66
+ def initialise_database (self , file ):
67
+ """
68
+ Initialise DB
69
+ """
70
+ with open (os .path .join (self .root_path , "tests/sql/{}.sql" .format (file )), 'r' ) as sql_as_string :
71
+ sql_as_string = sql_as_string .read ()
72
+
73
+ sqlExecuteScript (sql_as_string )
74
+
75
+ def version (self ):
76
+ """
77
+ Run SQL Scripts, Initialize DB with respect to versioning
78
+ and Upgrade DB schema for all versions
79
+ """
80
+ def wrapper (* args ):
81
+ """
82
+ Run SQL and mocking DB for versions
83
+ """
84
+ # import pdb; pdb.set_trace()
85
+ self = args [0 ]
86
+ func_name = func .__name__
87
+ version = func_name .rsplit ('_' , 1 )[- 1 ]
88
+
89
+ # Update versions DB mocking
90
+ self .initialise_database ("init_version_{}" .format (version ))
91
+
92
+ if int (version ) == 9 :
93
+ sqlThread ().create_function ()
94
+
95
+ # Test versions
96
+ upgrade_db = UpgradeDB ()
97
+ upgrade_db ._upgrade_one_level_sql_statement (int (version )) # pylint: disable= W0212, protected-access
98
+ return func (* args ) # <-- use (self, ...)
99
+ func = self
100
+ return wrapper
101
+
32
102
def test_create_function (self ):
33
- """Check the result of enaddr function"""
103
+ """
104
+ Test create_function and asserting the result
105
+ """
106
+
107
+ # call create function
34
108
encoded_str = encodeAddress (4 , 1 , "21122112211221122112" )
35
- query = sqlQuery ('SELECT enaddr(4, 1, "21122112211221122112")' )
36
- self .assertEqual (
37
- query [0 ][- 1 ], encoded_str , "test case fail for create_function" )
109
+
110
+ # Initialise Database
111
+ self .initialise_database ("create_function" )
112
+
113
+ sqlExecute ('''INSERT INTO testhash (addressversion, hash) VALUES(4, "21122112211221122112")''' )
114
+ # call function in query
115
+
116
+ sqlExecute ('''UPDATE testhash SET address=(enaddr(testhash.addressversion, 1, hash));''' )
117
+
118
+ # Assertion
119
+ query = sqlQuery ('''select * from testhash;''' )
120
+ self .assertEqual (query [0 ][- 1 ], encoded_str , "test case fail for create_function" )
121
+ sqlExecute ('''DROP TABLE testhash''' )
122
+
123
+ @version
124
+ def test_sql_thread_version_2 (self ):
125
+ """
126
+ Test with version 2
127
+ """
128
+
129
+ # Assertion
130
+ res = sqlQuery (''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='inventory_backup' ''' )
131
+ self .assertNotEqual (res [0 ][0 ], 1 , "Table inventory_backup not deleted in versioning 2" )
132
+
133
+ @version
134
+ def test_sql_thread_version_3 (self ):
135
+ """
136
+ Test with version 1
137
+ Version 1 and 3 are same so will skip 3
138
+ """
139
+
140
+ # Assertion after versioning
141
+ res = sqlQuery ('''PRAGMA table_info('inventory');''' )
142
+ result = list (filter_table_column (res , "tag" ))
143
+ res = [tup for tup in res if any (i in tup for i in ["tag" ])]
144
+ self .assertEqual (result , ['tag' ], "Data not migrated for version 1" )
145
+ self .assertEqual (res , [(5 , 'tag' , 'blob' , 0 , "''" , 0 )], "Data not migrated for version 1" )
146
+
147
+ @version
148
+ def test_sql_thread_version_4 (self ):
149
+ """
150
+ Test with version 4
151
+ """
152
+
153
+ # Assertion
154
+ res = sqlQuery ('''select * from inventory where objecttype = 'pubkey';''' )
155
+ self .assertNotEqual (len (res ), 1 , "Table inventory not deleted in versioning 4" )
156
+
157
+ @version
158
+ def test_sql_thread_version_5 (self ):
159
+ """
160
+ Test with version 5
161
+ """
162
+
163
+ # Assertion
164
+ res = sqlQuery (''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='knownnodes' ''' )
165
+
166
+ self .assertNotEqual (res [0 ][0 ], 1 , "Table knownnodes not deleted in versioning 5" )
167
+ res = sqlQuery (''' SELECT count(name) FROM sqlite_master
168
+ WHERE type='table' AND name='objectprocessorqueue'; ''' )
169
+ self .assertNotEqual (len (res ), 0 , "Table objectprocessorqueue not created in versioning 5" )
170
+
171
+ @version
172
+ def test_sql_thread_version_6 (self ):
173
+ """
174
+ Test with version 6
175
+ """
176
+
177
+ # Assertion
178
+
179
+ inventory = sqlQuery ('''PRAGMA table_info('inventory');''' )
180
+ inventory = list (filter_table_column (inventory , "expirestime" ))
181
+ self .assertEqual (inventory , ['expirestime' ], "Data not migrated for version 6" )
182
+
183
+ objectprocessorqueue = sqlQuery ('''PRAGMA table_info('inventory');''' )
184
+ objectprocessorqueue = list (filter_table_column (objectprocessorqueue , "objecttype" ))
185
+ self .assertEqual (objectprocessorqueue , ['objecttype' ], "Data not migrated for version 6" )
186
+
187
+ @version
188
+ def test_sql_thread_version_7 (self ):
189
+ """
190
+ Test with version 7
191
+ """
192
+
193
+ # Assertion
194
+ pubkeys = sqlQuery ('''SELECT * FROM pubkeys ''' )
195
+ self .assertEqual (pubkeys , [], "Data not migrated for version 7" )
196
+
197
+ inventory = sqlQuery ('''SELECT * FROM inventory ''' )
198
+ self .assertEqual (inventory , [], "Data not migrated for version 7" )
199
+
200
+ sent = sqlQuery ('''SELECT status FROM sent ''' )
201
+ self .assertEqual (sent , [('msgqueued' ,), ('msgqueued' ,)], "Data not migrated for version 7" )
202
+
203
+ @version
204
+ def test_sql_thread_version_8 (self ):
205
+ """
206
+ Test with version 8
207
+ """
208
+
209
+ # Assertion
210
+ res = sqlQuery ('''PRAGMA table_info('inbox');''' )
211
+ result = list (filter_table_column (res , "sighash" ))
212
+ self .assertEqual (result , ['sighash' ], "Data not migrated for version 8" )
213
+
214
+ @version
215
+ def test_sql_thread_version_9 (self ):
216
+ """
217
+ Test with version 9
218
+ """
219
+
220
+ # Assertion
221
+ res = sqlQuery (''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='pubkeys_backup' ''' )
222
+ self .assertNotEqual (res [0 ][0 ], 1 , "Table pubkeys_backup not deleted" )
223
+
224
+ res = sqlQuery ('''PRAGMA table_info('pubkeys');''' )
225
+ # res = res.fetchall()
226
+ result = list (filter_table_column (res , "address" ))
227
+ self .assertEqual (result , ['address' ], "Data not migrated for version 9" )
228
+
229
+ @version
230
+ def test_sql_thread_version_10 (self ):
231
+ """
232
+ Test with version 10
233
+ """
234
+
235
+ # Assertion
236
+ res = sqlQuery (''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='old_addressbook' ''' )
237
+ self .assertNotEqual (res [0 ][0 ], 1 , "Table old_addressbook not deleted" )
238
+ self .assertEqual (len (res ), 1 , "Table old_addressbook not deleted" )
239
+
240
+ res = sqlQuery ('''PRAGMA table_info('addressbook');''' )
241
+ result = list (filter_table_column (res , "address" ))
242
+ self .assertEqual (result , ['address' ], "Data not migrated for version 10" )
0 commit comments