cursors.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. # -*- coding: utf-8 -*-
  2. from __future__ import print_function, absolute_import
  3. from functools import partial
  4. import re
  5. import warnings
  6. from ._compat import range_type, text_type, PY2
  7. from . import err
  8. #: Regular expression for :meth:`Cursor.executemany`.
  9. #: executemany only suports simple bulk insert.
  10. #: You can use it to load large dataset.
  11. RE_INSERT_VALUES = re.compile(
  12. r"\s*((?:INSERT|REPLACE)\b.+\bVALUES?\s*)" +
  13. r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" +
  14. r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z",
  15. re.IGNORECASE | re.DOTALL)
  16. class Cursor(object):
  17. """
  18. This is the object you use to interact with the database.
  19. Do not create an instance of a Cursor yourself. Call
  20. connections.Connection.cursor().
  21. See `Cursor <https://www.python.org/dev/peps/pep-0249/#cursor-objects>`_ in
  22. the specification.
  23. """
  24. #: Max statement size which :meth:`executemany` generates.
  25. #:
  26. #: Max size of allowed statement is max_allowed_packet - packet_header_size.
  27. #: Default value of max_allowed_packet is 1048576.
  28. max_stmt_length = 1024000
  29. _defer_warnings = False
  30. def __init__(self, connection):
  31. self.connection = connection
  32. self.description = None
  33. self.rownumber = 0
  34. self.rowcount = -1
  35. self.arraysize = 1
  36. self._executed = None
  37. self._result = None
  38. self._rows = None
  39. self._warnings_handled = False
  40. def close(self):
  41. """
  42. Closing a cursor just exhausts all remaining data.
  43. """
  44. conn = self.connection
  45. if conn is None:
  46. return
  47. try:
  48. while self.nextset():
  49. pass
  50. finally:
  51. self.connection = None
  52. def __enter__(self):
  53. return self
  54. def __exit__(self, *exc_info):
  55. del exc_info
  56. self.close()
  57. def _get_db(self):
  58. if not self.connection:
  59. raise err.ProgrammingError("Cursor closed")
  60. return self.connection
  61. def _check_executed(self):
  62. if not self._executed:
  63. raise err.ProgrammingError("execute() first")
  64. def _conv_row(self, row):
  65. return row
  66. def setinputsizes(self, *args):
  67. """Does nothing, required by DB API."""
  68. def setoutputsizes(self, *args):
  69. """Does nothing, required by DB API."""
  70. def _nextset(self, unbuffered=False):
  71. """Get the next query set"""
  72. conn = self._get_db()
  73. current_result = self._result
  74. # for unbuffered queries warnings are only available once whole result has been read
  75. if unbuffered:
  76. self._show_warnings()
  77. if current_result is None or current_result is not conn._result:
  78. return None
  79. if not current_result.has_next:
  80. return None
  81. self._result = None
  82. self._clear_result()
  83. conn.next_result(unbuffered=unbuffered)
  84. self._do_get_result()
  85. return True
  86. def nextset(self):
  87. return self._nextset(False)
  88. def _ensure_bytes(self, x, encoding=None):
  89. if isinstance(x, text_type):
  90. x = x.encode(encoding)
  91. elif isinstance(x, (tuple, list)):
  92. x = type(x)(self._ensure_bytes(v, encoding=encoding) for v in x)
  93. return x
  94. def _escape_args(self, args, conn):
  95. ensure_bytes = partial(self._ensure_bytes, encoding=conn.encoding)
  96. if isinstance(args, (tuple, list)):
  97. if PY2:
  98. args = tuple(map(ensure_bytes, args))
  99. return tuple(conn.literal(arg) for arg in args)
  100. elif isinstance(args, dict):
  101. if PY2:
  102. args = {ensure_bytes(key): ensure_bytes(val) for
  103. (key, val) in args.items()}
  104. return {key: conn.literal(val) for (key, val) in args.items()}
  105. else:
  106. # If it's not a dictionary let's try escaping it anyways.
  107. # Worst case it will throw a Value error
  108. if PY2:
  109. args = ensure_bytes(args)
  110. return conn.escape(args)
  111. def mogrify(self, query, args=None):
  112. """
  113. Returns the exact string that is sent to the database by calling the
  114. execute() method.
  115. This method follows the extension to the DB API 2.0 followed by Psycopg.
  116. """
  117. conn = self._get_db()
  118. if PY2: # Use bytes on Python 2 always
  119. query = self._ensure_bytes(query, encoding=conn.encoding)
  120. if args is not None:
  121. query = query % self._escape_args(args, conn)
  122. return query
  123. def execute(self, query, args=None):
  124. """Execute a query
  125. :param str query: Query to execute.
  126. :param args: parameters used with query. (optional)
  127. :type args: tuple, list or dict
  128. :return: Number of affected rows
  129. :rtype: int
  130. If args is a list or tuple, %s can be used as a placeholder in the query.
  131. If args is a dict, %(name)s can be used as a placeholder in the query.
  132. """
  133. while self.nextset():
  134. pass
  135. query = self.mogrify(query, args)
  136. result = self._query(query)
  137. self._executed = query
  138. return result
  139. def executemany(self, query, args):
  140. # type: (str, list) -> int
  141. """Run several data against one query
  142. :param query: query to execute on server
  143. :param args: Sequence of sequences or mappings. It is used as parameter.
  144. :return: Number of rows affected, if any.
  145. This method improves performance on multiple-row INSERT and
  146. REPLACE. Otherwise it is equivalent to looping over args with
  147. execute().
  148. """
  149. if not args:
  150. return
  151. m = RE_INSERT_VALUES.match(query)
  152. if m:
  153. q_prefix = m.group(1) % ()
  154. q_values = m.group(2).rstrip()
  155. q_postfix = m.group(3) or ''
  156. assert q_values[0] == '(' and q_values[-1] == ')'
  157. return self._do_execute_many(q_prefix, q_values, q_postfix, args,
  158. self.max_stmt_length,
  159. self._get_db().encoding)
  160. self.rowcount = sum(self.execute(query, arg) for arg in args)
  161. return self.rowcount
  162. def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encoding):
  163. conn = self._get_db()
  164. escape = self._escape_args
  165. if isinstance(prefix, text_type):
  166. prefix = prefix.encode(encoding)
  167. if PY2 and isinstance(values, text_type):
  168. values = values.encode(encoding)
  169. if isinstance(postfix, text_type):
  170. postfix = postfix.encode(encoding)
  171. sql = bytearray(prefix)
  172. args = iter(args)
  173. v = values % escape(next(args), conn)
  174. if isinstance(v, text_type):
  175. if PY2:
  176. v = v.encode(encoding)
  177. else:
  178. v = v.encode(encoding, 'surrogateescape')
  179. sql += v
  180. rows = 0
  181. for arg in args:
  182. v = values % escape(arg, conn)
  183. if isinstance(v, text_type):
  184. if PY2:
  185. v = v.encode(encoding)
  186. else:
  187. v = v.encode(encoding, 'surrogateescape')
  188. if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
  189. rows += self.execute(sql + postfix)
  190. sql = bytearray(prefix)
  191. else:
  192. sql += b','
  193. sql += v
  194. rows += self.execute(sql + postfix)
  195. self.rowcount = rows
  196. return rows
  197. def callproc(self, procname, args=()):
  198. """Execute stored procedure procname with args
  199. procname -- string, name of procedure to execute on server
  200. args -- Sequence of parameters to use with procedure
  201. Returns the original args.
  202. Compatibility warning: PEP-249 specifies that any modified
  203. parameters must be returned. This is currently impossible
  204. as they are only available by storing them in a server
  205. variable and then retrieved by a query. Since stored
  206. procedures return zero or more result sets, there is no
  207. reliable way to get at OUT or INOUT parameters via callproc.
  208. The server variables are named @_procname_n, where procname
  209. is the parameter above and n is the position of the parameter
  210. (from zero). Once all result sets generated by the procedure
  211. have been fetched, you can issue a SELECT @_procname_0, ...
  212. query using .execute() to get any OUT or INOUT values.
  213. Compatibility warning: The act of calling a stored procedure
  214. itself creates an empty result set. This appears after any
  215. result sets generated by the procedure. This is non-standard
  216. behavior with respect to the DB-API. Be sure to use nextset()
  217. to advance through all result sets; otherwise you may get
  218. disconnected.
  219. """
  220. conn = self._get_db()
  221. if args:
  222. fmt = '@_{0}_%d=%s'.format(procname)
  223. self._query('SET %s' % ','.join(fmt % (index, conn.escape(arg))
  224. for index, arg in enumerate(args)))
  225. self.nextset()
  226. q = "CALL %s(%s)" % (procname,
  227. ','.join(['@_%s_%d' % (procname, i)
  228. for i in range_type(len(args))]))
  229. self._query(q)
  230. self._executed = q
  231. return args
  232. def fetchone(self):
  233. """Fetch the next row"""
  234. self._check_executed()
  235. if self._rows is None or self.rownumber >= len(self._rows):
  236. return None
  237. result = self._rows[self.rownumber]
  238. self.rownumber += 1
  239. return result
  240. def fetchmany(self, size=None):
  241. """Fetch several rows"""
  242. self._check_executed()
  243. if self._rows is None:
  244. return ()
  245. end = self.rownumber + (size or self.arraysize)
  246. result = self._rows[self.rownumber:end]
  247. self.rownumber = min(end, len(self._rows))
  248. return result
  249. def fetchall(self):
  250. """Fetch all the rows"""
  251. self._check_executed()
  252. if self._rows is None:
  253. return ()
  254. if self.rownumber:
  255. result = self._rows[self.rownumber:]
  256. else:
  257. result = self._rows
  258. self.rownumber = len(self._rows)
  259. return result
  260. def scroll(self, value, mode='relative'):
  261. self._check_executed()
  262. if mode == 'relative':
  263. r = self.rownumber + value
  264. elif mode == 'absolute':
  265. r = value
  266. else:
  267. raise err.ProgrammingError("unknown scroll mode %s" % mode)
  268. if not (0 <= r < len(self._rows)):
  269. raise IndexError("out of range")
  270. self.rownumber = r
  271. def _query(self, q):
  272. conn = self._get_db()
  273. self._last_executed = q
  274. self._clear_result()
  275. conn.query(q)
  276. self._do_get_result()
  277. return self.rowcount
  278. def _clear_result(self):
  279. self.rownumber = 0
  280. self._result = None
  281. self.rowcount = 0
  282. self.description = None
  283. self.lastrowid = None
  284. self._rows = None
  285. def _do_get_result(self):
  286. conn = self._get_db()
  287. self._result = result = conn._result
  288. self.rowcount = result.affected_rows
  289. self.description = result.description
  290. self.lastrowid = result.insert_id
  291. self._rows = result.rows
  292. self._warnings_handled = False
  293. if not self._defer_warnings:
  294. self._show_warnings()
  295. def _show_warnings(self):
  296. if self._warnings_handled:
  297. return
  298. self._warnings_handled = True
  299. if self._result and (self._result.has_next or not self._result.warning_count):
  300. return
  301. ws = self._get_db().show_warnings()
  302. if ws is None:
  303. return
  304. for w in ws:
  305. msg = w[-1]
  306. if PY2:
  307. if isinstance(msg, unicode):
  308. msg = msg.encode('utf-8', 'replace')
  309. warnings.warn(err.Warning(*w[1:3]), stacklevel=4)
  310. def __iter__(self):
  311. return iter(self.fetchone, None)
  312. Warning = err.Warning
  313. Error = err.Error
  314. InterfaceError = err.InterfaceError
  315. DatabaseError = err.DatabaseError
  316. DataError = err.DataError
  317. OperationalError = err.OperationalError
  318. IntegrityError = err.IntegrityError
  319. InternalError = err.InternalError
  320. ProgrammingError = err.ProgrammingError
  321. NotSupportedError = err.NotSupportedError
  322. class DictCursorMixin(object):
  323. # You can override this to use OrderedDict or other dict-like types.
  324. dict_type = dict
  325. def _do_get_result(self):
  326. super(DictCursorMixin, self)._do_get_result()
  327. fields = []
  328. if self.description:
  329. for f in self._result.fields:
  330. name = f.name
  331. if name in fields:
  332. name = f.table_name + '.' + name
  333. fields.append(name)
  334. self._fields = fields
  335. if fields and self._rows:
  336. self._rows = [self._conv_row(r) for r in self._rows]
  337. def _conv_row(self, row):
  338. if row is None:
  339. return None
  340. return self.dict_type(zip(self._fields, row))
  341. class DictCursor(DictCursorMixin, Cursor):
  342. """A cursor which returns results as a dictionary"""
  343. class SSCursor(Cursor):
  344. """
  345. Unbuffered Cursor, mainly useful for queries that return a lot of data,
  346. or for connections to remote servers over a slow network.
  347. Instead of copying every row of data into a buffer, this will fetch
  348. rows as needed. The upside of this is the client uses much less memory,
  349. and rows are returned much faster when traveling over a slow network
  350. or if the result set is very big.
  351. There are limitations, though. The MySQL protocol doesn't support
  352. returning the total number of rows, so the only way to tell how many rows
  353. there are is to iterate over every row returned. Also, it currently isn't
  354. possible to scroll backwards, as only the current row is held in memory.
  355. """
  356. _defer_warnings = True
  357. def _conv_row(self, row):
  358. return row
  359. def close(self):
  360. conn = self.connection
  361. if conn is None:
  362. return
  363. if self._result is not None and self._result is conn._result:
  364. self._result._finish_unbuffered_query()
  365. try:
  366. while self.nextset():
  367. pass
  368. finally:
  369. self.connection = None
  370. __del__ = close
  371. def _query(self, q):
  372. conn = self._get_db()
  373. self._last_executed = q
  374. self._clear_result()
  375. conn.query(q, unbuffered=True)
  376. self._do_get_result()
  377. return self.rowcount
  378. def nextset(self):
  379. return self._nextset(unbuffered=True)
  380. def read_next(self):
  381. """Read next row"""
  382. return self._conv_row(self._result._read_rowdata_packet_unbuffered())
  383. def fetchone(self):
  384. """Fetch next row"""
  385. self._check_executed()
  386. row = self.read_next()
  387. if row is None:
  388. self._show_warnings()
  389. return None
  390. self.rownumber += 1
  391. return row
  392. def fetchall(self):
  393. """
  394. Fetch all, as per MySQLdb. Pretty useless for large queries, as
  395. it is buffered. See fetchall_unbuffered(), if you want an unbuffered
  396. generator version of this method.
  397. """
  398. return list(self.fetchall_unbuffered())
  399. def fetchall_unbuffered(self):
  400. """
  401. Fetch all, implemented as a generator, which isn't to standard,
  402. however, it doesn't make sense to return everything in a list, as that
  403. would use ridiculous memory for large result sets.
  404. """
  405. return iter(self.fetchone, None)
  406. def __iter__(self):
  407. return self.fetchall_unbuffered()
  408. def fetchmany(self, size=None):
  409. """Fetch many"""
  410. self._check_executed()
  411. if size is None:
  412. size = self.arraysize
  413. rows = []
  414. for i in range_type(size):
  415. row = self.read_next()
  416. if row is None:
  417. self._show_warnings()
  418. break
  419. rows.append(row)
  420. self.rownumber += 1
  421. return rows
  422. def scroll(self, value, mode='relative'):
  423. self._check_executed()
  424. if mode == 'relative':
  425. if value < 0:
  426. raise err.NotSupportedError(
  427. "Backwards scrolling not supported by this cursor")
  428. for _ in range_type(value):
  429. self.read_next()
  430. self.rownumber += value
  431. elif mode == 'absolute':
  432. if value < self.rownumber:
  433. raise err.NotSupportedError(
  434. "Backwards scrolling not supported by this cursor")
  435. end = value - self.rownumber
  436. for _ in range_type(end):
  437. self.read_next()
  438. self.rownumber = value
  439. else:
  440. raise err.ProgrammingError("unknown scroll mode %s" % mode)
  441. class SSDictCursor(DictCursorMixin, SSCursor):
  442. """An unbuffered cursor, which returns results as a dictionary"""