|
27 | 27 |
|
28 | 28 | def func_returntext(): |
29 | 29 | return "foo" |
| 30 | +def func_returntextwithnull(): |
| 31 | + return "1\x002" |
30 | 32 | def func_returnunicode(): |
31 | 33 | return "bar" |
32 | 34 | def func_returnint(): |
@@ -137,11 +139,21 @@ def step(self, val): |
137 | 139 | def finalize(self): |
138 | 140 | return self.val |
139 | 141 |
|
| 142 | +class AggrText: |
| 143 | + def __init__(self): |
| 144 | + self.txt = "" |
| 145 | + def step(self, txt): |
| 146 | + self.txt = self.txt + txt |
| 147 | + def finalize(self): |
| 148 | + return self.txt |
| 149 | + |
| 150 | + |
140 | 151 | class FunctionTests(unittest.TestCase): |
141 | 152 | def setUp(self): |
142 | 153 | self.con = sqlite.connect(":memory:") |
143 | 154 |
|
144 | 155 | self.con.create_function("returntext", 0, func_returntext) |
| 156 | + self.con.create_function("returntextwithnull", 0, func_returntextwithnull) |
145 | 157 | self.con.create_function("returnunicode", 0, func_returnunicode) |
146 | 158 | self.con.create_function("returnint", 0, func_returnint) |
147 | 159 | self.con.create_function("returnfloat", 0, func_returnfloat) |
@@ -185,6 +197,12 @@ def CheckFuncReturnText(self): |
185 | 197 | self.assertEqual(type(val), str) |
186 | 198 | self.assertEqual(val, "foo") |
187 | 199 |
|
| 200 | + def CheckFuncReturnTextWithNullChar(self): |
| 201 | + cur = self.con.cursor() |
| 202 | + res = cur.execute("select returntextwithnull()").fetchone()[0] |
| 203 | + self.assertEqual(type(res), str) |
| 204 | + self.assertEqual(res, "1\x002") |
| 205 | + |
188 | 206 | def CheckFuncReturnUnicode(self): |
189 | 207 | cur = self.con.cursor() |
190 | 208 | cur.execute("select returnunicode()") |
@@ -343,6 +361,7 @@ def setUp(self): |
343 | 361 | self.con.create_aggregate("checkType", 2, AggrCheckType) |
344 | 362 | self.con.create_aggregate("checkTypes", -1, AggrCheckTypes) |
345 | 363 | self.con.create_aggregate("mysum", 1, AggrSum) |
| 364 | + self.con.create_aggregate("aggtxt", 1, AggrText) |
346 | 365 |
|
347 | 366 | def tearDown(self): |
348 | 367 | #self.cur.close() |
@@ -431,6 +450,15 @@ def CheckAggrCheckAggrSum(self): |
431 | 450 | val = cur.fetchone()[0] |
432 | 451 | self.assertEqual(val, 60) |
433 | 452 |
|
| 453 | + def CheckAggrText(self): |
| 454 | + cur = self.con.cursor() |
| 455 | + for txt in ["foo", "1\x002"]: |
| 456 | + with self.subTest(txt=txt): |
| 457 | + cur.execute("select aggtxt(?) from test", (txt,)) |
| 458 | + val = cur.fetchone()[0] |
| 459 | + self.assertEqual(val, txt) |
| 460 | + |
| 461 | + |
434 | 462 | class AuthorizerTests(unittest.TestCase): |
435 | 463 | @staticmethod |
436 | 464 | def authorizer_cb(action, arg1, arg2, dbname, source): |
|
0 commit comments