本站源代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

454 lines
11KB

  1. package mssql
  2. import (
  3. "fmt"
  4. "net"
  5. "net/url"
  6. "os"
  7. "strconv"
  8. "strings"
  9. "time"
  10. "unicode"
  11. )
  12. type connectParams struct {
  13. logFlags uint64
  14. port uint64
  15. host string
  16. instance string
  17. database string
  18. user string
  19. password string
  20. dial_timeout time.Duration
  21. conn_timeout time.Duration
  22. keepAlive time.Duration
  23. encrypt bool
  24. disableEncryption bool
  25. trustServerCertificate bool
  26. certificate string
  27. hostInCertificate string
  28. hostInCertificateProvided bool
  29. serverSPN string
  30. workstation string
  31. appname string
  32. typeFlags uint8
  33. failOverPartner string
  34. failOverPort uint64
  35. packetSize uint16
  36. }
  37. func parseConnectParams(dsn string) (connectParams, error) {
  38. var p connectParams
  39. var params map[string]string
  40. if strings.HasPrefix(dsn, "odbc:") {
  41. parameters, err := splitConnectionStringOdbc(dsn[len("odbc:"):])
  42. if err != nil {
  43. return p, err
  44. }
  45. params = parameters
  46. } else if strings.HasPrefix(dsn, "sqlserver://") {
  47. parameters, err := splitConnectionStringURL(dsn)
  48. if err != nil {
  49. return p, err
  50. }
  51. params = parameters
  52. } else {
  53. params = splitConnectionString(dsn)
  54. }
  55. strlog, ok := params["log"]
  56. if ok {
  57. var err error
  58. p.logFlags, err = strconv.ParseUint(strlog, 10, 64)
  59. if err != nil {
  60. return p, fmt.Errorf("Invalid log parameter '%s': %s", strlog, err.Error())
  61. }
  62. }
  63. server := params["server"]
  64. parts := strings.SplitN(server, `\`, 2)
  65. p.host = parts[0]
  66. if p.host == "." || strings.ToUpper(p.host) == "(LOCAL)" || p.host == "" {
  67. p.host = "localhost"
  68. }
  69. if len(parts) > 1 {
  70. p.instance = parts[1]
  71. }
  72. p.database = params["database"]
  73. p.user = params["user id"]
  74. p.password = params["password"]
  75. p.port = 1433
  76. strport, ok := params["port"]
  77. if ok {
  78. var err error
  79. p.port, err = strconv.ParseUint(strport, 10, 16)
  80. if err != nil {
  81. f := "Invalid tcp port '%v': %v"
  82. return p, fmt.Errorf(f, strport, err.Error())
  83. }
  84. }
  85. // https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option
  86. // Default packet size remains at 4096 bytes
  87. p.packetSize = 4096
  88. strpsize, ok := params["packet size"]
  89. if ok {
  90. var err error
  91. psize, err := strconv.ParseUint(strpsize, 0, 16)
  92. if err != nil {
  93. f := "Invalid packet size '%v': %v"
  94. return p, fmt.Errorf(f, strpsize, err.Error())
  95. }
  96. // Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes
  97. // NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request
  98. // a higher packet size, the server will respond with an ENVCHANGE request to
  99. // alter the packet size to 16383 bytes.
  100. p.packetSize = uint16(psize)
  101. if p.packetSize < 512 {
  102. p.packetSize = 512
  103. } else if p.packetSize > 32767 {
  104. p.packetSize = 32767
  105. }
  106. }
  107. // https://msdn.microsoft.com/en-us/library/dd341108.aspx
  108. //
  109. // Do not set a connection timeout. Use Context to manage such things.
  110. // Default to zero, but still allow it to be set.
  111. if strconntimeout, ok := params["connection timeout"]; ok {
  112. timeout, err := strconv.ParseUint(strconntimeout, 10, 64)
  113. if err != nil {
  114. f := "Invalid connection timeout '%v': %v"
  115. return p, fmt.Errorf(f, strconntimeout, err.Error())
  116. }
  117. p.conn_timeout = time.Duration(timeout) * time.Second
  118. }
  119. p.dial_timeout = 15 * time.Second
  120. if strdialtimeout, ok := params["dial timeout"]; ok {
  121. timeout, err := strconv.ParseUint(strdialtimeout, 10, 64)
  122. if err != nil {
  123. f := "Invalid dial timeout '%v': %v"
  124. return p, fmt.Errorf(f, strdialtimeout, err.Error())
  125. }
  126. p.dial_timeout = time.Duration(timeout) * time.Second
  127. }
  128. // default keep alive should be 30 seconds according to spec:
  129. // https://msdn.microsoft.com/en-us/library/dd341108.aspx
  130. p.keepAlive = 30 * time.Second
  131. if keepAlive, ok := params["keepalive"]; ok {
  132. timeout, err := strconv.ParseUint(keepAlive, 10, 64)
  133. if err != nil {
  134. f := "Invalid keepAlive value '%s': %s"
  135. return p, fmt.Errorf(f, keepAlive, err.Error())
  136. }
  137. p.keepAlive = time.Duration(timeout) * time.Second
  138. }
  139. encrypt, ok := params["encrypt"]
  140. if ok {
  141. if strings.EqualFold(encrypt, "DISABLE") {
  142. p.disableEncryption = true
  143. } else {
  144. var err error
  145. p.encrypt, err = strconv.ParseBool(encrypt)
  146. if err != nil {
  147. f := "Invalid encrypt '%s': %s"
  148. return p, fmt.Errorf(f, encrypt, err.Error())
  149. }
  150. }
  151. } else {
  152. p.trustServerCertificate = true
  153. }
  154. trust, ok := params["trustservercertificate"]
  155. if ok {
  156. var err error
  157. p.trustServerCertificate, err = strconv.ParseBool(trust)
  158. if err != nil {
  159. f := "Invalid trust server certificate '%s': %s"
  160. return p, fmt.Errorf(f, trust, err.Error())
  161. }
  162. }
  163. p.certificate = params["certificate"]
  164. p.hostInCertificate, ok = params["hostnameincertificate"]
  165. if ok {
  166. p.hostInCertificateProvided = true
  167. } else {
  168. p.hostInCertificate = p.host
  169. p.hostInCertificateProvided = false
  170. }
  171. serverSPN, ok := params["serverspn"]
  172. if ok {
  173. p.serverSPN = serverSPN
  174. } else {
  175. p.serverSPN = fmt.Sprintf("MSSQLSvc/%s:%d", p.host, p.port)
  176. }
  177. workstation, ok := params["workstation id"]
  178. if ok {
  179. p.workstation = workstation
  180. } else {
  181. workstation, err := os.Hostname()
  182. if err == nil {
  183. p.workstation = workstation
  184. }
  185. }
  186. appname, ok := params["app name"]
  187. if !ok {
  188. appname = "go-mssqldb"
  189. }
  190. p.appname = appname
  191. appintent, ok := params["applicationintent"]
  192. if ok {
  193. if appintent == "ReadOnly" {
  194. p.typeFlags |= fReadOnlyIntent
  195. }
  196. }
  197. failOverPartner, ok := params["failoverpartner"]
  198. if ok {
  199. p.failOverPartner = failOverPartner
  200. }
  201. failOverPort, ok := params["failoverport"]
  202. if ok {
  203. var err error
  204. p.failOverPort, err = strconv.ParseUint(failOverPort, 0, 16)
  205. if err != nil {
  206. f := "Invalid tcp port '%v': %v"
  207. return p, fmt.Errorf(f, failOverPort, err.Error())
  208. }
  209. }
  210. return p, nil
  211. }
  212. func splitConnectionString(dsn string) (res map[string]string) {
  213. res = map[string]string{}
  214. parts := strings.Split(dsn, ";")
  215. for _, part := range parts {
  216. if len(part) == 0 {
  217. continue
  218. }
  219. lst := strings.SplitN(part, "=", 2)
  220. name := strings.TrimSpace(strings.ToLower(lst[0]))
  221. if len(name) == 0 {
  222. continue
  223. }
  224. var value string = ""
  225. if len(lst) > 1 {
  226. value = strings.TrimSpace(lst[1])
  227. }
  228. res[name] = value
  229. }
  230. return res
  231. }
  232. // Splits a URL of the form sqlserver://username:password@host/instance?param1=value&param2=value
  233. func splitConnectionStringURL(dsn string) (map[string]string, error) {
  234. res := map[string]string{}
  235. u, err := url.Parse(dsn)
  236. if err != nil {
  237. return res, err
  238. }
  239. if u.Scheme != "sqlserver" {
  240. return res, fmt.Errorf("scheme %s is not recognized", u.Scheme)
  241. }
  242. if u.User != nil {
  243. res["user id"] = u.User.Username()
  244. p, exists := u.User.Password()
  245. if exists {
  246. res["password"] = p
  247. }
  248. }
  249. host, port, err := net.SplitHostPort(u.Host)
  250. if err != nil {
  251. host = u.Host
  252. }
  253. if len(u.Path) > 0 {
  254. res["server"] = host + "\\" + u.Path[1:]
  255. } else {
  256. res["server"] = host
  257. }
  258. if len(port) > 0 {
  259. res["port"] = port
  260. }
  261. query := u.Query()
  262. for k, v := range query {
  263. if len(v) > 1 {
  264. return res, fmt.Errorf("key %s provided more than once", k)
  265. }
  266. res[strings.ToLower(k)] = v[0]
  267. }
  268. return res, nil
  269. }
  270. // Splits a URL in the ODBC format
  271. func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
  272. res := map[string]string{}
  273. type parserState int
  274. const (
  275. // Before the start of a key
  276. parserStateBeforeKey parserState = iota
  277. // Inside a key
  278. parserStateKey
  279. // Beginning of a value. May be bare or braced
  280. parserStateBeginValue
  281. // Inside a bare value
  282. parserStateBareValue
  283. // Inside a braced value
  284. parserStateBracedValue
  285. // A closing brace inside a braced value.
  286. // May be the end of the value or an escaped closing brace, depending on the next character
  287. parserStateBracedValueClosingBrace
  288. // After a value. Next character should be a semicolon or whitespace.
  289. parserStateEndValue
  290. )
  291. var state = parserStateBeforeKey
  292. var key string
  293. var value string
  294. for i, c := range dsn {
  295. switch state {
  296. case parserStateBeforeKey:
  297. switch {
  298. case c == '=':
  299. return res, fmt.Errorf("Unexpected character = at index %d. Expected start of key or semi-colon or whitespace.", i)
  300. case !unicode.IsSpace(c) && c != ';':
  301. state = parserStateKey
  302. key += string(c)
  303. }
  304. case parserStateKey:
  305. switch c {
  306. case '=':
  307. key = normalizeOdbcKey(key)
  308. state = parserStateBeginValue
  309. case ';':
  310. // Key without value
  311. key = normalizeOdbcKey(key)
  312. res[key] = value
  313. key = ""
  314. value = ""
  315. state = parserStateBeforeKey
  316. default:
  317. key += string(c)
  318. }
  319. case parserStateBeginValue:
  320. switch {
  321. case c == '{':
  322. state = parserStateBracedValue
  323. case c == ';':
  324. // Empty value
  325. res[key] = value
  326. key = ""
  327. state = parserStateBeforeKey
  328. case unicode.IsSpace(c):
  329. // Ignore whitespace
  330. default:
  331. state = parserStateBareValue
  332. value += string(c)
  333. }
  334. case parserStateBareValue:
  335. if c == ';' {
  336. res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
  337. key = ""
  338. value = ""
  339. state = parserStateBeforeKey
  340. } else {
  341. value += string(c)
  342. }
  343. case parserStateBracedValue:
  344. if c == '}' {
  345. state = parserStateBracedValueClosingBrace
  346. } else {
  347. value += string(c)
  348. }
  349. case parserStateBracedValueClosingBrace:
  350. if c == '}' {
  351. // Escaped closing brace
  352. value += string(c)
  353. state = parserStateBracedValue
  354. continue
  355. }
  356. // End of braced value
  357. res[key] = value
  358. key = ""
  359. value = ""
  360. // This character is the first character past the end,
  361. // so it needs to be parsed like the parserStateEndValue state.
  362. state = parserStateEndValue
  363. switch {
  364. case c == ';':
  365. state = parserStateBeforeKey
  366. case unicode.IsSpace(c):
  367. // Ignore whitespace
  368. default:
  369. return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
  370. }
  371. case parserStateEndValue:
  372. switch {
  373. case c == ';':
  374. state = parserStateBeforeKey
  375. case unicode.IsSpace(c):
  376. // Ignore whitespace
  377. default:
  378. return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
  379. }
  380. }
  381. }
  382. switch state {
  383. case parserStateBeforeKey: // Okay
  384. case parserStateKey: // Unfinished key. Treat as key without value.
  385. key = normalizeOdbcKey(key)
  386. res[key] = value
  387. case parserStateBeginValue: // Empty value
  388. res[key] = value
  389. case parserStateBareValue:
  390. res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
  391. case parserStateBracedValue:
  392. return res, fmt.Errorf("Unexpected end of braced value at index %d.", len(dsn))
  393. case parserStateBracedValueClosingBrace: // End of braced value
  394. res[key] = value
  395. case parserStateEndValue: // Okay
  396. }
  397. return res, nil
  398. }
  399. // Normalizes the given string as an ODBC-format key
  400. func normalizeOdbcKey(s string) string {
  401. return strings.ToLower(strings.TrimRightFunc(s, unicode.IsSpace))
  402. }
上海开阖软件有限公司 沪ICP备12045867号-1