_auth.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. """
  2. Implements auth methods
  3. """
  4. from ._compat import text_type, PY2
  5. from .constants import CLIENT
  6. from .err import OperationalError
  7. from .util import byte2int, int2byte
  8. try:
  9. from cryptography.hazmat.backends import default_backend
  10. from cryptography.hazmat.primitives import serialization, hashes
  11. from cryptography.hazmat.primitives.asymmetric import padding
  12. _have_cryptography = True
  13. except ImportError:
  14. _have_cryptography = False
  15. from functools import partial
  16. import hashlib
  17. import io
  18. import struct
  19. import warnings
  20. DEBUG = False
  21. SCRAMBLE_LENGTH = 20
  22. sha1_new = partial(hashlib.new, 'sha1')
  23. # mysql_native_password
  24. # https://dev.mysql.com/doc/internals/en/secure-password-authentication.html#packet-Authentication::Native41
  25. def scramble_native_password(password, message):
  26. """Scramble used for mysql_native_password"""
  27. if not password:
  28. return b''
  29. stage1 = sha1_new(password).digest()
  30. stage2 = sha1_new(stage1).digest()
  31. s = sha1_new()
  32. s.update(message[:SCRAMBLE_LENGTH])
  33. s.update(stage2)
  34. result = s.digest()
  35. return _my_crypt(result, stage1)
  36. def _my_crypt(message1, message2):
  37. result = bytearray(message1)
  38. if PY2:
  39. message2 = bytearray(message2)
  40. for i in range(len(result)):
  41. result[i] ^= message2[i]
  42. return bytes(result)
  43. # old_passwords support ported from libmysql/password.c
  44. # https://dev.mysql.com/doc/internals/en/old-password-authentication.html
  45. SCRAMBLE_LENGTH_323 = 8
  46. class RandStruct_323(object):
  47. def __init__(self, seed1, seed2):
  48. self.max_value = 0x3FFFFFFF
  49. self.seed1 = seed1 % self.max_value
  50. self.seed2 = seed2 % self.max_value
  51. def my_rnd(self):
  52. self.seed1 = (self.seed1 * 3 + self.seed2) % self.max_value
  53. self.seed2 = (self.seed1 + self.seed2 + 33) % self.max_value
  54. return float(self.seed1) / float(self.max_value)
  55. def scramble_old_password(password, message):
  56. """Scramble for old_password"""
  57. warnings.warn("old password (for MySQL <4.1) is used. Upgrade your password with newer auth method.\n"
  58. "old password support will be removed in future PyMySQL version")
  59. hash_pass = _hash_password_323(password)
  60. hash_message = _hash_password_323(message[:SCRAMBLE_LENGTH_323])
  61. hash_pass_n = struct.unpack(">LL", hash_pass)
  62. hash_message_n = struct.unpack(">LL", hash_message)
  63. rand_st = RandStruct_323(
  64. hash_pass_n[0] ^ hash_message_n[0], hash_pass_n[1] ^ hash_message_n[1]
  65. )
  66. outbuf = io.BytesIO()
  67. for _ in range(min(SCRAMBLE_LENGTH_323, len(message))):
  68. outbuf.write(int2byte(int(rand_st.my_rnd() * 31) + 64))
  69. extra = int2byte(int(rand_st.my_rnd() * 31))
  70. out = outbuf.getvalue()
  71. outbuf = io.BytesIO()
  72. for c in out:
  73. outbuf.write(int2byte(byte2int(c) ^ byte2int(extra)))
  74. return outbuf.getvalue()
  75. def _hash_password_323(password):
  76. nr = 1345345333
  77. add = 7
  78. nr2 = 0x12345671
  79. # x in py3 is numbers, p27 is chars
  80. for c in [byte2int(x) for x in password if x not in (' ', '\t', 32, 9)]:
  81. nr ^= (((nr & 63) + add) * c) + (nr << 8) & 0xFFFFFFFF
  82. nr2 = (nr2 + ((nr2 << 8) ^ nr)) & 0xFFFFFFFF
  83. add = (add + c) & 0xFFFFFFFF
  84. r1 = nr & ((1 << 31) - 1) # kill sign bits
  85. r2 = nr2 & ((1 << 31) - 1)
  86. return struct.pack(">LL", r1, r2)
  87. # sha256_password
  88. def _roundtrip(conn, send_data):
  89. conn.write_packet(send_data)
  90. pkt = conn._read_packet()
  91. pkt.check_error()
  92. return pkt
  93. def _xor_password(password, salt):
  94. password_bytes = bytearray(password)
  95. salt = bytearray(salt) # for PY2 compat.
  96. salt_len = len(salt)
  97. for i in range(len(password_bytes)):
  98. password_bytes[i] ^= salt[i % salt_len]
  99. return bytes(password_bytes)
  100. def sha2_rsa_encrypt(password, salt, public_key):
  101. """Encrypt password with salt and public_key.
  102. Used for sha256_password and caching_sha2_password.
  103. """
  104. if not _have_cryptography:
  105. raise RuntimeError("cryptography is required for sha256_password or caching_sha2_password")
  106. message = _xor_password(password + b'\0', salt)
  107. rsa_key = serialization.load_pem_public_key(public_key, default_backend())
  108. return rsa_key.encrypt(
  109. message,
  110. padding.OAEP(
  111. mgf=padding.MGF1(algorithm=hashes.SHA1()),
  112. algorithm=hashes.SHA1(),
  113. label=None,
  114. ),
  115. )
  116. def sha256_password_auth(conn, pkt):
  117. if conn._secure:
  118. if DEBUG:
  119. print("sha256: Sending plain password")
  120. data = conn.password + b'\0'
  121. return _roundtrip(conn, data)
  122. if pkt.is_auth_switch_request():
  123. conn.salt = pkt.read_all()
  124. if not conn.server_public_key and conn.password:
  125. # Request server public key
  126. if DEBUG:
  127. print("sha256: Requesting server public key")
  128. pkt = _roundtrip(conn, b'\1')
  129. if pkt.is_extra_auth_data():
  130. conn.server_public_key = pkt._data[1:]
  131. if DEBUG:
  132. print("Received public key:\n", conn.server_public_key.decode('ascii'))
  133. if conn.password:
  134. if not conn.server_public_key:
  135. raise OperationalError("Couldn't receive server's public key")
  136. data = sha2_rsa_encrypt(conn.password, conn.salt, conn.server_public_key)
  137. else:
  138. data = b''
  139. return _roundtrip(conn, data)
  140. def scramble_caching_sha2(password, nonce):
  141. # (bytes, bytes) -> bytes
  142. """Scramble algorithm used in cached_sha2_password fast path.
  143. XOR(SHA256(password), SHA256(SHA256(SHA256(password)), nonce))
  144. """
  145. if not password:
  146. return b''
  147. p1 = hashlib.sha256(password).digest()
  148. p2 = hashlib.sha256(p1).digest()
  149. p3 = hashlib.sha256(p2 + nonce).digest()
  150. res = bytearray(p1)
  151. if PY2:
  152. p3 = bytearray(p3)
  153. for i in range(len(p3)):
  154. res[i] ^= p3[i]
  155. return bytes(res)
  156. def caching_sha2_password_auth(conn, pkt):
  157. # No password fast path
  158. if not conn.password:
  159. return _roundtrip(conn, b'')
  160. if pkt.is_auth_switch_request():
  161. # Try from fast auth
  162. if DEBUG:
  163. print("caching sha2: Trying fast path")
  164. conn.salt = pkt.read_all()
  165. scrambled = scramble_caching_sha2(conn.password, conn.salt)
  166. pkt = _roundtrip(conn, scrambled)
  167. # else: fast auth is tried in initial handshake
  168. if not pkt.is_extra_auth_data():
  169. raise OperationalError(
  170. "caching sha2: Unknown packet for fast auth: %s" % pkt._data[:1]
  171. )
  172. # magic numbers:
  173. # 2 - request public key
  174. # 3 - fast auth succeeded
  175. # 4 - need full auth
  176. pkt.advance(1)
  177. n = pkt.read_uint8()
  178. if n == 3:
  179. if DEBUG:
  180. print("caching sha2: succeeded by fast path.")
  181. pkt = conn._read_packet()
  182. pkt.check_error() # pkt must be OK packet
  183. return pkt
  184. if n != 4:
  185. raise OperationalError("caching sha2: Unknwon result for fast auth: %s" % n)
  186. if DEBUG:
  187. print("caching sha2: Trying full auth...")
  188. if conn._secure:
  189. if DEBUG:
  190. print("caching sha2: Sending plain password via secure connection")
  191. return _roundtrip(conn, conn.password + b'\0')
  192. if not conn.server_public_key:
  193. pkt = _roundtrip(conn, b'\x02') # Request public key
  194. if not pkt.is_extra_auth_data():
  195. raise OperationalError(
  196. "caching sha2: Unknown packet for public key: %s" % pkt._data[:1]
  197. )
  198. conn.server_public_key = pkt._data[1:]
  199. if DEBUG:
  200. print(conn.server_public_key.decode('ascii'))
  201. data = sha2_rsa_encrypt(conn.password, conn.salt, conn.server_public_key)
  202. pkt = _roundtrip(conn, data)