|
15 | 15 | NotSupportedError, ProgrammingError)
|
16 | 16 |
|
17 | 17 |
|
18 |
| -PY2 = sys.version_info[0] == 2 |
19 |
| -if PY2: |
20 |
| - text_type = unicode |
21 |
| -else: |
22 |
| - text_type = str |
23 |
| - |
24 |
| - |
25 | 18 | #: Regular expression for :meth:`Cursor.executemany`.
|
26 | 19 | #: executemany only supports simple bulk insert.
|
27 | 20 | #: You can use it to load large dataset.
|
@@ -95,31 +88,28 @@ def __exit__(self, *exc_info):
|
95 | 88 | del exc_info
|
96 | 89 | self.close()
|
97 | 90 |
|
98 |
| - def _ensure_bytes(self, x, encoding=None): |
99 |
| - if isinstance(x, text_type): |
100 |
| - x = x.encode(encoding) |
101 |
| - elif isinstance(x, (tuple, list)): |
102 |
| - x = type(x)(self._ensure_bytes(v, encoding=encoding) for v in x) |
103 |
| - return x |
104 |
| - |
105 | 91 | def _escape_args(self, args, conn):
|
106 |
| - ensure_bytes = partial(self._ensure_bytes, encoding=conn.encoding) |
| 92 | + encoding = conn.encoding |
| 93 | + literal = conn.literal |
| 94 | + |
| 95 | + def ensure_bytes(x): |
| 96 | + if isinstance(x, unicode): |
| 97 | + return x.encode(encoding) |
| 98 | + elif isinstance(x, tuple): |
| 99 | + return tuple(map(ensure_bytes, x)) |
| 100 | + elif isinstance(x, list): |
| 101 | + return list(map(ensure_bytes, x)) |
| 102 | + return x |
107 | 103 |
|
108 | 104 | if isinstance(args, (tuple, list)):
|
109 |
| - if PY2: |
110 |
| - args = tuple(map(ensure_bytes, args)) |
111 |
| - return tuple(conn.literal(arg) for arg in args) |
| 105 | + return tuple(literal(ensure_bytes(arg)) for arg in args) |
112 | 106 | elif isinstance(args, dict):
|
113 |
| - if PY2: |
114 |
| - args = dict((ensure_bytes(key), ensure_bytes(val)) for |
115 |
| - (key, val) in args.items()) |
116 |
| - return dict((key, conn.literal(val)) for (key, val) in args.items()) |
| 107 | + return {ensure_bytes(key): literal(ensure_bytes(val)) |
| 108 | + for (key, val) in args.items()} |
117 | 109 | else:
|
118 | 110 | # If it's not a dictionary let's try escaping it anyways.
|
119 | 111 | # Worst case it will throw a Value error
|
120 |
| - if PY2: |
121 |
| - args = ensure_bytes(args) |
122 |
| - return conn.literal(args) |
| 112 | + return literal(ensure_bytes(args)) |
123 | 113 |
|
124 | 114 | def _check_executed(self):
|
125 | 115 | if not self._executed:
|
@@ -186,31 +176,20 @@ def execute(self, query, args=None):
|
186 | 176 | pass
|
187 | 177 | db = self._get_db()
|
188 | 178 |
|
189 |
| - # NOTE: |
190 |
| - # Python 2: query should be bytes when executing %. |
191 |
| - # All unicode in args should be encoded to bytes on Python 2. |
192 |
| - # Python 3: query should be str (unicode) when executing %. |
193 |
| - # All bytes in args should be decoded with ascii and surrogateescape on Python 3. |
194 |
| - # db.literal(obj) always returns str. |
195 |
| - |
196 |
| - if PY2 and isinstance(query, unicode): |
| 179 | + if isinstance(query, unicode): |
197 | 180 | query = query.encode(db.encoding)
|
198 | 181 |
|
199 | 182 | if args is not None:
|
200 | 183 | if isinstance(args, dict):
|
201 | 184 | args = dict((key, db.literal(item)) for key, item in args.items())
|
202 | 185 | else:
|
203 | 186 | args = tuple(map(db.literal, args))
|
204 |
| - if not PY2 and isinstance(query, (bytes, bytearray)): |
205 |
| - query = query.decode(db.encoding) |
206 | 187 | try:
|
207 | 188 | query = query % args
|
208 | 189 | except TypeError as m:
|
209 | 190 | raise ProgrammingError(str(m))
|
210 | 191 |
|
211 |
| - if isinstance(query, unicode): |
212 |
| - query = query.encode(db.encoding, 'surrogateescape') |
213 |
| - |
| 192 | + assert isinstance(query, (bytes, bytearray)) |
214 | 193 | res = self._query(query)
|
215 | 194 | return res
|
216 | 195 |
|
@@ -247,29 +226,19 @@ def executemany(self, query, args):
|
247 | 226 | def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encoding):
|
248 | 227 | conn = self._get_db()
|
249 | 228 | escape = self._escape_args
|
250 |
| - if isinstance(prefix, text_type): |
| 229 | + if isinstance(prefix, unicode): |
251 | 230 | prefix = prefix.encode(encoding)
|
252 |
| - if PY2 and isinstance(values, text_type): |
| 231 | + if isinstance(values, unicode): |
253 | 232 | values = values.encode(encoding)
|
254 |
| - if isinstance(postfix, text_type): |
| 233 | + if isinstance(postfix, unicode): |
255 | 234 | postfix = postfix.encode(encoding)
|
256 | 235 | sql = bytearray(prefix)
|
257 | 236 | args = iter(args)
|
258 | 237 | v = values % escape(next(args), conn)
|
259 |
| - if isinstance(v, text_type): |
260 |
| - if PY2: |
261 |
| - v = v.encode(encoding) |
262 |
| - else: |
263 |
| - v = v.encode(encoding, 'surrogateescape') |
264 | 238 | sql += v
|
265 | 239 | rows = 0
|
266 | 240 | for arg in args:
|
267 | 241 | v = values % escape(arg, conn)
|
268 |
| - if isinstance(v, text_type): |
269 |
| - if PY2: |
270 |
| - v = v.encode(encoding) |
271 |
| - else: |
272 |
| - v = v.encode(encoding, 'surrogateescape') |
273 | 242 | if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
|
274 | 243 | rows += self.execute(sql + postfix)
|
275 | 244 | sql = bytearray(prefix)
|
@@ -308,22 +277,19 @@ def callproc(self, procname, args=()):
|
308 | 277 | to advance through all result sets; otherwise you may get
|
309 | 278 | disconnected.
|
310 | 279 | """
|
311 |
| - |
312 | 280 | db = self._get_db()
|
| 281 | + if isinstance(procname, unicode): |
| 282 | + procname = procname.encode(db.encoding) |
313 | 283 | if args:
|
314 |
| - fmt = '@_{0}_%d=%s'.format(procname) |
315 |
| - q = 'SET %s' % ','.join(fmt % (index, db.literal(arg)) |
316 |
| - for index, arg in enumerate(args)) |
317 |
| - if isinstance(q, unicode): |
318 |
| - q = q.encode(db.encoding, 'surrogateescape') |
| 284 | + fmt = b'@_' + procname + b'_%d=%s' |
| 285 | + q = b'SET %s' % b','.join(fmt % (index, db.literal(arg)) |
| 286 | + for index, arg in enumerate(args)) |
319 | 287 | self._query(q)
|
320 | 288 | self.nextset()
|
321 | 289 |
|
322 |
| - q = "CALL %s(%s)" % (procname, |
323 |
| - ','.join(['@_%s_%d' % (procname, i) |
324 |
| - for i in range(len(args))])) |
325 |
| - if isinstance(q, unicode): |
326 |
| - q = q.encode(db.encoding, 'surrogateescape') |
| 290 | + q = b"CALL %s(%s)" % (procname, |
| 291 | + b','.join([b'@_%s_%d' % (procname, i) |
| 292 | + for i in range(len(args))])) |
327 | 293 | self._query(q)
|
328 | 294 | return args
|
329 | 295 |
|
|
0 commit comments