protocol.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. # Python implementation of low level MySQL client-server protocol
  2. # http://dev.mysql.com/doc/internals/en/client-server-protocol.html
  3. from __future__ import print_function
  4. from .charset import MBLENGTH
  5. from ._compat import PY2, range_type
  6. from .constants import FIELD_TYPE, SERVER_STATUS
  7. from . import err
  8. from .util import byte2int
  9. import struct
  10. import sys
  11. DEBUG = False
  12. NULL_COLUMN = 251
  13. UNSIGNED_CHAR_COLUMN = 251
  14. UNSIGNED_SHORT_COLUMN = 252
  15. UNSIGNED_INT24_COLUMN = 253
  16. UNSIGNED_INT64_COLUMN = 254
  17. def dump_packet(data): # pragma: no cover
  18. def printable(data):
  19. if 32 <= byte2int(data) < 127:
  20. if isinstance(data, int):
  21. return chr(data)
  22. return data
  23. return '.'
  24. try:
  25. print("packet length:", len(data))
  26. for i in range(1, 7):
  27. f = sys._getframe(i)
  28. print("call[%d]: %s (line %d)" % (i, f.f_code.co_name, f.f_lineno))
  29. print("-" * 66)
  30. except ValueError:
  31. pass
  32. dump_data = [data[i:i+16] for i in range_type(0, min(len(data), 256), 16)]
  33. for d in dump_data:
  34. print(' '.join("{:02X}".format(byte2int(x)) for x in d) +
  35. ' ' * (16 - len(d)) + ' ' * 2 +
  36. ''.join(printable(x) for x in d))
  37. print("-" * 66)
  38. print()
  39. class MysqlPacket(object):
  40. """Representation of a MySQL response packet.
  41. Provides an interface for reading/parsing the packet results.
  42. """
  43. __slots__ = ('_position', '_data')
  44. def __init__(self, data, encoding):
  45. self._position = 0
  46. self._data = data
  47. def get_all_data(self):
  48. return self._data
  49. def read(self, size):
  50. """Read the first 'size' bytes in packet and advance cursor past them."""
  51. result = self._data[self._position:(self._position+size)]
  52. if len(result) != size:
  53. error = ('Result length not requested length:\n'
  54. 'Expected=%s. Actual=%s. Position: %s. Data Length: %s'
  55. % (size, len(result), self._position, len(self._data)))
  56. if DEBUG:
  57. print(error)
  58. self.dump()
  59. raise AssertionError(error)
  60. self._position += size
  61. return result
  62. def read_all(self):
  63. """Read all remaining data in the packet.
  64. (Subsequent read() will return errors.)
  65. """
  66. result = self._data[self._position:]
  67. self._position = None # ensure no subsequent read()
  68. return result
  69. def advance(self, length):
  70. """Advance the cursor in data buffer 'length' bytes."""
  71. new_position = self._position + length
  72. if new_position < 0 or new_position > len(self._data):
  73. raise Exception('Invalid advance amount (%s) for cursor. '
  74. 'Position=%s' % (length, new_position))
  75. self._position = new_position
  76. def rewind(self, position=0):
  77. """Set the position of the data buffer cursor to 'position'."""
  78. if position < 0 or position > len(self._data):
  79. raise Exception("Invalid position to rewind cursor to: %s." % position)
  80. self._position = position
  81. def get_bytes(self, position, length=1):
  82. """Get 'length' bytes starting at 'position'.
  83. Position is start of payload (first four packet header bytes are not
  84. included) starting at index '0'.
  85. No error checking is done. If requesting outside end of buffer
  86. an empty string (or string shorter than 'length') may be returned!
  87. """
  88. return self._data[position:(position+length)]
  89. if PY2:
  90. def read_uint8(self):
  91. result = ord(self._data[self._position])
  92. self._position += 1
  93. return result
  94. else:
  95. def read_uint8(self):
  96. result = self._data[self._position]
  97. self._position += 1
  98. return result
  99. def read_uint16(self):
  100. result = struct.unpack_from('<H', self._data, self._position)[0]
  101. self._position += 2
  102. return result
  103. def read_uint24(self):
  104. low, high = struct.unpack_from('<HB', self._data, self._position)
  105. self._position += 3
  106. return low + (high << 16)
  107. def read_uint32(self):
  108. result = struct.unpack_from('<I', self._data, self._position)[0]
  109. self._position += 4
  110. return result
  111. def read_uint64(self):
  112. result = struct.unpack_from('<Q', self._data, self._position)[0]
  113. self._position += 8
  114. return result
  115. def read_string(self):
  116. end_pos = self._data.find(b'\0', self._position)
  117. if end_pos < 0:
  118. return None
  119. result = self._data[self._position:end_pos]
  120. self._position = end_pos + 1
  121. return result
  122. def read_length_encoded_integer(self):
  123. """Read a 'Length Coded Binary' number from the data buffer.
  124. Length coded numbers can be anywhere from 1 to 9 bytes depending
  125. on the value of the first byte.
  126. """
  127. c = self.read_uint8()
  128. if c == NULL_COLUMN:
  129. return None
  130. if c < UNSIGNED_CHAR_COLUMN:
  131. return c
  132. elif c == UNSIGNED_SHORT_COLUMN:
  133. return self.read_uint16()
  134. elif c == UNSIGNED_INT24_COLUMN:
  135. return self.read_uint24()
  136. elif c == UNSIGNED_INT64_COLUMN:
  137. return self.read_uint64()
  138. def read_length_coded_string(self):
  139. """Read a 'Length Coded String' from the data buffer.
  140. A 'Length Coded String' consists first of a length coded
  141. (unsigned, positive) integer represented in 1-9 bytes followed by
  142. that many bytes of binary data. (For example "cat" would be "3cat".)
  143. """
  144. length = self.read_length_encoded_integer()
  145. if length is None:
  146. return None
  147. return self.read(length)
  148. def read_struct(self, fmt):
  149. s = struct.Struct(fmt)
  150. result = s.unpack_from(self._data, self._position)
  151. self._position += s.size
  152. return result
  153. def is_ok_packet(self):
  154. # https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
  155. return self._data[0:1] == b'\0' and len(self._data) >= 7
  156. def is_eof_packet(self):
  157. # http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet
  158. # Caution: \xFE may be LengthEncodedInteger.
  159. # If \xFE is LengthEncodedInteger header, 8bytes followed.
  160. return self._data[0:1] == b'\xfe' and len(self._data) < 9
  161. def is_auth_switch_request(self):
  162. # http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
  163. return self._data[0:1] == b'\xfe'
  164. def is_extra_auth_data(self):
  165. # https://dev.mysql.com/doc/internals/en/successful-authentication.html
  166. return self._data[0:1] == b'\x01'
  167. def is_resultset_packet(self):
  168. field_count = ord(self._data[0:1])
  169. return 1 <= field_count <= 250
  170. def is_load_local_packet(self):
  171. return self._data[0:1] == b'\xfb'
  172. def is_error_packet(self):
  173. return self._data[0:1] == b'\xff'
  174. def check_error(self):
  175. if self.is_error_packet():
  176. self.rewind()
  177. self.advance(1) # field_count == error (we already know that)
  178. errno = self.read_uint16()
  179. if DEBUG: print("errno =", errno)
  180. err.raise_mysql_exception(self._data)
  181. def dump(self):
  182. dump_packet(self._data)
  183. class FieldDescriptorPacket(MysqlPacket):
  184. """A MysqlPacket that represents a specific column's metadata in the result.
  185. Parsing is automatically done and the results are exported via public
  186. attributes on the class such as: db, table_name, name, length, type_code.
  187. """
  188. def __init__(self, data, encoding):
  189. MysqlPacket.__init__(self, data, encoding)
  190. self._parse_field_descriptor(encoding)
  191. def _parse_field_descriptor(self, encoding):
  192. """Parse the 'Field Descriptor' (Metadata) packet.
  193. This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0).
  194. """
  195. self.catalog = self.read_length_coded_string()
  196. self.db = self.read_length_coded_string()
  197. self.table_name = self.read_length_coded_string().decode(encoding)
  198. self.org_table = self.read_length_coded_string().decode(encoding)
  199. self.name = self.read_length_coded_string().decode(encoding)
  200. self.org_name = self.read_length_coded_string().decode(encoding)
  201. self.charsetnr, self.length, self.type_code, self.flags, self.scale = (
  202. self.read_struct('<xHIBHBxx'))
  203. # 'default' is a length coded binary and is still in the buffer?
  204. # not used for normal result sets...
  205. def description(self):
  206. """Provides a 7-item tuple compatible with the Python PEP249 DB Spec."""
  207. return (
  208. self.name,
  209. self.type_code,
  210. None, # TODO: display_length; should this be self.length?
  211. self.get_column_length(), # 'internal_size'
  212. self.get_column_length(), # 'precision' # TODO: why!?!?
  213. self.scale,
  214. self.flags % 2 == 0)
  215. def get_column_length(self):
  216. if self.type_code == FIELD_TYPE.VAR_STRING:
  217. mblen = MBLENGTH.get(self.charsetnr, 1)
  218. return self.length // mblen
  219. return self.length
  220. def __str__(self):
  221. return ('%s %r.%r.%r, type=%s, flags=%x'
  222. % (self.__class__, self.db, self.table_name, self.name,
  223. self.type_code, self.flags))
  224. class OKPacketWrapper(object):
  225. """
  226. OK Packet Wrapper. It uses an existing packet object, and wraps
  227. around it, exposing useful variables while still providing access
  228. to the original packet objects variables and methods.
  229. """
  230. def __init__(self, from_packet):
  231. if not from_packet.is_ok_packet():
  232. raise ValueError('Cannot create ' + str(self.__class__.__name__) +
  233. ' object from invalid packet type')
  234. self.packet = from_packet
  235. self.packet.advance(1)
  236. self.affected_rows = self.packet.read_length_encoded_integer()
  237. self.insert_id = self.packet.read_length_encoded_integer()
  238. self.server_status, self.warning_count = self.read_struct('<HH')
  239. self.message = self.packet.read_all()
  240. self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS
  241. def __getattr__(self, key):
  242. return getattr(self.packet, key)
  243. class EOFPacketWrapper(object):
  244. """
  245. EOF Packet Wrapper. It uses an existing packet object, and wraps
  246. around it, exposing useful variables while still providing access
  247. to the original packet objects variables and methods.
  248. """
  249. def __init__(self, from_packet):
  250. if not from_packet.is_eof_packet():
  251. raise ValueError(
  252. "Cannot create '{0}' object from invalid packet type".format(
  253. self.__class__))
  254. self.packet = from_packet
  255. self.warning_count, self.server_status = self.packet.read_struct('<xhh')
  256. if DEBUG: print("server_status=", self.server_status)
  257. self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS
  258. def __getattr__(self, key):
  259. return getattr(self.packet, key)
  260. class LoadLocalPacketWrapper(object):
  261. """
  262. Load Local Packet Wrapper. It uses an existing packet object, and wraps
  263. around it, exposing useful variables while still providing access
  264. to the original packet objects variables and methods.
  265. """
  266. def __init__(self, from_packet):
  267. if not from_packet.is_load_local_packet():
  268. raise ValueError(
  269. "Cannot create '{0}' object from invalid packet type".format(
  270. self.__class__))
  271. self.packet = from_packet
  272. self.filename = self.packet.get_all_data()[1:]
  273. if DEBUG: print("filename=", self.filename)
  274. def __getattr__(self, key):
  275. return getattr(self.packet, key)