本站源代码
Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

915 lines
22KB

  1. package mssql
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "crypto/x509"
  6. "encoding/binary"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "io/ioutil"
  11. "net"
  12. "sort"
  13. "strconv"
  14. "strings"
  15. "unicode/utf16"
  16. "unicode/utf8"
  17. )
  18. func parseInstances(msg []byte) map[string]map[string]string {
  19. results := map[string]map[string]string{}
  20. if len(msg) > 3 && msg[0] == 5 {
  21. out_s := string(msg[3:])
  22. tokens := strings.Split(out_s, ";")
  23. instdict := map[string]string{}
  24. got_name := false
  25. var name string
  26. for _, token := range tokens {
  27. if got_name {
  28. instdict[name] = token
  29. got_name = false
  30. } else {
  31. name = token
  32. if len(name) == 0 {
  33. if len(instdict) == 0 {
  34. break
  35. }
  36. results[strings.ToUpper(instdict["InstanceName"])] = instdict
  37. instdict = map[string]string{}
  38. continue
  39. }
  40. got_name = true
  41. }
  42. }
  43. }
  44. return results
  45. }
  46. func getInstances(ctx context.Context, d Dialer, address string) (map[string]map[string]string, error) {
  47. conn, err := d.DialContext(ctx, "udp", address+":1434")
  48. if err != nil {
  49. return nil, err
  50. }
  51. defer conn.Close()
  52. deadline, _ := ctx.Deadline()
  53. conn.SetDeadline(deadline)
  54. _, err = conn.Write([]byte{3})
  55. if err != nil {
  56. return nil, err
  57. }
  58. var resp = make([]byte, 16*1024-1)
  59. read, err := conn.Read(resp)
  60. if err != nil {
  61. return nil, err
  62. }
  63. return parseInstances(resp[:read]), nil
  64. }
  65. // tds versions
  66. const (
  67. verTDS70 = 0x70000000
  68. verTDS71 = 0x71000000
  69. verTDS71rev1 = 0x71000001
  70. verTDS72 = 0x72090002
  71. verTDS73A = 0x730A0003
  72. verTDS73 = verTDS73A
  73. verTDS73B = 0x730B0003
  74. verTDS74 = 0x74000004
  75. )
  76. // packet types
  77. // https://msdn.microsoft.com/en-us/library/dd304214.aspx
  78. const (
  79. packSQLBatch packetType = 1
  80. packRPCRequest = 3
  81. packReply = 4
  82. // 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
  83. // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
  84. packAttention = 6
  85. packBulkLoadBCP = 7
  86. packTransMgrReq = 14
  87. packNormal = 15
  88. packLogin7 = 16
  89. packSSPIMessage = 17
  90. packPrelogin = 18
  91. )
  92. // prelogin fields
  93. // http://msdn.microsoft.com/en-us/library/dd357559.aspx
  94. const (
  95. preloginVERSION = 0
  96. preloginENCRYPTION = 1
  97. preloginINSTOPT = 2
  98. preloginTHREADID = 3
  99. preloginMARS = 4
  100. preloginTRACEID = 5
  101. preloginTERMINATOR = 0xff
  102. )
  103. const (
  104. encryptOff = 0 // Encryption is available but off.
  105. encryptOn = 1 // Encryption is available and on.
  106. encryptNotSup = 2 // Encryption is not available.
  107. encryptReq = 3 // Encryption is required.
  108. )
  109. type tdsSession struct {
  110. buf *tdsBuffer
  111. loginAck loginAckStruct
  112. database string
  113. partner string
  114. columns []columnStruct
  115. tranid uint64
  116. logFlags uint64
  117. log optionalLogger
  118. routedServer string
  119. routedPort uint16
  120. }
  121. const (
  122. logErrors = 1
  123. logMessages = 2
  124. logRows = 4
  125. logSQL = 8
  126. logParams = 16
  127. logTransaction = 32
  128. logDebug = 64
  129. )
  130. type columnStruct struct {
  131. UserType uint32
  132. Flags uint16
  133. ColName string
  134. ti typeInfo
  135. }
  136. type keySlice []uint8
  137. func (p keySlice) Len() int { return len(p) }
  138. func (p keySlice) Less(i, j int) bool { return p[i] < p[j] }
  139. func (p keySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
  140. // http://msdn.microsoft.com/en-us/library/dd357559.aspx
  141. func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error {
  142. var err error
  143. w.BeginPacket(packPrelogin, false)
  144. offset := uint16(5*len(fields) + 1)
  145. keys := make(keySlice, 0, len(fields))
  146. for k, _ := range fields {
  147. keys = append(keys, k)
  148. }
  149. sort.Sort(keys)
  150. // writing header
  151. for _, k := range keys {
  152. err = w.WriteByte(k)
  153. if err != nil {
  154. return err
  155. }
  156. err = binary.Write(w, binary.BigEndian, offset)
  157. if err != nil {
  158. return err
  159. }
  160. v := fields[k]
  161. size := uint16(len(v))
  162. err = binary.Write(w, binary.BigEndian, size)
  163. if err != nil {
  164. return err
  165. }
  166. offset += size
  167. }
  168. err = w.WriteByte(preloginTERMINATOR)
  169. if err != nil {
  170. return err
  171. }
  172. // writing values
  173. for _, k := range keys {
  174. v := fields[k]
  175. written, err := w.Write(v)
  176. if err != nil {
  177. return err
  178. }
  179. if written != len(v) {
  180. return errors.New("Write method didn't write the whole value")
  181. }
  182. }
  183. return w.FinishPacket()
  184. }
  185. func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) {
  186. packet_type, err := r.BeginRead()
  187. if err != nil {
  188. return nil, err
  189. }
  190. struct_buf, err := ioutil.ReadAll(r)
  191. if err != nil {
  192. return nil, err
  193. }
  194. if packet_type != 4 {
  195. return nil, errors.New("Invalid respones, expected packet type 4, PRELOGIN RESPONSE")
  196. }
  197. offset := 0
  198. results := map[uint8][]byte{}
  199. for true {
  200. rec_type := struct_buf[offset]
  201. if rec_type == preloginTERMINATOR {
  202. break
  203. }
  204. rec_offset := binary.BigEndian.Uint16(struct_buf[offset+1:])
  205. rec_len := binary.BigEndian.Uint16(struct_buf[offset+3:])
  206. value := struct_buf[rec_offset : rec_offset+rec_len]
  207. results[rec_type] = value
  208. offset += 5
  209. }
  210. return results, nil
  211. }
  212. // OptionFlags2
  213. // http://msdn.microsoft.com/en-us/library/dd304019.aspx
  214. const (
  215. fLanguageFatal = 1
  216. fODBC = 2
  217. fTransBoundary = 4
  218. fCacheConnect = 8
  219. fIntSecurity = 0x80
  220. )
  221. // TypeFlags
  222. const (
  223. // 4 bits for fSQLType
  224. // 1 bit for fOLEDB
  225. fReadOnlyIntent = 32
  226. )
  227. type login struct {
  228. TDSVersion uint32
  229. PacketSize uint32
  230. ClientProgVer uint32
  231. ClientPID uint32
  232. ConnectionID uint32
  233. OptionFlags1 uint8
  234. OptionFlags2 uint8
  235. TypeFlags uint8
  236. OptionFlags3 uint8
  237. ClientTimeZone int32
  238. ClientLCID uint32
  239. HostName string
  240. UserName string
  241. Password string
  242. AppName string
  243. ServerName string
  244. CtlIntName string
  245. Language string
  246. Database string
  247. ClientID [6]byte
  248. SSPI []byte
  249. AtchDBFile string
  250. ChangePassword string
  251. }
  252. type loginHeader struct {
  253. Length uint32
  254. TDSVersion uint32
  255. PacketSize uint32
  256. ClientProgVer uint32
  257. ClientPID uint32
  258. ConnectionID uint32
  259. OptionFlags1 uint8
  260. OptionFlags2 uint8
  261. TypeFlags uint8
  262. OptionFlags3 uint8
  263. ClientTimeZone int32
  264. ClientLCID uint32
  265. HostNameOffset uint16
  266. HostNameLength uint16
  267. UserNameOffset uint16
  268. UserNameLength uint16
  269. PasswordOffset uint16
  270. PasswordLength uint16
  271. AppNameOffset uint16
  272. AppNameLength uint16
  273. ServerNameOffset uint16
  274. ServerNameLength uint16
  275. ExtensionOffset uint16
  276. ExtensionLenght uint16
  277. CtlIntNameOffset uint16
  278. CtlIntNameLength uint16
  279. LanguageOffset uint16
  280. LanguageLength uint16
  281. DatabaseOffset uint16
  282. DatabaseLength uint16
  283. ClientID [6]byte
  284. SSPIOffset uint16
  285. SSPILength uint16
  286. AtchDBFileOffset uint16
  287. AtchDBFileLength uint16
  288. ChangePasswordOffset uint16
  289. ChangePasswordLength uint16
  290. SSPILongLength uint32
  291. }
  292. // convert Go string to UTF-16 encoded []byte (littleEndian)
  293. // done manually rather than using bytes and binary packages
  294. // for performance reasons
  295. func str2ucs2(s string) []byte {
  296. res := utf16.Encode([]rune(s))
  297. ucs2 := make([]byte, 2*len(res))
  298. for i := 0; i < len(res); i++ {
  299. ucs2[2*i] = byte(res[i])
  300. ucs2[2*i+1] = byte(res[i] >> 8)
  301. }
  302. return ucs2
  303. }
  304. func ucs22str(s []byte) (string, error) {
  305. if len(s)%2 != 0 {
  306. return "", fmt.Errorf("Illegal UCS2 string length: %d", len(s))
  307. }
  308. buf := make([]uint16, len(s)/2)
  309. for i := 0; i < len(s); i += 2 {
  310. buf[i/2] = binary.LittleEndian.Uint16(s[i:])
  311. }
  312. return string(utf16.Decode(buf)), nil
  313. }
  314. func manglePassword(password string) []byte {
  315. var ucs2password []byte = str2ucs2(password)
  316. for i, ch := range ucs2password {
  317. ucs2password[i] = ((ch<<4)&0xff | (ch >> 4)) ^ 0xA5
  318. }
  319. return ucs2password
  320. }
  321. // http://msdn.microsoft.com/en-us/library/dd304019.aspx
  322. func sendLogin(w *tdsBuffer, login login) error {
  323. w.BeginPacket(packLogin7, false)
  324. hostname := str2ucs2(login.HostName)
  325. username := str2ucs2(login.UserName)
  326. password := manglePassword(login.Password)
  327. appname := str2ucs2(login.AppName)
  328. servername := str2ucs2(login.ServerName)
  329. ctlintname := str2ucs2(login.CtlIntName)
  330. language := str2ucs2(login.Language)
  331. database := str2ucs2(login.Database)
  332. atchdbfile := str2ucs2(login.AtchDBFile)
  333. changepassword := str2ucs2(login.ChangePassword)
  334. hdr := loginHeader{
  335. TDSVersion: login.TDSVersion,
  336. PacketSize: login.PacketSize,
  337. ClientProgVer: login.ClientProgVer,
  338. ClientPID: login.ClientPID,
  339. ConnectionID: login.ConnectionID,
  340. OptionFlags1: login.OptionFlags1,
  341. OptionFlags2: login.OptionFlags2,
  342. TypeFlags: login.TypeFlags,
  343. OptionFlags3: login.OptionFlags3,
  344. ClientTimeZone: login.ClientTimeZone,
  345. ClientLCID: login.ClientLCID,
  346. HostNameLength: uint16(utf8.RuneCountInString(login.HostName)),
  347. UserNameLength: uint16(utf8.RuneCountInString(login.UserName)),
  348. PasswordLength: uint16(utf8.RuneCountInString(login.Password)),
  349. AppNameLength: uint16(utf8.RuneCountInString(login.AppName)),
  350. ServerNameLength: uint16(utf8.RuneCountInString(login.ServerName)),
  351. CtlIntNameLength: uint16(utf8.RuneCountInString(login.CtlIntName)),
  352. LanguageLength: uint16(utf8.RuneCountInString(login.Language)),
  353. DatabaseLength: uint16(utf8.RuneCountInString(login.Database)),
  354. ClientID: login.ClientID,
  355. SSPILength: uint16(len(login.SSPI)),
  356. AtchDBFileLength: uint16(utf8.RuneCountInString(login.AtchDBFile)),
  357. ChangePasswordLength: uint16(utf8.RuneCountInString(login.ChangePassword)),
  358. }
  359. offset := uint16(binary.Size(hdr))
  360. hdr.HostNameOffset = offset
  361. offset += uint16(len(hostname))
  362. hdr.UserNameOffset = offset
  363. offset += uint16(len(username))
  364. hdr.PasswordOffset = offset
  365. offset += uint16(len(password))
  366. hdr.AppNameOffset = offset
  367. offset += uint16(len(appname))
  368. hdr.ServerNameOffset = offset
  369. offset += uint16(len(servername))
  370. hdr.CtlIntNameOffset = offset
  371. offset += uint16(len(ctlintname))
  372. hdr.LanguageOffset = offset
  373. offset += uint16(len(language))
  374. hdr.DatabaseOffset = offset
  375. offset += uint16(len(database))
  376. hdr.SSPIOffset = offset
  377. offset += uint16(len(login.SSPI))
  378. hdr.AtchDBFileOffset = offset
  379. offset += uint16(len(atchdbfile))
  380. hdr.ChangePasswordOffset = offset
  381. offset += uint16(len(changepassword))
  382. hdr.Length = uint32(offset)
  383. var err error
  384. err = binary.Write(w, binary.LittleEndian, &hdr)
  385. if err != nil {
  386. return err
  387. }
  388. _, err = w.Write(hostname)
  389. if err != nil {
  390. return err
  391. }
  392. _, err = w.Write(username)
  393. if err != nil {
  394. return err
  395. }
  396. _, err = w.Write(password)
  397. if err != nil {
  398. return err
  399. }
  400. _, err = w.Write(appname)
  401. if err != nil {
  402. return err
  403. }
  404. _, err = w.Write(servername)
  405. if err != nil {
  406. return err
  407. }
  408. _, err = w.Write(ctlintname)
  409. if err != nil {
  410. return err
  411. }
  412. _, err = w.Write(language)
  413. if err != nil {
  414. return err
  415. }
  416. _, err = w.Write(database)
  417. if err != nil {
  418. return err
  419. }
  420. _, err = w.Write(login.SSPI)
  421. if err != nil {
  422. return err
  423. }
  424. _, err = w.Write(atchdbfile)
  425. if err != nil {
  426. return err
  427. }
  428. _, err = w.Write(changepassword)
  429. if err != nil {
  430. return err
  431. }
  432. return w.FinishPacket()
  433. }
  434. func readUcs2(r io.Reader, numchars int) (res string, err error) {
  435. buf := make([]byte, numchars*2)
  436. _, err = io.ReadFull(r, buf)
  437. if err != nil {
  438. return "", err
  439. }
  440. return ucs22str(buf)
  441. }
  442. func readUsVarChar(r io.Reader) (res string, err error) {
  443. numchars, err := readUshort(r)
  444. if err != nil {
  445. return
  446. }
  447. return readUcs2(r, int(numchars))
  448. }
  449. func writeUsVarChar(w io.Writer, s string) (err error) {
  450. buf := str2ucs2(s)
  451. var numchars int = len(buf) / 2
  452. if numchars > 0xffff {
  453. panic("invalid size for US_VARCHAR")
  454. }
  455. err = binary.Write(w, binary.LittleEndian, uint16(numchars))
  456. if err != nil {
  457. return
  458. }
  459. _, err = w.Write(buf)
  460. return
  461. }
  462. func readBVarChar(r io.Reader) (res string, err error) {
  463. numchars, err := readByte(r)
  464. if err != nil {
  465. return "", err
  466. }
  467. // A zero length could be returned, return an empty string
  468. if numchars == 0 {
  469. return "", nil
  470. }
  471. return readUcs2(r, int(numchars))
  472. }
  473. func writeBVarChar(w io.Writer, s string) (err error) {
  474. buf := str2ucs2(s)
  475. var numchars int = len(buf) / 2
  476. if numchars > 0xff {
  477. panic("invalid size for B_VARCHAR")
  478. }
  479. err = binary.Write(w, binary.LittleEndian, uint8(numchars))
  480. if err != nil {
  481. return
  482. }
  483. _, err = w.Write(buf)
  484. return
  485. }
  486. func readBVarByte(r io.Reader) (res []byte, err error) {
  487. length, err := readByte(r)
  488. if err != nil {
  489. return
  490. }
  491. res = make([]byte, length)
  492. _, err = io.ReadFull(r, res)
  493. return
  494. }
  495. func readUshort(r io.Reader) (res uint16, err error) {
  496. err = binary.Read(r, binary.LittleEndian, &res)
  497. return
  498. }
  499. func readByte(r io.Reader) (res byte, err error) {
  500. var b [1]byte
  501. _, err = r.Read(b[:])
  502. res = b[0]
  503. return
  504. }
  505. // Packet Data Stream Headers
  506. // http://msdn.microsoft.com/en-us/library/dd304953.aspx
  507. type headerStruct struct {
  508. hdrtype uint16
  509. data []byte
  510. }
  511. const (
  512. dataStmHdrQueryNotif = 1 // query notifications
  513. dataStmHdrTransDescr = 2 // MARS transaction descriptor (required)
  514. dataStmHdrTraceActivity = 3
  515. )
  516. // Query Notifications Header
  517. // http://msdn.microsoft.com/en-us/library/dd304949.aspx
  518. type queryNotifHdr struct {
  519. notifyId string
  520. ssbDeployment string
  521. notifyTimeout uint32
  522. }
  523. func (hdr queryNotifHdr) pack() (res []byte) {
  524. notifyId := str2ucs2(hdr.notifyId)
  525. ssbDeployment := str2ucs2(hdr.ssbDeployment)
  526. res = make([]byte, 2+len(notifyId)+2+len(ssbDeployment)+4)
  527. b := res
  528. binary.LittleEndian.PutUint16(b, uint16(len(notifyId)))
  529. b = b[2:]
  530. copy(b, notifyId)
  531. b = b[len(notifyId):]
  532. binary.LittleEndian.PutUint16(b, uint16(len(ssbDeployment)))
  533. b = b[2:]
  534. copy(b, ssbDeployment)
  535. b = b[len(ssbDeployment):]
  536. binary.LittleEndian.PutUint32(b, hdr.notifyTimeout)
  537. return res
  538. }
  539. // MARS Transaction Descriptor Header
  540. // http://msdn.microsoft.com/en-us/library/dd340515.aspx
  541. type transDescrHdr struct {
  542. transDescr uint64 // transaction descriptor returned from ENVCHANGE
  543. outstandingReqCnt uint32 // outstanding request count
  544. }
  545. func (hdr transDescrHdr) pack() (res []byte) {
  546. res = make([]byte, 8+4)
  547. binary.LittleEndian.PutUint64(res, hdr.transDescr)
  548. binary.LittleEndian.PutUint32(res[8:], hdr.outstandingReqCnt)
  549. return res
  550. }
  551. func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) {
  552. // Calculating total length.
  553. var totallen uint32 = 4
  554. for _, hdr := range headers {
  555. totallen += 4 + 2 + uint32(len(hdr.data))
  556. }
  557. // writing
  558. err = binary.Write(w, binary.LittleEndian, totallen)
  559. if err != nil {
  560. return err
  561. }
  562. for _, hdr := range headers {
  563. var headerlen uint32 = 4 + 2 + uint32(len(hdr.data))
  564. err = binary.Write(w, binary.LittleEndian, headerlen)
  565. if err != nil {
  566. return err
  567. }
  568. err = binary.Write(w, binary.LittleEndian, hdr.hdrtype)
  569. if err != nil {
  570. return err
  571. }
  572. _, err = w.Write(hdr.data)
  573. if err != nil {
  574. return err
  575. }
  576. }
  577. return nil
  578. }
  579. func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct, resetSession bool) (err error) {
  580. buf.BeginPacket(packSQLBatch, resetSession)
  581. if err = writeAllHeaders(buf, headers); err != nil {
  582. return
  583. }
  584. _, err = buf.Write(str2ucs2(sqltext))
  585. if err != nil {
  586. return
  587. }
  588. return buf.FinishPacket()
  589. }
  590. // 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
  591. // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
  592. func sendAttention(buf *tdsBuffer) error {
  593. buf.BeginPacket(packAttention, false)
  594. return buf.FinishPacket()
  595. }
  596. type auth interface {
  597. InitialBytes() ([]byte, error)
  598. NextBytes([]byte) ([]byte, error)
  599. Free()
  600. }
  601. // SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a
  602. // list of IP addresses. So if there is more than one, try them all and
  603. // use the first one that allows a connection.
  604. func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn net.Conn, err error) {
  605. var ips []net.IP
  606. ips, err = net.LookupIP(p.host)
  607. if err != nil {
  608. ip := net.ParseIP(p.host)
  609. if ip == nil {
  610. return nil, err
  611. }
  612. ips = []net.IP{ip}
  613. }
  614. if len(ips) == 1 {
  615. d := c.getDialer(&p)
  616. addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(p.port)))
  617. conn, err = d.DialContext(ctx, "tcp", addr)
  618. } else {
  619. //Try Dials in parallel to avoid waiting for timeouts.
  620. connChan := make(chan net.Conn, len(ips))
  621. errChan := make(chan error, len(ips))
  622. portStr := strconv.Itoa(int(p.port))
  623. for _, ip := range ips {
  624. go func(ip net.IP) {
  625. d := c.getDialer(&p)
  626. addr := net.JoinHostPort(ip.String(), portStr)
  627. conn, err := d.DialContext(ctx, "tcp", addr)
  628. if err == nil {
  629. connChan <- conn
  630. } else {
  631. errChan <- err
  632. }
  633. }(ip)
  634. }
  635. // Wait for either the *first* successful connection, or all the errors
  636. wait_loop:
  637. for i, _ := range ips {
  638. select {
  639. case conn = <-connChan:
  640. // Got a connection to use, close any others
  641. go func(n int) {
  642. for i := 0; i < n; i++ {
  643. select {
  644. case conn := <-connChan:
  645. conn.Close()
  646. case <-errChan:
  647. }
  648. }
  649. }(len(ips) - i - 1)
  650. // Remove any earlier errors we may have collected
  651. err = nil
  652. break wait_loop
  653. case err = <-errChan:
  654. }
  655. }
  656. }
  657. // Can't do the usual err != nil check, as it is possible to have gotten an error before a successful connection
  658. if conn == nil {
  659. f := "Unable to open tcp connection with host '%v:%v': %v"
  660. return nil, fmt.Errorf(f, p.host, p.port, err.Error())
  661. }
  662. return conn, err
  663. }
  664. func connect(ctx context.Context, c *Connector, log optionalLogger, p connectParams) (res *tdsSession, err error) {
  665. dialCtx := ctx
  666. if p.dial_timeout > 0 {
  667. var cancel func()
  668. dialCtx, cancel = context.WithTimeout(ctx, p.dial_timeout)
  669. defer cancel()
  670. }
  671. // if instance is specified use instance resolution service
  672. if p.instance != "" {
  673. p.instance = strings.ToUpper(p.instance)
  674. d := c.getDialer(&p)
  675. instances, err := getInstances(dialCtx, d, p.host)
  676. if err != nil {
  677. f := "Unable to get instances from Sql Server Browser on host %v: %v"
  678. return nil, fmt.Errorf(f, p.host, err.Error())
  679. }
  680. strport, ok := instances[p.instance]["tcp"]
  681. if !ok {
  682. f := "No instance matching '%v' returned from host '%v'"
  683. return nil, fmt.Errorf(f, p.instance, p.host)
  684. }
  685. p.port, err = strconv.ParseUint(strport, 0, 16)
  686. if err != nil {
  687. f := "Invalid tcp port returned from Sql Server Browser '%v': %v"
  688. return nil, fmt.Errorf(f, strport, err.Error())
  689. }
  690. }
  691. initiate_connection:
  692. conn, err := dialConnection(dialCtx, c, p)
  693. if err != nil {
  694. return nil, err
  695. }
  696. toconn := newTimeoutConn(conn, p.conn_timeout)
  697. outbuf := newTdsBuffer(p.packetSize, toconn)
  698. sess := tdsSession{
  699. buf: outbuf,
  700. log: log,
  701. logFlags: p.logFlags,
  702. }
  703. instance_buf := []byte(p.instance)
  704. instance_buf = append(instance_buf, 0) // zero terminate instance name
  705. var encrypt byte
  706. if p.disableEncryption {
  707. encrypt = encryptNotSup
  708. } else if p.encrypt {
  709. encrypt = encryptOn
  710. } else {
  711. encrypt = encryptOff
  712. }
  713. fields := map[uint8][]byte{
  714. preloginVERSION: {0, 0, 0, 0, 0, 0},
  715. preloginENCRYPTION: {encrypt},
  716. preloginINSTOPT: instance_buf,
  717. preloginTHREADID: {0, 0, 0, 0},
  718. preloginMARS: {0}, // MARS disabled
  719. }
  720. err = writePrelogin(outbuf, fields)
  721. if err != nil {
  722. return nil, err
  723. }
  724. fields, err = readPrelogin(outbuf)
  725. if err != nil {
  726. return nil, err
  727. }
  728. encryptBytes, ok := fields[preloginENCRYPTION]
  729. if !ok {
  730. return nil, fmt.Errorf("Encrypt negotiation failed")
  731. }
  732. encrypt = encryptBytes[0]
  733. if p.encrypt && (encrypt == encryptNotSup || encrypt == encryptOff) {
  734. return nil, fmt.Errorf("Server does not support encryption")
  735. }
  736. if encrypt != encryptNotSup {
  737. var config tls.Config
  738. if p.certificate != "" {
  739. pem, err := ioutil.ReadFile(p.certificate)
  740. if err != nil {
  741. return nil, fmt.Errorf("Cannot read certificate %q: %v", p.certificate, err)
  742. }
  743. certs := x509.NewCertPool()
  744. certs.AppendCertsFromPEM(pem)
  745. config.RootCAs = certs
  746. }
  747. if p.trustServerCertificate {
  748. config.InsecureSkipVerify = true
  749. }
  750. config.ServerName = p.hostInCertificate
  751. // fix for https://github.com/denisenkom/go-mssqldb/issues/166
  752. // Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
  753. // while SQL Server seems to expect one TCP segment per encrypted TDS package.
  754. // Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
  755. config.DynamicRecordSizingDisabled = true
  756. // setting up connection handler which will allow wrapping of TLS handshake packets inside TDS stream
  757. handshakeConn := tlsHandshakeConn{buf: outbuf}
  758. passthrough := passthroughConn{c: &handshakeConn}
  759. tlsConn := tls.Client(&passthrough, &config)
  760. err = tlsConn.Handshake()
  761. passthrough.c = toconn
  762. outbuf.transport = tlsConn
  763. if err != nil {
  764. return nil, fmt.Errorf("TLS Handshake failed: %v", err)
  765. }
  766. if encrypt == encryptOff {
  767. outbuf.afterFirst = func() {
  768. outbuf.transport = toconn
  769. }
  770. }
  771. }
  772. login := login{
  773. TDSVersion: verTDS74,
  774. PacketSize: uint32(outbuf.PackageSize()),
  775. Database: p.database,
  776. OptionFlags2: fODBC, // to get unlimited TEXTSIZE
  777. HostName: p.workstation,
  778. ServerName: p.host,
  779. AppName: p.appname,
  780. TypeFlags: p.typeFlags,
  781. }
  782. auth, auth_ok := getAuth(p.user, p.password, p.serverSPN, p.workstation)
  783. if auth_ok {
  784. login.SSPI, err = auth.InitialBytes()
  785. if err != nil {
  786. return nil, err
  787. }
  788. login.OptionFlags2 |= fIntSecurity
  789. defer auth.Free()
  790. } else {
  791. login.UserName = p.user
  792. login.Password = p.password
  793. }
  794. err = sendLogin(outbuf, login)
  795. if err != nil {
  796. return nil, err
  797. }
  798. // processing login response
  799. success := false
  800. for {
  801. tokchan := make(chan tokenStruct, 5)
  802. go processResponse(context.Background(), &sess, tokchan, nil)
  803. for tok := range tokchan {
  804. switch token := tok.(type) {
  805. case sspiMsg:
  806. sspi_msg, err := auth.NextBytes(token)
  807. if err != nil {
  808. return nil, err
  809. }
  810. if sspi_msg != nil && len(sspi_msg) > 0 {
  811. outbuf.BeginPacket(packSSPIMessage, false)
  812. _, err = outbuf.Write(sspi_msg)
  813. if err != nil {
  814. return nil, err
  815. }
  816. err = outbuf.FinishPacket()
  817. if err != nil {
  818. return nil, err
  819. }
  820. sspi_msg = nil
  821. }
  822. case loginAckStruct:
  823. success = true
  824. sess.loginAck = token
  825. case error:
  826. return nil, fmt.Errorf("Login error: %s", token.Error())
  827. case doneStruct:
  828. if token.isError() {
  829. return nil, fmt.Errorf("Login error: %s", token.getError())
  830. }
  831. goto loginEnd
  832. }
  833. }
  834. }
  835. loginEnd:
  836. if !success {
  837. return nil, fmt.Errorf("Login failed")
  838. }
  839. if sess.routedServer != "" {
  840. toconn.Close()
  841. p.host = sess.routedServer
  842. p.port = uint64(sess.routedPort)
  843. if !p.hostInCertificateProvided {
  844. p.hostInCertificate = sess.routedServer
  845. }
  846. goto initiate_connection
  847. }
  848. return &sess, nil
  849. }
上海开阖软件有限公司 沪ICP备12045867号-1