本站源代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

94 lines
1.7KB

  1. package mssql
  2. import (
  3. "context"
  4. "database/sql/driver"
  5. "encoding/json"
  6. "errors"
  7. )
  8. type copyin struct {
  9. cn *Conn
  10. bulkcopy *Bulk
  11. closed bool
  12. }
  13. type serializableBulkConfig struct {
  14. TableName string
  15. ColumnsName []string
  16. Options BulkOptions
  17. }
  18. func (d *Driver) OpenConnection(dsn string) (*Conn, error) {
  19. return d.open(context.Background(), dsn)
  20. }
  21. func (c *Conn) prepareCopyIn(ctx context.Context, query string) (_ driver.Stmt, err error) {
  22. config_json := query[11:]
  23. bulkconfig := serializableBulkConfig{}
  24. err = json.Unmarshal([]byte(config_json), &bulkconfig)
  25. if err != nil {
  26. return
  27. }
  28. bulkcopy := c.CreateBulkContext(ctx, bulkconfig.TableName, bulkconfig.ColumnsName)
  29. bulkcopy.Options = bulkconfig.Options
  30. ci := &copyin{
  31. cn: c,
  32. bulkcopy: bulkcopy,
  33. }
  34. return ci, nil
  35. }
  36. func CopyIn(table string, options BulkOptions, columns ...string) string {
  37. bulkconfig := &serializableBulkConfig{TableName: table, Options: options, ColumnsName: columns}
  38. config_json, err := json.Marshal(bulkconfig)
  39. if err != nil {
  40. panic(err)
  41. }
  42. stmt := "INSERTBULK " + string(config_json)
  43. return stmt
  44. }
  45. func (ci *copyin) NumInput() int {
  46. return -1
  47. }
  48. func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) {
  49. panic("should never be called")
  50. }
  51. func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) {
  52. if ci.closed {
  53. return nil, errors.New("copyin query is closed")
  54. }
  55. if len(v) == 0 {
  56. rowCount, err := ci.bulkcopy.Done()
  57. ci.closed = true
  58. return driver.RowsAffected(rowCount), err
  59. }
  60. t := make([]interface{}, len(v))
  61. for i, val := range v {
  62. t[i] = val
  63. }
  64. err = ci.bulkcopy.AddRow(t)
  65. if err != nil {
  66. return
  67. }
  68. return driver.RowsAffected(0), nil
  69. }
  70. func (ci *copyin) Close() (err error) {
  71. return nil
  72. }
上海开阖软件有限公司 沪ICP备12045867号-1