1
- """
2
- Test for sqlThread
3
- """
1
+ """Tests for SQL thread"""
4
2
5
3
import os
6
- import unittest
7
- from ..helper_sql import sqlStoredProcedure , sql_ready , sqlExecute , SqlBulkExecute , sqlQuery , sqlExecuteScript
8
- from ..class_sqlThread import (sqlThread , UpgradeDB )
9
- from ..addresses import encodeAddress
4
+ import tempfile
10
5
import threading
6
+ import unittest
11
7
8
+ from pybitmessage .helper_sql import sqlQuery , sql_ready , sqlStoredProcedure
9
+ from pybitmessage .class_sqlThread import sqlThread
10
+ from pybitmessage .addresses import encodeAddress
12
11
13
- def filter_table_column (schema , column ):
14
- """
15
- Filter column from schema
16
- """
17
- for x in schema :
18
- for y in x :
19
- if y == column :
20
- yield y
21
12
22
13
23
14
class TestSqlThread (unittest .TestCase ):
24
- """
25
- Test case for SQLThread
26
- """
27
-
28
- # query file path
29
- __name__ = None
30
- root_path = os .path .dirname (os .path .dirname (__file__ ))
15
+ """Test case for SQL thread"""
31
16
32
17
@classmethod
33
18
def setUpClass (cls ):
34
- """
35
- Start SQL thread
36
- """
19
+ # Start SQL thread
37
20
sqlLookup = sqlThread ()
38
21
sqlLookup .daemon = True
39
22
sqlLookup .start ()
40
23
sql_ready .wait ()
41
24
42
- @classmethod
43
- def setUp (cls ):
44
- """
45
- Drop all tables before each test case start
46
- """
47
- tables = list (sqlQuery ("select name from sqlite_master where type is 'table'" ))
48
- with SqlBulkExecute () as sql :
49
- for q in tables :
50
- sql .execute ("drop table if exists %s" % q )
51
-
52
- @classmethod
53
- def tearDown (cls ):
54
- pass
55
-
56
25
@classmethod
57
26
def tearDownClass (cls ):
58
- """
59
- Join the thread
60
- """
61
27
sqlStoredProcedure ('exit' )
62
28
for thread in threading .enumerate ():
63
29
if thread .name == "SQL" :
64
30
thread .join ()
65
31
66
- def initialise_database (self , file ):
67
- """
68
- Initialise DB
69
- """
70
- with open (os .path .join (self .root_path , "tests/sql/{}.sql" .format (file )), 'r' ) as sql_as_string :
71
- sql_as_string = sql_as_string .read ()
72
-
73
- sqlExecuteScript (sql_as_string )
74
-
75
- def version (self ):
76
- """
77
- Run SQL Scripts, Initialize DB with respect to versioning
78
- and Upgrade DB schema for all versions
79
- """
80
- def wrapper (* args ):
81
- """
82
- Run SQL and mocking DB for versions
83
- """
84
- # import pdb; pdb.set_trace()
85
- self = args [0 ]
86
- func_name = func .__name__
87
- version = func_name .rsplit ('_' , 1 )[- 1 ]
88
-
89
- # Update versions DB mocking
90
- self .initialise_database ("init_version_{}" .format (version ))
91
-
92
- if int (version ) == 9 :
93
- sqlThread ().create_function ()
94
-
95
- # Test versions
96
- upgrade_db = UpgradeDB ()
97
- upgrade_db ._upgrade_one_level_sql_statement (int (version )) # pylint: disable= W0212, protected-access
98
- return func (* args ) # <-- use (self, ...)
99
- func = self
100
- return wrapper
101
-
102
32
def test_create_function (self ):
103
- """
104
- Test create_function and asserting the result
105
- """
106
-
107
- # call create function
33
+ """Check the result of enaddr function"""
108
34
encoded_str = encodeAddress (4 , 1 , "21122112211221122112" )
109
-
110
- # Initialise Database
111
- self .initialise_database ("create_function" )
112
-
113
- sqlExecute ('''INSERT INTO testhash (addressversion, hash) VALUES(4, "21122112211221122112")''' )
114
- # call function in query
115
-
116
- sqlExecute ('''UPDATE testhash SET address=(enaddr(testhash.addressversion, 1, hash));''' )
117
-
118
- # Assertion
119
- query = sqlQuery ('''select * from testhash;''' )
120
- self .assertEqual (query [0 ][- 1 ], encoded_str , "test case fail for create_function" )
121
- sqlExecute ('''DROP TABLE testhash''' )
122
-
123
- @version
124
- def test_sql_thread_version_2 (self ):
125
- """
126
- Test with version 2
127
- """
128
-
129
- # Assertion
130
- res = sqlQuery (''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='inventory_backup' ''' )
131
- self .assertNotEqual (res [0 ][0 ], 1 , "Table inventory_backup not deleted in versioning 2" )
132
-
133
- @version
134
- def test_sql_thread_version_3 (self ):
135
- """
136
- Test with version 1
137
- Version 1 and 3 are same so will skip 3
138
- """
139
-
140
- # Assertion after versioning
141
- res = sqlQuery ('''PRAGMA table_info('inventory');''' )
142
- result = list (filter_table_column (res , "tag" ))
143
- res = [tup for tup in res if any (i in tup for i in ["tag" ])]
144
- self .assertEqual (result , ['tag' ], "Data not migrated for version 1" )
145
- self .assertEqual (res , [(5 , 'tag' , 'blob' , 0 , "''" , 0 )], "Data not migrated for version 1" )
146
-
147
- @version
148
- def test_sql_thread_version_4 (self ):
149
- """
150
- Test with version 4
151
- """
152
-
153
- # Assertion
154
- res = sqlQuery ('''select * from inventory where objecttype = 'pubkey';''' )
155
- self .assertNotEqual (len (res ), 1 , "Table inventory not deleted in versioning 4" )
156
-
157
- @version
158
- def test_sql_thread_version_5 (self ):
159
- """
160
- Test with version 5
161
- """
162
-
163
- # Assertion
164
- res = sqlQuery (''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='knownnodes' ''' )
165
-
166
- self .assertNotEqual (res [0 ][0 ], 1 , "Table knownnodes not deleted in versioning 5" )
167
- res = sqlQuery (''' SELECT count(name) FROM sqlite_master
168
- WHERE type='table' AND name='objectprocessorqueue'; ''' )
169
- self .assertNotEqual (len (res ), 0 , "Table objectprocessorqueue not created in versioning 5" )
170
-
171
- @version
172
- def test_sql_thread_version_6 (self ):
173
- """
174
- Test with version 6
175
- """
176
-
177
- # Assertion
178
-
179
- inventory = sqlQuery ('''PRAGMA table_info('inventory');''' )
180
- inventory = list (filter_table_column (inventory , "expirestime" ))
181
- self .assertEqual (inventory , ['expirestime' ], "Data not migrated for version 6" )
182
-
183
- objectprocessorqueue = sqlQuery ('''PRAGMA table_info('inventory');''' )
184
- objectprocessorqueue = list (filter_table_column (objectprocessorqueue , "objecttype" ))
185
- self .assertEqual (objectprocessorqueue , ['objecttype' ], "Data not migrated for version 6" )
186
-
187
- @version
188
- def test_sql_thread_version_7 (self ):
189
- """
190
- Test with version 7
191
- """
192
-
193
- # Assertion
194
- pubkeys = sqlQuery ('''SELECT * FROM pubkeys ''' )
195
- self .assertEqual (pubkeys , [], "Data not migrated for version 7" )
196
-
197
- inventory = sqlQuery ('''SELECT * FROM inventory ''' )
198
- self .assertEqual (inventory , [], "Data not migrated for version 7" )
199
-
200
- sent = sqlQuery ('''SELECT status FROM sent ''' )
201
- self .assertEqual (sent , [('msgqueued' ,), ('msgqueued' ,)], "Data not migrated for version 7" )
202
-
203
- @version
204
- def test_sql_thread_version_8 (self ):
205
- """
206
- Test with version 8
207
- """
208
-
209
- # Assertion
210
- res = sqlQuery ('''PRAGMA table_info('inbox');''' )
211
- result = list (filter_table_column (res , "sighash" ))
212
- self .assertEqual (result , ['sighash' ], "Data not migrated for version 8" )
213
-
214
- @version
215
- def test_sql_thread_version_9 (self ):
216
- """
217
- Test with version 9
218
- """
219
-
220
- # Assertion
221
- res = sqlQuery (''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='pubkeys_backup' ''' )
222
- self .assertNotEqual (res [0 ][0 ], 1 , "Table pubkeys_backup not deleted" )
223
-
224
- res = sqlQuery ('''PRAGMA table_info('pubkeys');''' )
225
- # res = res.fetchall()
226
- result = list (filter_table_column (res , "address" ))
227
- self .assertEqual (result , ['address' ], "Data not migrated for version 9" )
228
-
229
- @version
230
- def test_sql_thread_version_10 (self ):
231
- """
232
- Test with version 10
233
- """
234
-
235
- # Assertion
236
- res = sqlQuery (''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='old_addressbook' ''' )
237
- self .assertNotEqual (res [0 ][0 ], 1 , "Table old_addressbook not deleted" )
238
- self .assertEqual (len (res ), 1 , "Table old_addressbook not deleted" )
239
-
240
- res = sqlQuery ('''PRAGMA table_info('addressbook');''' )
241
- result = list (filter_table_column (res , "address" ))
242
- self .assertEqual (result , ['address' ], "Data not migrated for version 10" )
35
+ query = sqlQuery ('SELECT enaddr(4, 1, "21122112211221122112")' )
36
+ self .assertEqual (
37
+ query [0 ][- 1 ], encoded_str , "test case fail for create_function" )
0 commit comments