本站源代码
Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

584 lines
14KB

  1. package mssql
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/binary"
  6. "fmt"
  7. "math"
  8. "reflect"
  9. "strings"
  10. "time"
  11. "github.com/denisenkom/go-mssqldb/internal/decimal"
  12. )
  13. type Bulk struct {
  14. // ctx is used only for AddRow and Done methods.
  15. // This could be removed if AddRow and Done accepted
  16. // a ctx field as well, which is available with the
  17. // database/sql call.
  18. ctx context.Context
  19. cn *Conn
  20. metadata []columnStruct
  21. bulkColumns []columnStruct
  22. columnsName []string
  23. tablename string
  24. numRows int
  25. headerSent bool
  26. Options BulkOptions
  27. Debug bool
  28. }
  29. type BulkOptions struct {
  30. CheckConstraints bool
  31. FireTriggers bool
  32. KeepNulls bool
  33. KilobytesPerBatch int
  34. RowsPerBatch int
  35. Order []string
  36. Tablock bool
  37. }
  38. type DataValue interface{}
  39. const (
  40. sqlDateFormat = "2006-01-02"
  41. sqlTimeFormat = "2006-01-02 15:04:05.999999999Z07:00"
  42. )
  43. func (cn *Conn) CreateBulk(table string, columns []string) (_ *Bulk) {
  44. b := Bulk{ctx: context.Background(), cn: cn, tablename: table, headerSent: false, columnsName: columns}
  45. b.Debug = false
  46. return &b
  47. }
  48. func (cn *Conn) CreateBulkContext(ctx context.Context, table string, columns []string) (_ *Bulk) {
  49. b := Bulk{ctx: ctx, cn: cn, tablename: table, headerSent: false, columnsName: columns}
  50. b.Debug = false
  51. return &b
  52. }
  53. func (b *Bulk) sendBulkCommand(ctx context.Context) (err error) {
  54. //get table columns info
  55. err = b.getMetadata(ctx)
  56. if err != nil {
  57. return err
  58. }
  59. //match the columns
  60. for _, colname := range b.columnsName {
  61. var bulkCol *columnStruct
  62. for _, m := range b.metadata {
  63. if m.ColName == colname {
  64. bulkCol = &m
  65. break
  66. }
  67. }
  68. if bulkCol != nil {
  69. if bulkCol.ti.TypeId == typeUdt {
  70. //send udt as binary
  71. bulkCol.ti.TypeId = typeBigVarBin
  72. }
  73. b.bulkColumns = append(b.bulkColumns, *bulkCol)
  74. b.dlogf("Adding column %s %s %#x", colname, bulkCol.ColName, bulkCol.ti.TypeId)
  75. } else {
  76. return fmt.Errorf("Column %s does not exist in destination table %s", colname, b.tablename)
  77. }
  78. }
  79. //create the bulk command
  80. //columns definitions
  81. var col_defs bytes.Buffer
  82. for i, col := range b.bulkColumns {
  83. if i != 0 {
  84. col_defs.WriteString(", ")
  85. }
  86. col_defs.WriteString("[" + col.ColName + "] " + makeDecl(col.ti))
  87. }
  88. //options
  89. var with_opts []string
  90. if b.Options.CheckConstraints {
  91. with_opts = append(with_opts, "CHECK_CONSTRAINTS")
  92. }
  93. if b.Options.FireTriggers {
  94. with_opts = append(with_opts, "FIRE_TRIGGERS")
  95. }
  96. if b.Options.KeepNulls {
  97. with_opts = append(with_opts, "KEEP_NULLS")
  98. }
  99. if b.Options.KilobytesPerBatch > 0 {
  100. with_opts = append(with_opts, fmt.Sprintf("KILOBYTES_PER_BATCH = %d", b.Options.KilobytesPerBatch))
  101. }
  102. if b.Options.RowsPerBatch > 0 {
  103. with_opts = append(with_opts, fmt.Sprintf("ROWS_PER_BATCH = %d", b.Options.RowsPerBatch))
  104. }
  105. if len(b.Options.Order) > 0 {
  106. with_opts = append(with_opts, fmt.Sprintf("ORDER(%s)", strings.Join(b.Options.Order, ",")))
  107. }
  108. if b.Options.Tablock {
  109. with_opts = append(with_opts, "TABLOCK")
  110. }
  111. var with_part string
  112. if len(with_opts) > 0 {
  113. with_part = fmt.Sprintf("WITH (%s)", strings.Join(with_opts, ","))
  114. }
  115. query := fmt.Sprintf("INSERT BULK %s (%s) %s", b.tablename, col_defs.String(), with_part)
  116. stmt, err := b.cn.PrepareContext(ctx, query)
  117. if err != nil {
  118. return fmt.Errorf("Prepare failed: %s", err.Error())
  119. }
  120. b.dlogf(query)
  121. _, err = stmt.(*Stmt).ExecContext(ctx, nil)
  122. if err != nil {
  123. return err
  124. }
  125. b.headerSent = true
  126. var buf = b.cn.sess.buf
  127. buf.BeginPacket(packBulkLoadBCP, false)
  128. // Send the columns metadata.
  129. columnMetadata := b.createColMetadata()
  130. _, err = buf.Write(columnMetadata)
  131. return
  132. }
  133. // AddRow immediately writes the row to the destination table.
  134. // The arguments are the row values in the order they were specified.
  135. func (b *Bulk) AddRow(row []interface{}) (err error) {
  136. if !b.headerSent {
  137. err = b.sendBulkCommand(b.ctx)
  138. if err != nil {
  139. return
  140. }
  141. }
  142. if len(row) != len(b.bulkColumns) {
  143. return fmt.Errorf("Row does not have the same number of columns than the destination table %d %d",
  144. len(row), len(b.bulkColumns))
  145. }
  146. bytes, err := b.makeRowData(row)
  147. if err != nil {
  148. return
  149. }
  150. _, err = b.cn.sess.buf.Write(bytes)
  151. if err != nil {
  152. return
  153. }
  154. b.numRows = b.numRows + 1
  155. return
  156. }
  157. func (b *Bulk) makeRowData(row []interface{}) ([]byte, error) {
  158. buf := new(bytes.Buffer)
  159. buf.WriteByte(byte(tokenRow))
  160. var logcol bytes.Buffer
  161. for i, col := range b.bulkColumns {
  162. if b.Debug {
  163. logcol.WriteString(fmt.Sprintf(" col[%d]='%v' ", i, row[i]))
  164. }
  165. param, err := b.makeParam(row[i], col)
  166. if err != nil {
  167. return nil, fmt.Errorf("bulkcopy: %s", err.Error())
  168. }
  169. if col.ti.Writer == nil {
  170. return nil, fmt.Errorf("no writer for column: %s, TypeId: %#x",
  171. col.ColName, col.ti.TypeId)
  172. }
  173. err = col.ti.Writer(buf, param.ti, param.buffer)
  174. if err != nil {
  175. return nil, fmt.Errorf("bulkcopy: %s", err.Error())
  176. }
  177. }
  178. b.dlogf("row[%d] %s\n", b.numRows, logcol.String())
  179. return buf.Bytes(), nil
  180. }
  181. func (b *Bulk) Done() (rowcount int64, err error) {
  182. if b.headerSent == false {
  183. //no rows had been sent
  184. return 0, nil
  185. }
  186. var buf = b.cn.sess.buf
  187. buf.WriteByte(byte(tokenDone))
  188. binary.Write(buf, binary.LittleEndian, uint16(doneFinal))
  189. binary.Write(buf, binary.LittleEndian, uint16(0)) // curcmd
  190. if b.cn.sess.loginAck.TDSVersion >= verTDS72 {
  191. binary.Write(buf, binary.LittleEndian, uint64(0)) //rowcount 0
  192. } else {
  193. binary.Write(buf, binary.LittleEndian, uint32(0)) //rowcount 0
  194. }
  195. buf.FinishPacket()
  196. tokchan := make(chan tokenStruct, 5)
  197. go processResponse(b.ctx, b.cn.sess, tokchan, nil)
  198. var rowCount int64
  199. for token := range tokchan {
  200. switch token := token.(type) {
  201. case doneStruct:
  202. if token.Status&doneCount != 0 {
  203. rowCount = int64(token.RowCount)
  204. }
  205. if token.isError() {
  206. return 0, token.getError()
  207. }
  208. case error:
  209. return 0, b.cn.checkBadConn(token)
  210. }
  211. }
  212. return rowCount, nil
  213. }
  214. func (b *Bulk) createColMetadata() []byte {
  215. buf := new(bytes.Buffer)
  216. buf.WriteByte(byte(tokenColMetadata)) // token
  217. binary.Write(buf, binary.LittleEndian, uint16(len(b.bulkColumns))) // column count
  218. for i, col := range b.bulkColumns {
  219. if b.cn.sess.loginAck.TDSVersion >= verTDS72 {
  220. binary.Write(buf, binary.LittleEndian, uint32(col.UserType)) // usertype, always 0?
  221. } else {
  222. binary.Write(buf, binary.LittleEndian, uint16(col.UserType))
  223. }
  224. binary.Write(buf, binary.LittleEndian, uint16(col.Flags))
  225. writeTypeInfo(buf, &b.bulkColumns[i].ti)
  226. if col.ti.TypeId == typeNText ||
  227. col.ti.TypeId == typeText ||
  228. col.ti.TypeId == typeImage {
  229. tablename_ucs2 := str2ucs2(b.tablename)
  230. binary.Write(buf, binary.LittleEndian, uint16(len(tablename_ucs2)/2))
  231. buf.Write(tablename_ucs2)
  232. }
  233. colname_ucs2 := str2ucs2(col.ColName)
  234. buf.WriteByte(uint8(len(colname_ucs2) / 2))
  235. buf.Write(colname_ucs2)
  236. }
  237. return buf.Bytes()
  238. }
  239. func (b *Bulk) getMetadata(ctx context.Context) (err error) {
  240. stmt, err := b.cn.prepareContext(ctx, "SET FMTONLY ON")
  241. if err != nil {
  242. return
  243. }
  244. _, err = stmt.ExecContext(ctx, nil)
  245. if err != nil {
  246. return
  247. }
  248. // Get columns info.
  249. stmt, err = b.cn.prepareContext(ctx, fmt.Sprintf("select * from %s SET FMTONLY OFF", b.tablename))
  250. if err != nil {
  251. return
  252. }
  253. rows, err := stmt.QueryContext(ctx, nil)
  254. if err != nil {
  255. return fmt.Errorf("get columns info failed: %v", err)
  256. }
  257. b.metadata = rows.(*Rows).cols
  258. if b.Debug {
  259. for _, col := range b.metadata {
  260. b.dlogf("col: %s typeId: %#x size: %d scale: %d prec: %d flags: %d lcid: %#x\n",
  261. col.ColName, col.ti.TypeId, col.ti.Size, col.ti.Scale, col.ti.Prec,
  262. col.Flags, col.ti.Collation.LcidAndFlags)
  263. }
  264. }
  265. return rows.Close()
  266. }
  267. func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error) {
  268. res.ti.Size = col.ti.Size
  269. res.ti.TypeId = col.ti.TypeId
  270. if val == nil {
  271. res.ti.Size = 0
  272. return
  273. }
  274. switch col.ti.TypeId {
  275. case typeInt1, typeInt2, typeInt4, typeInt8, typeIntN:
  276. var intvalue int64
  277. switch val := val.(type) {
  278. case int:
  279. intvalue = int64(val)
  280. case int32:
  281. intvalue = int64(val)
  282. case int64:
  283. intvalue = val
  284. default:
  285. err = fmt.Errorf("mssql: invalid type for int column: %T", val)
  286. return
  287. }
  288. res.buffer = make([]byte, res.ti.Size)
  289. if col.ti.Size == 1 {
  290. res.buffer[0] = byte(intvalue)
  291. } else if col.ti.Size == 2 {
  292. binary.LittleEndian.PutUint16(res.buffer, uint16(intvalue))
  293. } else if col.ti.Size == 4 {
  294. binary.LittleEndian.PutUint32(res.buffer, uint32(intvalue))
  295. } else if col.ti.Size == 8 {
  296. binary.LittleEndian.PutUint64(res.buffer, uint64(intvalue))
  297. }
  298. case typeFlt4, typeFlt8, typeFltN:
  299. var floatvalue float64
  300. switch val := val.(type) {
  301. case float32:
  302. floatvalue = float64(val)
  303. case float64:
  304. floatvalue = val
  305. case int:
  306. floatvalue = float64(val)
  307. case int64:
  308. floatvalue = float64(val)
  309. default:
  310. err = fmt.Errorf("mssql: invalid type for float column: %T %s", val, val)
  311. return
  312. }
  313. if col.ti.Size == 4 {
  314. res.buffer = make([]byte, 4)
  315. binary.LittleEndian.PutUint32(res.buffer, math.Float32bits(float32(floatvalue)))
  316. } else if col.ti.Size == 8 {
  317. res.buffer = make([]byte, 8)
  318. binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(floatvalue))
  319. }
  320. case typeNVarChar, typeNText, typeNChar:
  321. switch val := val.(type) {
  322. case string:
  323. res.buffer = str2ucs2(val)
  324. case []byte:
  325. res.buffer = val
  326. default:
  327. err = fmt.Errorf("mssql: invalid type for nvarchar column: %T %s", val, val)
  328. return
  329. }
  330. res.ti.Size = len(res.buffer)
  331. case typeVarChar, typeBigVarChar, typeText, typeChar, typeBigChar:
  332. switch val := val.(type) {
  333. case string:
  334. res.buffer = []byte(val)
  335. case []byte:
  336. res.buffer = val
  337. default:
  338. err = fmt.Errorf("mssql: invalid type for varchar column: %T %s", val, val)
  339. return
  340. }
  341. res.ti.Size = len(res.buffer)
  342. case typeBit, typeBitN:
  343. if reflect.TypeOf(val).Kind() != reflect.Bool {
  344. err = fmt.Errorf("mssql: invalid type for bit column: %T %s", val, val)
  345. return
  346. }
  347. res.ti.TypeId = typeBitN
  348. res.ti.Size = 1
  349. res.buffer = make([]byte, 1)
  350. if val.(bool) {
  351. res.buffer[0] = 1
  352. }
  353. case typeDateTime2N:
  354. switch val := val.(type) {
  355. case time.Time:
  356. res.buffer = encodeDateTime2(val, int(col.ti.Scale))
  357. res.ti.Size = len(res.buffer)
  358. case string:
  359. var t time.Time
  360. if t, err = time.Parse(sqlTimeFormat, val); err != nil {
  361. return res, fmt.Errorf("bulk: unable to convert string to date: %v", err)
  362. }
  363. res.buffer = encodeDateTime2(t, int(col.ti.Scale))
  364. res.ti.Size = len(res.buffer)
  365. default:
  366. err = fmt.Errorf("mssql: invalid type for datetime2 column: %T %s", val, val)
  367. return
  368. }
  369. case typeDateTimeOffsetN:
  370. switch val := val.(type) {
  371. case time.Time:
  372. res.buffer = encodeDateTimeOffset(val, int(col.ti.Scale))
  373. res.ti.Size = len(res.buffer)
  374. case string:
  375. var t time.Time
  376. if t, err = time.Parse(sqlTimeFormat, val); err != nil {
  377. return res, fmt.Errorf("bulk: unable to convert string to date: %v", err)
  378. }
  379. res.buffer = encodeDateTimeOffset(t, int(col.ti.Scale))
  380. res.ti.Size = len(res.buffer)
  381. default:
  382. err = fmt.Errorf("mssql: invalid type for datetimeoffset column: %T %s", val, val)
  383. return
  384. }
  385. case typeDateN:
  386. switch val := val.(type) {
  387. case time.Time:
  388. res.buffer = encodeDate(val)
  389. res.ti.Size = len(res.buffer)
  390. case string:
  391. var t time.Time
  392. if t, err = time.ParseInLocation(sqlDateFormat, val, time.UTC); err != nil {
  393. return res, fmt.Errorf("bulk: unable to convert string to date: %v", err)
  394. }
  395. res.buffer = encodeDate(t)
  396. res.ti.Size = len(res.buffer)
  397. default:
  398. err = fmt.Errorf("mssql: invalid type for date column: %T %s", val, val)
  399. return
  400. }
  401. case typeDateTime, typeDateTimeN, typeDateTim4:
  402. var t time.Time
  403. switch val := val.(type) {
  404. case time.Time:
  405. t = val
  406. case string:
  407. if t, err = time.Parse(sqlTimeFormat, val); err != nil {
  408. return res, fmt.Errorf("bulk: unable to convert string to date: %v", err)
  409. }
  410. default:
  411. err = fmt.Errorf("mssql: invalid type for datetime column: %T %s", val, val)
  412. return
  413. }
  414. if col.ti.Size == 4 {
  415. res.buffer = encodeDateTim4(t)
  416. res.ti.Size = len(res.buffer)
  417. } else if col.ti.Size == 8 {
  418. res.buffer = encodeDateTime(t)
  419. res.ti.Size = len(res.buffer)
  420. } else {
  421. err = fmt.Errorf("mssql: invalid size of column %d", col.ti.Size)
  422. }
  423. // case typeMoney, typeMoney4, typeMoneyN:
  424. case typeDecimal, typeDecimalN, typeNumeric, typeNumericN:
  425. prec := col.ti.Prec
  426. scale := col.ti.Scale
  427. var dec decimal.Decimal
  428. switch v := val.(type) {
  429. case int:
  430. dec = decimal.Int64ToDecimalScale(int64(v), 0)
  431. case int8:
  432. dec = decimal.Int64ToDecimalScale(int64(v), 0)
  433. case int16:
  434. dec = decimal.Int64ToDecimalScale(int64(v), 0)
  435. case int32:
  436. dec = decimal.Int64ToDecimalScale(int64(v), 0)
  437. case int64:
  438. dec = decimal.Int64ToDecimalScale(int64(v), 0)
  439. case float32:
  440. dec, err = decimal.Float64ToDecimalScale(float64(v), scale)
  441. case float64:
  442. dec, err = decimal.Float64ToDecimalScale(float64(v), scale)
  443. case string:
  444. dec, err = decimal.StringToDecimalScale(v, scale)
  445. default:
  446. return res, fmt.Errorf("unknown value for decimal: %T %#v", v, v)
  447. }
  448. if err != nil {
  449. return res, err
  450. }
  451. dec.SetPrec(prec)
  452. var length byte
  453. switch {
  454. case prec <= 9:
  455. length = 4
  456. case prec <= 19:
  457. length = 8
  458. case prec <= 28:
  459. length = 12
  460. default:
  461. length = 16
  462. }
  463. buf := make([]byte, length+1)
  464. // first byte length written by typeInfo.writer
  465. res.ti.Size = int(length) + 1
  466. // second byte sign
  467. if !dec.IsPositive() {
  468. buf[0] = 0
  469. } else {
  470. buf[0] = 1
  471. }
  472. ub := dec.UnscaledBytes()
  473. l := len(ub)
  474. if l > int(length) {
  475. err = fmt.Errorf("decimal out of range: %s", dec)
  476. return res, err
  477. }
  478. // reverse the bytes
  479. for i, j := 1, l-1; j >= 0; i, j = i+1, j-1 {
  480. buf[i] = ub[j]
  481. }
  482. res.buffer = buf
  483. case typeBigVarBin, typeBigBinary:
  484. switch val := val.(type) {
  485. case []byte:
  486. res.ti.Size = len(val)
  487. res.buffer = val
  488. default:
  489. err = fmt.Errorf("mssql: invalid type for Binary column: %T %s", val, val)
  490. return
  491. }
  492. case typeGuid:
  493. switch val := val.(type) {
  494. case []byte:
  495. res.ti.Size = len(val)
  496. res.buffer = val
  497. default:
  498. err = fmt.Errorf("mssql: invalid type for Guid column: %T %s", val, val)
  499. return
  500. }
  501. default:
  502. err = fmt.Errorf("mssql: type %x not implemented", col.ti.TypeId)
  503. }
  504. return
  505. }
  506. func (b *Bulk) dlogf(format string, v ...interface{}) {
  507. if b.Debug {
  508. b.cn.sess.log.Printf(format, v...)
  509. }
  510. }
上海开阖软件有限公司 沪ICP备12045867号-1