本站源代码
選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

1287 行
31KB

  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
  4. //
  5. // This Source Code Form is subject to the terms of the Mozilla Public
  6. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  7. // You can obtain one at http://mozilla.org/MPL/2.0/.
  8. package mysql
  9. import (
  10. "bytes"
  11. "crypto/tls"
  12. "database/sql/driver"
  13. "encoding/binary"
  14. "errors"
  15. "fmt"
  16. "io"
  17. "math"
  18. "time"
  19. )
  20. // Packets documentation:
  21. // http://dev.mysql.com/doc/internals/en/client-server-protocol.html
  22. // Read packet to buffer 'data'
  23. func (mc *mysqlConn) readPacket() ([]byte, error) {
  24. var prevData []byte
  25. for {
  26. // read packet header
  27. data, err := mc.buf.readNext(4)
  28. if err != nil {
  29. if cerr := mc.canceled.Value(); cerr != nil {
  30. return nil, cerr
  31. }
  32. errLog.Print(err)
  33. mc.Close()
  34. return nil, ErrInvalidConn
  35. }
  36. // packet length [24 bit]
  37. pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16)
  38. // check packet sync [8 bit]
  39. if data[3] != mc.sequence {
  40. if data[3] > mc.sequence {
  41. return nil, ErrPktSyncMul
  42. }
  43. return nil, ErrPktSync
  44. }
  45. mc.sequence++
  46. // packets with length 0 terminate a previous packet which is a
  47. // multiple of (2^24)−1 bytes long
  48. if pktLen == 0 {
  49. // there was no previous packet
  50. if prevData == nil {
  51. errLog.Print(ErrMalformPkt)
  52. mc.Close()
  53. return nil, ErrInvalidConn
  54. }
  55. return prevData, nil
  56. }
  57. // read packet body [pktLen bytes]
  58. data, err = mc.buf.readNext(pktLen)
  59. if err != nil {
  60. if cerr := mc.canceled.Value(); cerr != nil {
  61. return nil, cerr
  62. }
  63. errLog.Print(err)
  64. mc.Close()
  65. return nil, ErrInvalidConn
  66. }
  67. // return data if this was the last packet
  68. if pktLen < maxPacketSize {
  69. // zero allocations for non-split packets
  70. if prevData == nil {
  71. return data, nil
  72. }
  73. return append(prevData, data...), nil
  74. }
  75. prevData = append(prevData, data...)
  76. }
  77. }
  78. // Write packet buffer 'data'
  79. func (mc *mysqlConn) writePacket(data []byte) error {
  80. pktLen := len(data) - 4
  81. if pktLen > mc.maxAllowedPacket {
  82. return ErrPktTooLarge
  83. }
  84. for {
  85. var size int
  86. if pktLen >= maxPacketSize {
  87. data[0] = 0xff
  88. data[1] = 0xff
  89. data[2] = 0xff
  90. size = maxPacketSize
  91. } else {
  92. data[0] = byte(pktLen)
  93. data[1] = byte(pktLen >> 8)
  94. data[2] = byte(pktLen >> 16)
  95. size = pktLen
  96. }
  97. data[3] = mc.sequence
  98. // Write packet
  99. if mc.writeTimeout > 0 {
  100. if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil {
  101. return err
  102. }
  103. }
  104. n, err := mc.netConn.Write(data[:4+size])
  105. if err == nil && n == 4+size {
  106. mc.sequence++
  107. if size != maxPacketSize {
  108. return nil
  109. }
  110. pktLen -= size
  111. data = data[size:]
  112. continue
  113. }
  114. // Handle error
  115. if err == nil { // n != len(data)
  116. mc.cleanup()
  117. errLog.Print(ErrMalformPkt)
  118. } else {
  119. if cerr := mc.canceled.Value(); cerr != nil {
  120. return cerr
  121. }
  122. if n == 0 && pktLen == len(data)-4 {
  123. // only for the first loop iteration when nothing was written yet
  124. return errBadConnNoWrite
  125. }
  126. mc.cleanup()
  127. errLog.Print(err)
  128. }
  129. return ErrInvalidConn
  130. }
  131. }
  132. /******************************************************************************
  133. * Initialization Process *
  134. ******************************************************************************/
  135. // Handshake Initialization Packet
  136. // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
  137. func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) {
  138. data, err = mc.readPacket()
  139. if err != nil {
  140. // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
  141. // in connection initialization we don't risk retrying non-idempotent actions.
  142. if err == ErrInvalidConn {
  143. return nil, "", driver.ErrBadConn
  144. }
  145. return
  146. }
  147. if data[0] == iERR {
  148. return nil, "", mc.handleErrorPacket(data)
  149. }
  150. // protocol version [1 byte]
  151. if data[0] < minProtocolVersion {
  152. return nil, "", fmt.Errorf(
  153. "unsupported protocol version %d. Version %d or higher is required",
  154. data[0],
  155. minProtocolVersion,
  156. )
  157. }
  158. // server version [null terminated string]
  159. // connection id [4 bytes]
  160. pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4
  161. // first part of the password cipher [8 bytes]
  162. authData := data[pos : pos+8]
  163. // (filler) always 0x00 [1 byte]
  164. pos += 8 + 1
  165. // capability flags (lower 2 bytes) [2 bytes]
  166. mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
  167. if mc.flags&clientProtocol41 == 0 {
  168. return nil, "", ErrOldProtocol
  169. }
  170. if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
  171. return nil, "", ErrNoTLS
  172. }
  173. pos += 2
  174. if len(data) > pos {
  175. // character set [1 byte]
  176. // status flags [2 bytes]
  177. // capability flags (upper 2 bytes) [2 bytes]
  178. // length of auth-plugin-data [1 byte]
  179. // reserved (all [00]) [10 bytes]
  180. pos += 1 + 2 + 2 + 1 + 10
  181. // second part of the password cipher [mininum 13 bytes],
  182. // where len=MAX(13, length of auth-plugin-data - 8)
  183. //
  184. // The web documentation is ambiguous about the length. However,
  185. // according to mysql-5.7/sql/auth/sql_authentication.cc line 538,
  186. // the 13th byte is "\0 byte, terminating the second part of
  187. // a scramble". So the second part of the password cipher is
  188. // a NULL terminated string that's at least 13 bytes with the
  189. // last byte being NULL.
  190. //
  191. // The official Python library uses the fixed length 12
  192. // which seems to work but technically could have a hidden bug.
  193. authData = append(authData, data[pos:pos+12]...)
  194. pos += 13
  195. // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
  196. // \NUL otherwise
  197. if end := bytes.IndexByte(data[pos:], 0x00); end != -1 {
  198. plugin = string(data[pos : pos+end])
  199. } else {
  200. plugin = string(data[pos:])
  201. }
  202. // make a memory safe copy of the cipher slice
  203. var b [20]byte
  204. copy(b[:], authData)
  205. return b[:], plugin, nil
  206. }
  207. // make a memory safe copy of the cipher slice
  208. var b [8]byte
  209. copy(b[:], authData)
  210. return b[:], plugin, nil
  211. }
  212. // Client Authentication Packet
  213. // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
  214. func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error {
  215. // Adjust client flags based on server support
  216. clientFlags := clientProtocol41 |
  217. clientSecureConn |
  218. clientLongPassword |
  219. clientTransactions |
  220. clientLocalFiles |
  221. clientPluginAuth |
  222. clientMultiResults |
  223. mc.flags&clientLongFlag
  224. if mc.cfg.ClientFoundRows {
  225. clientFlags |= clientFoundRows
  226. }
  227. // To enable TLS / SSL
  228. if mc.cfg.tls != nil {
  229. clientFlags |= clientSSL
  230. }
  231. if mc.cfg.MultiStatements {
  232. clientFlags |= clientMultiStatements
  233. }
  234. // encode length of the auth plugin data
  235. var authRespLEIBuf [9]byte
  236. authRespLen := len(authResp)
  237. authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen))
  238. if len(authRespLEI) > 1 {
  239. // if the length can not be written in 1 byte, it must be written as a
  240. // length encoded integer
  241. clientFlags |= clientPluginAuthLenEncClientData
  242. }
  243. pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1
  244. // To specify a db name
  245. if n := len(mc.cfg.DBName); n > 0 {
  246. clientFlags |= clientConnectWithDB
  247. pktLen += n + 1
  248. }
  249. // Calculate packet length and get buffer with that size
  250. data := mc.buf.takeSmallBuffer(pktLen + 4)
  251. if data == nil {
  252. // cannot take the buffer. Something must be wrong with the connection
  253. errLog.Print(ErrBusyBuffer)
  254. return errBadConnNoWrite
  255. }
  256. // ClientFlags [32 bit]
  257. data[4] = byte(clientFlags)
  258. data[5] = byte(clientFlags >> 8)
  259. data[6] = byte(clientFlags >> 16)
  260. data[7] = byte(clientFlags >> 24)
  261. // MaxPacketSize [32 bit] (none)
  262. data[8] = 0x00
  263. data[9] = 0x00
  264. data[10] = 0x00
  265. data[11] = 0x00
  266. // Charset [1 byte]
  267. var found bool
  268. data[12], found = collations[mc.cfg.Collation]
  269. if !found {
  270. // Note possibility for false negatives:
  271. // could be triggered although the collation is valid if the
  272. // collations map does not contain entries the server supports.
  273. return errors.New("unknown collation")
  274. }
  275. // SSL Connection Request Packet
  276. // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
  277. if mc.cfg.tls != nil {
  278. // Send TLS / SSL request packet
  279. if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil {
  280. return err
  281. }
  282. // Switch to TLS
  283. tlsConn := tls.Client(mc.netConn, mc.cfg.tls)
  284. if err := tlsConn.Handshake(); err != nil {
  285. return err
  286. }
  287. mc.netConn = tlsConn
  288. mc.buf.nc = tlsConn
  289. }
  290. // Filler [23 bytes] (all 0x00)
  291. pos := 13
  292. for ; pos < 13+23; pos++ {
  293. data[pos] = 0
  294. }
  295. // User [null terminated string]
  296. if len(mc.cfg.User) > 0 {
  297. pos += copy(data[pos:], mc.cfg.User)
  298. }
  299. data[pos] = 0x00
  300. pos++
  301. // Auth Data [length encoded integer]
  302. pos += copy(data[pos:], authRespLEI)
  303. pos += copy(data[pos:], authResp)
  304. // Databasename [null terminated string]
  305. if len(mc.cfg.DBName) > 0 {
  306. pos += copy(data[pos:], mc.cfg.DBName)
  307. data[pos] = 0x00
  308. pos++
  309. }
  310. pos += copy(data[pos:], plugin)
  311. data[pos] = 0x00
  312. pos++
  313. // Send Auth packet
  314. return mc.writePacket(data[:pos])
  315. }
  316. // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
  317. func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
  318. pktLen := 4 + len(authData)
  319. data := mc.buf.takeSmallBuffer(pktLen)
  320. if data == nil {
  321. // cannot take the buffer. Something must be wrong with the connection
  322. errLog.Print(ErrBusyBuffer)
  323. return errBadConnNoWrite
  324. }
  325. // Add the auth data [EOF]
  326. copy(data[4:], authData)
  327. return mc.writePacket(data)
  328. }
  329. /******************************************************************************
  330. * Command Packets *
  331. ******************************************************************************/
  332. func (mc *mysqlConn) writeCommandPacket(command byte) error {
  333. // Reset Packet Sequence
  334. mc.sequence = 0
  335. data := mc.buf.takeSmallBuffer(4 + 1)
  336. if data == nil {
  337. // cannot take the buffer. Something must be wrong with the connection
  338. errLog.Print(ErrBusyBuffer)
  339. return errBadConnNoWrite
  340. }
  341. // Add command byte
  342. data[4] = command
  343. // Send CMD packet
  344. return mc.writePacket(data)
  345. }
  346. func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
  347. // Reset Packet Sequence
  348. mc.sequence = 0
  349. pktLen := 1 + len(arg)
  350. data := mc.buf.takeBuffer(pktLen + 4)
  351. if data == nil {
  352. // cannot take the buffer. Something must be wrong with the connection
  353. errLog.Print(ErrBusyBuffer)
  354. return errBadConnNoWrite
  355. }
  356. // Add command byte
  357. data[4] = command
  358. // Add arg
  359. copy(data[5:], arg)
  360. // Send CMD packet
  361. return mc.writePacket(data)
  362. }
  363. func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
  364. // Reset Packet Sequence
  365. mc.sequence = 0
  366. data := mc.buf.takeSmallBuffer(4 + 1 + 4)
  367. if data == nil {
  368. // cannot take the buffer. Something must be wrong with the connection
  369. errLog.Print(ErrBusyBuffer)
  370. return errBadConnNoWrite
  371. }
  372. // Add command byte
  373. data[4] = command
  374. // Add arg [32 bit]
  375. data[5] = byte(arg)
  376. data[6] = byte(arg >> 8)
  377. data[7] = byte(arg >> 16)
  378. data[8] = byte(arg >> 24)
  379. // Send CMD packet
  380. return mc.writePacket(data)
  381. }
  382. /******************************************************************************
  383. * Result Packets *
  384. ******************************************************************************/
  385. func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
  386. data, err := mc.readPacket()
  387. if err != nil {
  388. return nil, "", err
  389. }
  390. // packet indicator
  391. switch data[0] {
  392. case iOK:
  393. return nil, "", mc.handleOkPacket(data)
  394. case iAuthMoreData:
  395. return data[1:], "", err
  396. case iEOF:
  397. if len(data) == 1 {
  398. // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
  399. return nil, "mysql_old_password", nil
  400. }
  401. pluginEndIndex := bytes.IndexByte(data, 0x00)
  402. if pluginEndIndex < 0 {
  403. return nil, "", ErrMalformPkt
  404. }
  405. plugin := string(data[1:pluginEndIndex])
  406. authData := data[pluginEndIndex+1:]
  407. return authData, plugin, nil
  408. default: // Error otherwise
  409. return nil, "", mc.handleErrorPacket(data)
  410. }
  411. }
  412. // Returns error if Packet is not an 'Result OK'-Packet
  413. func (mc *mysqlConn) readResultOK() error {
  414. data, err := mc.readPacket()
  415. if err != nil {
  416. return err
  417. }
  418. if data[0] == iOK {
  419. return mc.handleOkPacket(data)
  420. }
  421. return mc.handleErrorPacket(data)
  422. }
  423. // Result Set Header Packet
  424. // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
  425. func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
  426. data, err := mc.readPacket()
  427. if err == nil {
  428. switch data[0] {
  429. case iOK:
  430. return 0, mc.handleOkPacket(data)
  431. case iERR:
  432. return 0, mc.handleErrorPacket(data)
  433. case iLocalInFile:
  434. return 0, mc.handleInFileRequest(string(data[1:]))
  435. }
  436. // column count
  437. num, _, n := readLengthEncodedInteger(data)
  438. if n-len(data) == 0 {
  439. return int(num), nil
  440. }
  441. return 0, ErrMalformPkt
  442. }
  443. return 0, err
  444. }
  445. // Error Packet
  446. // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet
  447. func (mc *mysqlConn) handleErrorPacket(data []byte) error {
  448. if data[0] != iERR {
  449. return ErrMalformPkt
  450. }
  451. // 0xff [1 byte]
  452. // Error Number [16 bit uint]
  453. errno := binary.LittleEndian.Uint16(data[1:3])
  454. // 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION
  455. // 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover)
  456. if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly {
  457. // Oops; we are connected to a read-only connection, and won't be able
  458. // to issue any write statements. Since RejectReadOnly is configured,
  459. // we throw away this connection hoping this one would have write
  460. // permission. This is specifically for a possible race condition
  461. // during failover (e.g. on AWS Aurora). See README.md for more.
  462. //
  463. // We explicitly close the connection before returning
  464. // driver.ErrBadConn to ensure that `database/sql` purges this
  465. // connection and initiates a new one for next statement next time.
  466. mc.Close()
  467. return driver.ErrBadConn
  468. }
  469. pos := 3
  470. // SQL State [optional: # + 5bytes string]
  471. if data[3] == 0x23 {
  472. //sqlstate := string(data[4 : 4+5])
  473. pos = 9
  474. }
  475. // Error Message [string]
  476. return &MySQLError{
  477. Number: errno,
  478. Message: string(data[pos:]),
  479. }
  480. }
  481. func readStatus(b []byte) statusFlag {
  482. return statusFlag(b[0]) | statusFlag(b[1])<<8
  483. }
  484. // Ok Packet
  485. // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
  486. func (mc *mysqlConn) handleOkPacket(data []byte) error {
  487. var n, m int
  488. // 0x00 [1 byte]
  489. // Affected rows [Length Coded Binary]
  490. mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
  491. // Insert id [Length Coded Binary]
  492. mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
  493. // server_status [2 bytes]
  494. mc.status = readStatus(data[1+n+m : 1+n+m+2])
  495. if mc.status&statusMoreResultsExists != 0 {
  496. return nil
  497. }
  498. // warning count [2 bytes]
  499. return nil
  500. }
  501. // Read Packets as Field Packets until EOF-Packet or an Error appears
  502. // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
  503. func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
  504. columns := make([]mysqlField, count)
  505. for i := 0; ; i++ {
  506. data, err := mc.readPacket()
  507. if err != nil {
  508. return nil, err
  509. }
  510. // EOF Packet
  511. if data[0] == iEOF && (len(data) == 5 || len(data) == 1) {
  512. if i == count {
  513. return columns, nil
  514. }
  515. return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns))
  516. }
  517. // Catalog
  518. pos, err := skipLengthEncodedString(data)
  519. if err != nil {
  520. return nil, err
  521. }
  522. // Database [len coded string]
  523. n, err := skipLengthEncodedString(data[pos:])
  524. if err != nil {
  525. return nil, err
  526. }
  527. pos += n
  528. // Table [len coded string]
  529. if mc.cfg.ColumnsWithAlias {
  530. tableName, _, n, err := readLengthEncodedString(data[pos:])
  531. if err != nil {
  532. return nil, err
  533. }
  534. pos += n
  535. columns[i].tableName = string(tableName)
  536. } else {
  537. n, err = skipLengthEncodedString(data[pos:])
  538. if err != nil {
  539. return nil, err
  540. }
  541. pos += n
  542. }
  543. // Original table [len coded string]
  544. n, err = skipLengthEncodedString(data[pos:])
  545. if err != nil {
  546. return nil, err
  547. }
  548. pos += n
  549. // Name [len coded string]
  550. name, _, n, err := readLengthEncodedString(data[pos:])
  551. if err != nil {
  552. return nil, err
  553. }
  554. columns[i].name = string(name)
  555. pos += n
  556. // Original name [len coded string]
  557. n, err = skipLengthEncodedString(data[pos:])
  558. if err != nil {
  559. return nil, err
  560. }
  561. pos += n
  562. // Filler [uint8]
  563. pos++
  564. // Charset [charset, collation uint8]
  565. columns[i].charSet = data[pos]
  566. pos += 2
  567. // Length [uint32]
  568. columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4])
  569. pos += 4
  570. // Field type [uint8]
  571. columns[i].fieldType = fieldType(data[pos])
  572. pos++
  573. // Flags [uint16]
  574. columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
  575. pos += 2
  576. // Decimals [uint8]
  577. columns[i].decimals = data[pos]
  578. //pos++
  579. // Default value [len coded binary]
  580. //if pos < len(data) {
  581. // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:])
  582. //}
  583. }
  584. }
  585. // Read Packets as Field Packets until EOF-Packet or an Error appears
  586. // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
  587. func (rows *textRows) readRow(dest []driver.Value) error {
  588. mc := rows.mc
  589. if rows.rs.done {
  590. return io.EOF
  591. }
  592. data, err := mc.readPacket()
  593. if err != nil {
  594. return err
  595. }
  596. // EOF Packet
  597. if data[0] == iEOF && len(data) == 5 {
  598. // server_status [2 bytes]
  599. rows.mc.status = readStatus(data[3:])
  600. rows.rs.done = true
  601. if !rows.HasNextResultSet() {
  602. rows.mc = nil
  603. }
  604. return io.EOF
  605. }
  606. if data[0] == iERR {
  607. rows.mc = nil
  608. return mc.handleErrorPacket(data)
  609. }
  610. // RowSet Packet
  611. var n int
  612. var isNull bool
  613. pos := 0
  614. for i := range dest {
  615. // Read bytes and convert to string
  616. dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
  617. pos += n
  618. if err == nil {
  619. if !isNull {
  620. if !mc.parseTime {
  621. continue
  622. } else {
  623. switch rows.rs.columns[i].fieldType {
  624. case fieldTypeTimestamp, fieldTypeDateTime,
  625. fieldTypeDate, fieldTypeNewDate:
  626. dest[i], err = parseDateTime(
  627. string(dest[i].([]byte)),
  628. mc.cfg.Loc,
  629. )
  630. if err == nil {
  631. continue
  632. }
  633. default:
  634. continue
  635. }
  636. }
  637. } else {
  638. dest[i] = nil
  639. continue
  640. }
  641. }
  642. return err // err != nil
  643. }
  644. return nil
  645. }
  646. // Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read
  647. func (mc *mysqlConn) readUntilEOF() error {
  648. for {
  649. data, err := mc.readPacket()
  650. if err != nil {
  651. return err
  652. }
  653. switch data[0] {
  654. case iERR:
  655. return mc.handleErrorPacket(data)
  656. case iEOF:
  657. if len(data) == 5 {
  658. mc.status = readStatus(data[3:])
  659. }
  660. return nil
  661. }
  662. }
  663. }
  664. /******************************************************************************
  665. * Prepared Statements *
  666. ******************************************************************************/
  667. // Prepare Result Packets
  668. // http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html
  669. func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
  670. data, err := stmt.mc.readPacket()
  671. if err == nil {
  672. // packet indicator [1 byte]
  673. if data[0] != iOK {
  674. return 0, stmt.mc.handleErrorPacket(data)
  675. }
  676. // statement id [4 bytes]
  677. stmt.id = binary.LittleEndian.Uint32(data[1:5])
  678. // Column count [16 bit uint]
  679. columnCount := binary.LittleEndian.Uint16(data[5:7])
  680. // Param count [16 bit uint]
  681. stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9]))
  682. // Reserved [8 bit]
  683. // Warning count [16 bit uint]
  684. return columnCount, nil
  685. }
  686. return 0, err
  687. }
  688. // http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html
  689. func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
  690. maxLen := stmt.mc.maxAllowedPacket - 1
  691. pktLen := maxLen
  692. // After the header (bytes 0-3) follows before the data:
  693. // 1 byte command
  694. // 4 bytes stmtID
  695. // 2 bytes paramID
  696. const dataOffset = 1 + 4 + 2
  697. // Cannot use the write buffer since
  698. // a) the buffer is too small
  699. // b) it is in use
  700. data := make([]byte, 4+1+4+2+len(arg))
  701. copy(data[4+dataOffset:], arg)
  702. for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset {
  703. if dataOffset+argLen < maxLen {
  704. pktLen = dataOffset + argLen
  705. }
  706. stmt.mc.sequence = 0
  707. // Add command byte [1 byte]
  708. data[4] = comStmtSendLongData
  709. // Add stmtID [32 bit]
  710. data[5] = byte(stmt.id)
  711. data[6] = byte(stmt.id >> 8)
  712. data[7] = byte(stmt.id >> 16)
  713. data[8] = byte(stmt.id >> 24)
  714. // Add paramID [16 bit]
  715. data[9] = byte(paramID)
  716. data[10] = byte(paramID >> 8)
  717. // Send CMD packet
  718. err := stmt.mc.writePacket(data[:4+pktLen])
  719. if err == nil {
  720. data = data[pktLen-dataOffset:]
  721. continue
  722. }
  723. return err
  724. }
  725. // Reset Packet Sequence
  726. stmt.mc.sequence = 0
  727. return nil
  728. }
  729. // Execute Prepared Statement
  730. // http://dev.mysql.com/doc/internals/en/com-stmt-execute.html
  731. func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
  732. if len(args) != stmt.paramCount {
  733. return fmt.Errorf(
  734. "argument count mismatch (got: %d; has: %d)",
  735. len(args),
  736. stmt.paramCount,
  737. )
  738. }
  739. const minPktLen = 4 + 1 + 4 + 1 + 4
  740. mc := stmt.mc
  741. // Determine threshould dynamically to avoid packet size shortage.
  742. longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
  743. if longDataSize < 64 {
  744. longDataSize = 64
  745. }
  746. // Reset packet-sequence
  747. mc.sequence = 0
  748. var data []byte
  749. if len(args) == 0 {
  750. data = mc.buf.takeBuffer(minPktLen)
  751. } else {
  752. data = mc.buf.takeCompleteBuffer()
  753. }
  754. if data == nil {
  755. // cannot take the buffer. Something must be wrong with the connection
  756. errLog.Print(ErrBusyBuffer)
  757. return errBadConnNoWrite
  758. }
  759. // command [1 byte]
  760. data[4] = comStmtExecute
  761. // statement_id [4 bytes]
  762. data[5] = byte(stmt.id)
  763. data[6] = byte(stmt.id >> 8)
  764. data[7] = byte(stmt.id >> 16)
  765. data[8] = byte(stmt.id >> 24)
  766. // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]
  767. data[9] = 0x00
  768. // iteration_count (uint32(1)) [4 bytes]
  769. data[10] = 0x01
  770. data[11] = 0x00
  771. data[12] = 0x00
  772. data[13] = 0x00
  773. if len(args) > 0 {
  774. pos := minPktLen
  775. var nullMask []byte
  776. if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) {
  777. // buffer has to be extended but we don't know by how much so
  778. // we depend on append after all data with known sizes fit.
  779. // We stop at that because we deal with a lot of columns here
  780. // which makes the required allocation size hard to guess.
  781. tmp := make([]byte, pos+maskLen+typesLen)
  782. copy(tmp[:pos], data[:pos])
  783. data = tmp
  784. nullMask = data[pos : pos+maskLen]
  785. pos += maskLen
  786. } else {
  787. nullMask = data[pos : pos+maskLen]
  788. for i := 0; i < maskLen; i++ {
  789. nullMask[i] = 0
  790. }
  791. pos += maskLen
  792. }
  793. // newParameterBoundFlag 1 [1 byte]
  794. data[pos] = 0x01
  795. pos++
  796. // type of each parameter [len(args)*2 bytes]
  797. paramTypes := data[pos:]
  798. pos += len(args) * 2
  799. // value of each parameter [n bytes]
  800. paramValues := data[pos:pos]
  801. valuesCap := cap(paramValues)
  802. for i, arg := range args {
  803. // build NULL-bitmap
  804. if arg == nil {
  805. nullMask[i/8] |= 1 << (uint(i) & 7)
  806. paramTypes[i+i] = byte(fieldTypeNULL)
  807. paramTypes[i+i+1] = 0x00
  808. continue
  809. }
  810. // cache types and values
  811. switch v := arg.(type) {
  812. case int64:
  813. paramTypes[i+i] = byte(fieldTypeLongLong)
  814. paramTypes[i+i+1] = 0x00
  815. if cap(paramValues)-len(paramValues)-8 >= 0 {
  816. paramValues = paramValues[:len(paramValues)+8]
  817. binary.LittleEndian.PutUint64(
  818. paramValues[len(paramValues)-8:],
  819. uint64(v),
  820. )
  821. } else {
  822. paramValues = append(paramValues,
  823. uint64ToBytes(uint64(v))...,
  824. )
  825. }
  826. case float64:
  827. paramTypes[i+i] = byte(fieldTypeDouble)
  828. paramTypes[i+i+1] = 0x00
  829. if cap(paramValues)-len(paramValues)-8 >= 0 {
  830. paramValues = paramValues[:len(paramValues)+8]
  831. binary.LittleEndian.PutUint64(
  832. paramValues[len(paramValues)-8:],
  833. math.Float64bits(v),
  834. )
  835. } else {
  836. paramValues = append(paramValues,
  837. uint64ToBytes(math.Float64bits(v))...,
  838. )
  839. }
  840. case bool:
  841. paramTypes[i+i] = byte(fieldTypeTiny)
  842. paramTypes[i+i+1] = 0x00
  843. if v {
  844. paramValues = append(paramValues, 0x01)
  845. } else {
  846. paramValues = append(paramValues, 0x00)
  847. }
  848. case []byte:
  849. // Common case (non-nil value) first
  850. if v != nil {
  851. paramTypes[i+i] = byte(fieldTypeString)
  852. paramTypes[i+i+1] = 0x00
  853. if len(v) < longDataSize {
  854. paramValues = appendLengthEncodedInteger(paramValues,
  855. uint64(len(v)),
  856. )
  857. paramValues = append(paramValues, v...)
  858. } else {
  859. if err := stmt.writeCommandLongData(i, v); err != nil {
  860. return err
  861. }
  862. }
  863. continue
  864. }
  865. // Handle []byte(nil) as a NULL value
  866. nullMask[i/8] |= 1 << (uint(i) & 7)
  867. paramTypes[i+i] = byte(fieldTypeNULL)
  868. paramTypes[i+i+1] = 0x00
  869. case string:
  870. paramTypes[i+i] = byte(fieldTypeString)
  871. paramTypes[i+i+1] = 0x00
  872. if len(v) < longDataSize {
  873. paramValues = appendLengthEncodedInteger(paramValues,
  874. uint64(len(v)),
  875. )
  876. paramValues = append(paramValues, v...)
  877. } else {
  878. if err := stmt.writeCommandLongData(i, []byte(v)); err != nil {
  879. return err
  880. }
  881. }
  882. case time.Time:
  883. paramTypes[i+i] = byte(fieldTypeString)
  884. paramTypes[i+i+1] = 0x00
  885. var a [64]byte
  886. var b = a[:0]
  887. if v.IsZero() {
  888. b = append(b, "0000-00-00"...)
  889. } else {
  890. b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat)
  891. }
  892. paramValues = appendLengthEncodedInteger(paramValues,
  893. uint64(len(b)),
  894. )
  895. paramValues = append(paramValues, b...)
  896. default:
  897. return fmt.Errorf("cannot convert type: %T", arg)
  898. }
  899. }
  900. // Check if param values exceeded the available buffer
  901. // In that case we must build the data packet with the new values buffer
  902. if valuesCap != cap(paramValues) {
  903. data = append(data[:pos], paramValues...)
  904. mc.buf.buf = data
  905. }
  906. pos += len(paramValues)
  907. data = data[:pos]
  908. }
  909. return mc.writePacket(data)
  910. }
  911. func (mc *mysqlConn) discardResults() error {
  912. for mc.status&statusMoreResultsExists != 0 {
  913. resLen, err := mc.readResultSetHeaderPacket()
  914. if err != nil {
  915. return err
  916. }
  917. if resLen > 0 {
  918. // columns
  919. if err := mc.readUntilEOF(); err != nil {
  920. return err
  921. }
  922. // rows
  923. if err := mc.readUntilEOF(); err != nil {
  924. return err
  925. }
  926. }
  927. }
  928. return nil
  929. }
  930. // http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
  931. func (rows *binaryRows) readRow(dest []driver.Value) error {
  932. data, err := rows.mc.readPacket()
  933. if err != nil {
  934. return err
  935. }
  936. // packet indicator [1 byte]
  937. if data[0] != iOK {
  938. // EOF Packet
  939. if data[0] == iEOF && len(data) == 5 {
  940. rows.mc.status = readStatus(data[3:])
  941. rows.rs.done = true
  942. if !rows.HasNextResultSet() {
  943. rows.mc = nil
  944. }
  945. return io.EOF
  946. }
  947. mc := rows.mc
  948. rows.mc = nil
  949. // Error otherwise
  950. return mc.handleErrorPacket(data)
  951. }
  952. // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes]
  953. pos := 1 + (len(dest)+7+2)>>3
  954. nullMask := data[1:pos]
  955. for i := range dest {
  956. // Field is NULL
  957. // (byte >> bit-pos) % 2 == 1
  958. if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 {
  959. dest[i] = nil
  960. continue
  961. }
  962. // Convert to byte-coded string
  963. switch rows.rs.columns[i].fieldType {
  964. case fieldTypeNULL:
  965. dest[i] = nil
  966. continue
  967. // Numeric Types
  968. case fieldTypeTiny:
  969. if rows.rs.columns[i].flags&flagUnsigned != 0 {
  970. dest[i] = int64(data[pos])
  971. } else {
  972. dest[i] = int64(int8(data[pos]))
  973. }
  974. pos++
  975. continue
  976. case fieldTypeShort, fieldTypeYear:
  977. if rows.rs.columns[i].flags&flagUnsigned != 0 {
  978. dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
  979. } else {
  980. dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
  981. }
  982. pos += 2
  983. continue
  984. case fieldTypeInt24, fieldTypeLong:
  985. if rows.rs.columns[i].flags&flagUnsigned != 0 {
  986. dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
  987. } else {
  988. dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
  989. }
  990. pos += 4
  991. continue
  992. case fieldTypeLongLong:
  993. if rows.rs.columns[i].flags&flagUnsigned != 0 {
  994. val := binary.LittleEndian.Uint64(data[pos : pos+8])
  995. if val > math.MaxInt64 {
  996. dest[i] = uint64ToString(val)
  997. } else {
  998. dest[i] = int64(val)
  999. }
  1000. } else {
  1001. dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8]))
  1002. }
  1003. pos += 8
  1004. continue
  1005. case fieldTypeFloat:
  1006. dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))
  1007. pos += 4
  1008. continue
  1009. case fieldTypeDouble:
  1010. dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8]))
  1011. pos += 8
  1012. continue
  1013. // Length coded Binary Strings
  1014. case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
  1015. fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
  1016. fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
  1017. fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON:
  1018. var isNull bool
  1019. var n int
  1020. dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
  1021. pos += n
  1022. if err == nil {
  1023. if !isNull {
  1024. continue
  1025. } else {
  1026. dest[i] = nil
  1027. continue
  1028. }
  1029. }
  1030. return err
  1031. case
  1032. fieldTypeDate, fieldTypeNewDate, // Date YYYY-MM-DD
  1033. fieldTypeTime, // Time [-][H]HH:MM:SS[.fractal]
  1034. fieldTypeTimestamp, fieldTypeDateTime: // Timestamp YYYY-MM-DD HH:MM:SS[.fractal]
  1035. num, isNull, n := readLengthEncodedInteger(data[pos:])
  1036. pos += n
  1037. switch {
  1038. case isNull:
  1039. dest[i] = nil
  1040. continue
  1041. case rows.rs.columns[i].fieldType == fieldTypeTime:
  1042. // database/sql does not support an equivalent to TIME, return a string
  1043. var dstlen uint8
  1044. switch decimals := rows.rs.columns[i].decimals; decimals {
  1045. case 0x00, 0x1f:
  1046. dstlen = 8
  1047. case 1, 2, 3, 4, 5, 6:
  1048. dstlen = 8 + 1 + decimals
  1049. default:
  1050. return fmt.Errorf(
  1051. "protocol error, illegal decimals value %d",
  1052. rows.rs.columns[i].decimals,
  1053. )
  1054. }
  1055. dest[i], err = formatBinaryTime(data[pos:pos+int(num)], dstlen)
  1056. case rows.mc.parseTime:
  1057. dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
  1058. default:
  1059. var dstlen uint8
  1060. if rows.rs.columns[i].fieldType == fieldTypeDate {
  1061. dstlen = 10
  1062. } else {
  1063. switch decimals := rows.rs.columns[i].decimals; decimals {
  1064. case 0x00, 0x1f:
  1065. dstlen = 19
  1066. case 1, 2, 3, 4, 5, 6:
  1067. dstlen = 19 + 1 + decimals
  1068. default:
  1069. return fmt.Errorf(
  1070. "protocol error, illegal decimals value %d",
  1071. rows.rs.columns[i].decimals,
  1072. )
  1073. }
  1074. }
  1075. dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen)
  1076. }
  1077. if err == nil {
  1078. pos += int(num)
  1079. continue
  1080. } else {
  1081. return err
  1082. }
  1083. // Please report if this happens!
  1084. default:
  1085. return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType)
  1086. }
  1087. }
  1088. return nil
  1089. }
上海开阖软件有限公司 沪ICP备12045867号-1