本站源代码
No puede seleccionar más de 25 temas Los temas deben comenzar con una letra o número, pueden incluir guiones ('-') y pueden tener hasta 35 caracteres de largo.

612 líneas
14KB

  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
  4. //
  5. // This Source Code Form is subject to the terms of the Mozilla Public
  6. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  7. // You can obtain one at http://mozilla.org/MPL/2.0/.
  8. package mysql
  9. import (
  10. "bytes"
  11. "crypto/rsa"
  12. "crypto/tls"
  13. "errors"
  14. "fmt"
  15. "net"
  16. "net/url"
  17. "sort"
  18. "strconv"
  19. "strings"
  20. "time"
  21. )
  22. var (
  23. errInvalidDSNUnescaped = errors.New("invalid DSN: did you forget to escape a param value?")
  24. errInvalidDSNAddr = errors.New("invalid DSN: network address not terminated (missing closing brace)")
  25. errInvalidDSNNoSlash = errors.New("invalid DSN: missing the slash separating the database name")
  26. errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations")
  27. )
  28. // Config is a configuration parsed from a DSN string.
  29. // If a new Config is created instead of being parsed from a DSN string,
  30. // the NewConfig function should be used, which sets default values.
  31. type Config struct {
  32. User string // Username
  33. Passwd string // Password (requires User)
  34. Net string // Network type
  35. Addr string // Network address (requires Net)
  36. DBName string // Database name
  37. Params map[string]string // Connection parameters
  38. Collation string // Connection collation
  39. Loc *time.Location // Location for time.Time values
  40. MaxAllowedPacket int // Max packet size allowed
  41. ServerPubKey string // Server public key name
  42. pubKey *rsa.PublicKey // Server public key
  43. TLSConfig string // TLS configuration name
  44. tls *tls.Config // TLS configuration
  45. Timeout time.Duration // Dial timeout
  46. ReadTimeout time.Duration // I/O read timeout
  47. WriteTimeout time.Duration // I/O write timeout
  48. AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
  49. AllowCleartextPasswords bool // Allows the cleartext client side plugin
  50. AllowNativePasswords bool // Allows the native password authentication method
  51. AllowOldPasswords bool // Allows the old insecure password method
  52. ClientFoundRows bool // Return number of matching rows instead of rows changed
  53. ColumnsWithAlias bool // Prepend table alias to column names
  54. InterpolateParams bool // Interpolate placeholders into query string
  55. MultiStatements bool // Allow multiple statements in one query
  56. ParseTime bool // Parse time values to time.Time
  57. RejectReadOnly bool // Reject read-only connections
  58. }
  59. // NewConfig creates a new Config and sets default values.
  60. func NewConfig() *Config {
  61. return &Config{
  62. Collation: defaultCollation,
  63. Loc: time.UTC,
  64. MaxAllowedPacket: defaultMaxAllowedPacket,
  65. AllowNativePasswords: true,
  66. }
  67. }
  68. func (cfg *Config) normalize() error {
  69. if cfg.InterpolateParams && unsafeCollations[cfg.Collation] {
  70. return errInvalidDSNUnsafeCollation
  71. }
  72. // Set default network if empty
  73. if cfg.Net == "" {
  74. cfg.Net = "tcp"
  75. }
  76. // Set default address if empty
  77. if cfg.Addr == "" {
  78. switch cfg.Net {
  79. case "tcp":
  80. cfg.Addr = "127.0.0.1:3306"
  81. case "unix":
  82. cfg.Addr = "/tmp/mysql.sock"
  83. default:
  84. return errors.New("default addr for network '" + cfg.Net + "' unknown")
  85. }
  86. } else if cfg.Net == "tcp" {
  87. cfg.Addr = ensureHavePort(cfg.Addr)
  88. }
  89. if cfg.tls != nil {
  90. if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify {
  91. host, _, err := net.SplitHostPort(cfg.Addr)
  92. if err == nil {
  93. cfg.tls.ServerName = host
  94. }
  95. }
  96. }
  97. return nil
  98. }
  99. // FormatDSN formats the given Config into a DSN string which can be passed to
  100. // the driver.
  101. func (cfg *Config) FormatDSN() string {
  102. var buf bytes.Buffer
  103. // [username[:password]@]
  104. if len(cfg.User) > 0 {
  105. buf.WriteString(cfg.User)
  106. if len(cfg.Passwd) > 0 {
  107. buf.WriteByte(':')
  108. buf.WriteString(cfg.Passwd)
  109. }
  110. buf.WriteByte('@')
  111. }
  112. // [protocol[(address)]]
  113. if len(cfg.Net) > 0 {
  114. buf.WriteString(cfg.Net)
  115. if len(cfg.Addr) > 0 {
  116. buf.WriteByte('(')
  117. buf.WriteString(cfg.Addr)
  118. buf.WriteByte(')')
  119. }
  120. }
  121. // /dbname
  122. buf.WriteByte('/')
  123. buf.WriteString(cfg.DBName)
  124. // [?param1=value1&...&paramN=valueN]
  125. hasParam := false
  126. if cfg.AllowAllFiles {
  127. hasParam = true
  128. buf.WriteString("?allowAllFiles=true")
  129. }
  130. if cfg.AllowCleartextPasswords {
  131. if hasParam {
  132. buf.WriteString("&allowCleartextPasswords=true")
  133. } else {
  134. hasParam = true
  135. buf.WriteString("?allowCleartextPasswords=true")
  136. }
  137. }
  138. if !cfg.AllowNativePasswords {
  139. if hasParam {
  140. buf.WriteString("&allowNativePasswords=false")
  141. } else {
  142. hasParam = true
  143. buf.WriteString("?allowNativePasswords=false")
  144. }
  145. }
  146. if cfg.AllowOldPasswords {
  147. if hasParam {
  148. buf.WriteString("&allowOldPasswords=true")
  149. } else {
  150. hasParam = true
  151. buf.WriteString("?allowOldPasswords=true")
  152. }
  153. }
  154. if cfg.ClientFoundRows {
  155. if hasParam {
  156. buf.WriteString("&clientFoundRows=true")
  157. } else {
  158. hasParam = true
  159. buf.WriteString("?clientFoundRows=true")
  160. }
  161. }
  162. if col := cfg.Collation; col != defaultCollation && len(col) > 0 {
  163. if hasParam {
  164. buf.WriteString("&collation=")
  165. } else {
  166. hasParam = true
  167. buf.WriteString("?collation=")
  168. }
  169. buf.WriteString(col)
  170. }
  171. if cfg.ColumnsWithAlias {
  172. if hasParam {
  173. buf.WriteString("&columnsWithAlias=true")
  174. } else {
  175. hasParam = true
  176. buf.WriteString("?columnsWithAlias=true")
  177. }
  178. }
  179. if cfg.InterpolateParams {
  180. if hasParam {
  181. buf.WriteString("&interpolateParams=true")
  182. } else {
  183. hasParam = true
  184. buf.WriteString("?interpolateParams=true")
  185. }
  186. }
  187. if cfg.Loc != time.UTC && cfg.Loc != nil {
  188. if hasParam {
  189. buf.WriteString("&loc=")
  190. } else {
  191. hasParam = true
  192. buf.WriteString("?loc=")
  193. }
  194. buf.WriteString(url.QueryEscape(cfg.Loc.String()))
  195. }
  196. if cfg.MultiStatements {
  197. if hasParam {
  198. buf.WriteString("&multiStatements=true")
  199. } else {
  200. hasParam = true
  201. buf.WriteString("?multiStatements=true")
  202. }
  203. }
  204. if cfg.ParseTime {
  205. if hasParam {
  206. buf.WriteString("&parseTime=true")
  207. } else {
  208. hasParam = true
  209. buf.WriteString("?parseTime=true")
  210. }
  211. }
  212. if cfg.ReadTimeout > 0 {
  213. if hasParam {
  214. buf.WriteString("&readTimeout=")
  215. } else {
  216. hasParam = true
  217. buf.WriteString("?readTimeout=")
  218. }
  219. buf.WriteString(cfg.ReadTimeout.String())
  220. }
  221. if cfg.RejectReadOnly {
  222. if hasParam {
  223. buf.WriteString("&rejectReadOnly=true")
  224. } else {
  225. hasParam = true
  226. buf.WriteString("?rejectReadOnly=true")
  227. }
  228. }
  229. if len(cfg.ServerPubKey) > 0 {
  230. if hasParam {
  231. buf.WriteString("&serverPubKey=")
  232. } else {
  233. hasParam = true
  234. buf.WriteString("?serverPubKey=")
  235. }
  236. buf.WriteString(url.QueryEscape(cfg.ServerPubKey))
  237. }
  238. if cfg.Timeout > 0 {
  239. if hasParam {
  240. buf.WriteString("&timeout=")
  241. } else {
  242. hasParam = true
  243. buf.WriteString("?timeout=")
  244. }
  245. buf.WriteString(cfg.Timeout.String())
  246. }
  247. if len(cfg.TLSConfig) > 0 {
  248. if hasParam {
  249. buf.WriteString("&tls=")
  250. } else {
  251. hasParam = true
  252. buf.WriteString("?tls=")
  253. }
  254. buf.WriteString(url.QueryEscape(cfg.TLSConfig))
  255. }
  256. if cfg.WriteTimeout > 0 {
  257. if hasParam {
  258. buf.WriteString("&writeTimeout=")
  259. } else {
  260. hasParam = true
  261. buf.WriteString("?writeTimeout=")
  262. }
  263. buf.WriteString(cfg.WriteTimeout.String())
  264. }
  265. if cfg.MaxAllowedPacket != defaultMaxAllowedPacket {
  266. if hasParam {
  267. buf.WriteString("&maxAllowedPacket=")
  268. } else {
  269. hasParam = true
  270. buf.WriteString("?maxAllowedPacket=")
  271. }
  272. buf.WriteString(strconv.Itoa(cfg.MaxAllowedPacket))
  273. }
  274. // other params
  275. if cfg.Params != nil {
  276. var params []string
  277. for param := range cfg.Params {
  278. params = append(params, param)
  279. }
  280. sort.Strings(params)
  281. for _, param := range params {
  282. if hasParam {
  283. buf.WriteByte('&')
  284. } else {
  285. hasParam = true
  286. buf.WriteByte('?')
  287. }
  288. buf.WriteString(param)
  289. buf.WriteByte('=')
  290. buf.WriteString(url.QueryEscape(cfg.Params[param]))
  291. }
  292. }
  293. return buf.String()
  294. }
  295. // ParseDSN parses the DSN string to a Config
  296. func ParseDSN(dsn string) (cfg *Config, err error) {
  297. // New config with some default values
  298. cfg = NewConfig()
  299. // [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
  300. // Find the last '/' (since the password or the net addr might contain a '/')
  301. foundSlash := false
  302. for i := len(dsn) - 1; i >= 0; i-- {
  303. if dsn[i] == '/' {
  304. foundSlash = true
  305. var j, k int
  306. // left part is empty if i <= 0
  307. if i > 0 {
  308. // [username[:password]@][protocol[(address)]]
  309. // Find the last '@' in dsn[:i]
  310. for j = i; j >= 0; j-- {
  311. if dsn[j] == '@' {
  312. // username[:password]
  313. // Find the first ':' in dsn[:j]
  314. for k = 0; k < j; k++ {
  315. if dsn[k] == ':' {
  316. cfg.Passwd = dsn[k+1 : j]
  317. break
  318. }
  319. }
  320. cfg.User = dsn[:k]
  321. break
  322. }
  323. }
  324. // [protocol[(address)]]
  325. // Find the first '(' in dsn[j+1:i]
  326. for k = j + 1; k < i; k++ {
  327. if dsn[k] == '(' {
  328. // dsn[i-1] must be == ')' if an address is specified
  329. if dsn[i-1] != ')' {
  330. if strings.ContainsRune(dsn[k+1:i], ')') {
  331. return nil, errInvalidDSNUnescaped
  332. }
  333. return nil, errInvalidDSNAddr
  334. }
  335. cfg.Addr = dsn[k+1 : i-1]
  336. break
  337. }
  338. }
  339. cfg.Net = dsn[j+1 : k]
  340. }
  341. // dbname[?param1=value1&...&paramN=valueN]
  342. // Find the first '?' in dsn[i+1:]
  343. for j = i + 1; j < len(dsn); j++ {
  344. if dsn[j] == '?' {
  345. if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
  346. return
  347. }
  348. break
  349. }
  350. }
  351. cfg.DBName = dsn[i+1 : j]
  352. break
  353. }
  354. }
  355. if !foundSlash && len(dsn) > 0 {
  356. return nil, errInvalidDSNNoSlash
  357. }
  358. if err = cfg.normalize(); err != nil {
  359. return nil, err
  360. }
  361. return
  362. }
  363. // parseDSNParams parses the DSN "query string"
  364. // Values must be url.QueryEscape'ed
  365. func parseDSNParams(cfg *Config, params string) (err error) {
  366. for _, v := range strings.Split(params, "&") {
  367. param := strings.SplitN(v, "=", 2)
  368. if len(param) != 2 {
  369. continue
  370. }
  371. // cfg params
  372. switch value := param[1]; param[0] {
  373. // Disable INFILE whitelist / enable all files
  374. case "allowAllFiles":
  375. var isBool bool
  376. cfg.AllowAllFiles, isBool = readBool(value)
  377. if !isBool {
  378. return errors.New("invalid bool value: " + value)
  379. }
  380. // Use cleartext authentication mode (MySQL 5.5.10+)
  381. case "allowCleartextPasswords":
  382. var isBool bool
  383. cfg.AllowCleartextPasswords, isBool = readBool(value)
  384. if !isBool {
  385. return errors.New("invalid bool value: " + value)
  386. }
  387. // Use native password authentication
  388. case "allowNativePasswords":
  389. var isBool bool
  390. cfg.AllowNativePasswords, isBool = readBool(value)
  391. if !isBool {
  392. return errors.New("invalid bool value: " + value)
  393. }
  394. // Use old authentication mode (pre MySQL 4.1)
  395. case "allowOldPasswords":
  396. var isBool bool
  397. cfg.AllowOldPasswords, isBool = readBool(value)
  398. if !isBool {
  399. return errors.New("invalid bool value: " + value)
  400. }
  401. // Switch "rowsAffected" mode
  402. case "clientFoundRows":
  403. var isBool bool
  404. cfg.ClientFoundRows, isBool = readBool(value)
  405. if !isBool {
  406. return errors.New("invalid bool value: " + value)
  407. }
  408. // Collation
  409. case "collation":
  410. cfg.Collation = value
  411. break
  412. case "columnsWithAlias":
  413. var isBool bool
  414. cfg.ColumnsWithAlias, isBool = readBool(value)
  415. if !isBool {
  416. return errors.New("invalid bool value: " + value)
  417. }
  418. // Compression
  419. case "compress":
  420. return errors.New("compression not implemented yet")
  421. // Enable client side placeholder substitution
  422. case "interpolateParams":
  423. var isBool bool
  424. cfg.InterpolateParams, isBool = readBool(value)
  425. if !isBool {
  426. return errors.New("invalid bool value: " + value)
  427. }
  428. // Time Location
  429. case "loc":
  430. if value, err = url.QueryUnescape(value); err != nil {
  431. return
  432. }
  433. cfg.Loc, err = time.LoadLocation(value)
  434. if err != nil {
  435. return
  436. }
  437. // multiple statements in one query
  438. case "multiStatements":
  439. var isBool bool
  440. cfg.MultiStatements, isBool = readBool(value)
  441. if !isBool {
  442. return errors.New("invalid bool value: " + value)
  443. }
  444. // time.Time parsing
  445. case "parseTime":
  446. var isBool bool
  447. cfg.ParseTime, isBool = readBool(value)
  448. if !isBool {
  449. return errors.New("invalid bool value: " + value)
  450. }
  451. // I/O read Timeout
  452. case "readTimeout":
  453. cfg.ReadTimeout, err = time.ParseDuration(value)
  454. if err != nil {
  455. return
  456. }
  457. // Reject read-only connections
  458. case "rejectReadOnly":
  459. var isBool bool
  460. cfg.RejectReadOnly, isBool = readBool(value)
  461. if !isBool {
  462. return errors.New("invalid bool value: " + value)
  463. }
  464. // Server public key
  465. case "serverPubKey":
  466. name, err := url.QueryUnescape(value)
  467. if err != nil {
  468. return fmt.Errorf("invalid value for server pub key name: %v", err)
  469. }
  470. if pubKey := getServerPubKey(name); pubKey != nil {
  471. cfg.ServerPubKey = name
  472. cfg.pubKey = pubKey
  473. } else {
  474. return errors.New("invalid value / unknown server pub key name: " + name)
  475. }
  476. // Strict mode
  477. case "strict":
  478. panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode")
  479. // Dial Timeout
  480. case "timeout":
  481. cfg.Timeout, err = time.ParseDuration(value)
  482. if err != nil {
  483. return
  484. }
  485. // TLS-Encryption
  486. case "tls":
  487. boolValue, isBool := readBool(value)
  488. if isBool {
  489. if boolValue {
  490. cfg.TLSConfig = "true"
  491. cfg.tls = &tls.Config{}
  492. } else {
  493. cfg.TLSConfig = "false"
  494. }
  495. } else if vl := strings.ToLower(value); vl == "skip-verify" {
  496. cfg.TLSConfig = vl
  497. cfg.tls = &tls.Config{InsecureSkipVerify: true}
  498. } else {
  499. name, err := url.QueryUnescape(value)
  500. if err != nil {
  501. return fmt.Errorf("invalid value for TLS config name: %v", err)
  502. }
  503. if tlsConfig := getTLSConfigClone(name); tlsConfig != nil {
  504. cfg.TLSConfig = name
  505. cfg.tls = tlsConfig
  506. } else {
  507. return errors.New("invalid value / unknown config name: " + name)
  508. }
  509. }
  510. // I/O write Timeout
  511. case "writeTimeout":
  512. cfg.WriteTimeout, err = time.ParseDuration(value)
  513. if err != nil {
  514. return
  515. }
  516. case "maxAllowedPacket":
  517. cfg.MaxAllowedPacket, err = strconv.Atoi(value)
  518. if err != nil {
  519. return
  520. }
  521. default:
  522. // lazy init
  523. if cfg.Params == nil {
  524. cfg.Params = make(map[string]string)
  525. }
  526. if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil {
  527. return
  528. }
  529. }
  530. }
  531. return
  532. }
  533. func ensureHavePort(addr string) string {
  534. if _, _, err := net.SplitHostPort(addr); err != nil {
  535. return net.JoinHostPort(addr, "3306")
  536. }
  537. return addr
  538. }
上海开阖软件有限公司 沪ICP备12045867号-1