本站源代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

232 lines
6.2KB

  1. // +build go1.9
  2. package mssql
  3. import (
  4. "bytes"
  5. "encoding/binary"
  6. "errors"
  7. "fmt"
  8. "reflect"
  9. "strings"
  10. "time"
  11. )
  12. const (
  13. jsonTag = "json"
  14. tvpTag = "tvp"
  15. skipTagValue = "-"
  16. sqlSeparator = "."
  17. )
  18. var (
  19. ErrorEmptyTVPTypeName = errors.New("TypeName must not be empty")
  20. ErrorTypeSlice = errors.New("TVP must be slice type")
  21. ErrorTypeSliceIsEmpty = errors.New("TVP mustn't be null value")
  22. ErrorSkip = errors.New("all fields mustn't skip")
  23. ErrorObjectName = errors.New("wrong tvp name")
  24. ErrorWrongTyping = errors.New("the number of elements in columnStr and tvpFieldIndexes do not align")
  25. )
  26. //TVP is driver type, which allows supporting Table Valued Parameters (TVP) in SQL Server
  27. type TVP struct {
  28. //TypeName mustn't be default value
  29. TypeName string
  30. //Value must be the slice, mustn't be nil
  31. Value interface{}
  32. }
  33. func (tvp TVP) check() error {
  34. if len(tvp.TypeName) == 0 {
  35. return ErrorEmptyTVPTypeName
  36. }
  37. if !isProc(tvp.TypeName) {
  38. return ErrorEmptyTVPTypeName
  39. }
  40. if sepCount := getCountSQLSeparators(tvp.TypeName); sepCount > 1 {
  41. return ErrorObjectName
  42. }
  43. valueOf := reflect.ValueOf(tvp.Value)
  44. if valueOf.Kind() != reflect.Slice {
  45. return ErrorTypeSlice
  46. }
  47. if valueOf.IsNil() {
  48. return ErrorTypeSliceIsEmpty
  49. }
  50. if reflect.TypeOf(tvp.Value).Elem().Kind() != reflect.Struct {
  51. return ErrorTypeSlice
  52. }
  53. return nil
  54. }
  55. func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldIndexes []int) ([]byte, error) {
  56. if len(columnStr) != len(tvpFieldIndexes) {
  57. return nil, ErrorWrongTyping
  58. }
  59. preparedBuffer := make([]byte, 0, 20+(10*len(columnStr)))
  60. buf := bytes.NewBuffer(preparedBuffer)
  61. err := writeBVarChar(buf, "")
  62. if err != nil {
  63. return nil, err
  64. }
  65. writeBVarChar(buf, schema)
  66. writeBVarChar(buf, name)
  67. binary.Write(buf, binary.LittleEndian, uint16(len(columnStr)))
  68. for i, column := range columnStr {
  69. binary.Write(buf, binary.LittleEndian, uint32(column.UserType))
  70. binary.Write(buf, binary.LittleEndian, uint16(column.Flags))
  71. writeTypeInfo(buf, &columnStr[i].ti)
  72. writeBVarChar(buf, "")
  73. }
  74. // The returned error is always nil
  75. buf.WriteByte(_TVP_END_TOKEN)
  76. conn := new(Conn)
  77. conn.sess = new(tdsSession)
  78. conn.sess.loginAck = loginAckStruct{TDSVersion: verTDS73}
  79. stmt := &Stmt{
  80. c: conn,
  81. }
  82. val := reflect.ValueOf(tvp.Value)
  83. for i := 0; i < val.Len(); i++ {
  84. refStr := reflect.ValueOf(val.Index(i).Interface())
  85. buf.WriteByte(_TVP_ROW_TOKEN)
  86. for columnStrIdx, fieldIdx := range tvpFieldIndexes {
  87. field := refStr.Field(fieldIdx)
  88. tvpVal := field.Interface()
  89. valOf := reflect.ValueOf(tvpVal)
  90. elemKind := field.Kind()
  91. if elemKind == reflect.Ptr && valOf.IsNil() {
  92. switch tvpVal.(type) {
  93. case *bool, *time.Time, *int8, *int16, *int32, *int64, *float32, *float64, *int:
  94. binary.Write(buf, binary.LittleEndian, uint8(0))
  95. continue
  96. default:
  97. binary.Write(buf, binary.LittleEndian, uint64(_PLP_NULL))
  98. continue
  99. }
  100. }
  101. if elemKind == reflect.Slice && valOf.IsNil() {
  102. binary.Write(buf, binary.LittleEndian, uint64(_PLP_NULL))
  103. continue
  104. }
  105. cval, err := convertInputParameter(tvpVal)
  106. if err != nil {
  107. return nil, fmt.Errorf("failed to convert tvp parameter row col: %s", err)
  108. }
  109. param, err := stmt.makeParam(cval)
  110. if err != nil {
  111. return nil, fmt.Errorf("failed to make tvp parameter row col: %s", err)
  112. }
  113. columnStr[columnStrIdx].ti.Writer(buf, param.ti, param.buffer)
  114. }
  115. }
  116. buf.WriteByte(_TVP_END_TOKEN)
  117. return buf.Bytes(), nil
  118. }
  119. func (tvp TVP) columnTypes() ([]columnStruct, []int, error) {
  120. val := reflect.ValueOf(tvp.Value)
  121. var firstRow interface{}
  122. if val.Len() != 0 {
  123. firstRow = val.Index(0).Interface()
  124. } else {
  125. firstRow = reflect.New(reflect.TypeOf(tvp.Value).Elem()).Elem().Interface()
  126. }
  127. tvpRow := reflect.TypeOf(firstRow)
  128. columnCount := tvpRow.NumField()
  129. defaultValues := make([]interface{}, 0, columnCount)
  130. tvpFieldIndexes := make([]int, 0, columnCount)
  131. for i := 0; i < columnCount; i++ {
  132. field := tvpRow.Field(i)
  133. tvpTagValue, isTvpTag := field.Tag.Lookup(tvpTag)
  134. jsonTagValue, isJsonTag := field.Tag.Lookup(jsonTag)
  135. if IsSkipField(tvpTagValue, isTvpTag, jsonTagValue, isJsonTag) {
  136. continue
  137. }
  138. tvpFieldIndexes = append(tvpFieldIndexes, i)
  139. if field.Type.Kind() == reflect.Ptr {
  140. v := reflect.New(field.Type.Elem())
  141. defaultValues = append(defaultValues, v.Interface())
  142. continue
  143. }
  144. defaultValues = append(defaultValues, reflect.Zero(field.Type).Interface())
  145. }
  146. if columnCount-len(tvpFieldIndexes) == columnCount {
  147. return nil, nil, ErrorSkip
  148. }
  149. conn := new(Conn)
  150. conn.sess = new(tdsSession)
  151. conn.sess.loginAck = loginAckStruct{TDSVersion: verTDS73}
  152. stmt := &Stmt{
  153. c: conn,
  154. }
  155. columnConfiguration := make([]columnStruct, 0, columnCount)
  156. for index, val := range defaultValues {
  157. cval, err := convertInputParameter(val)
  158. if err != nil {
  159. return nil, nil, fmt.Errorf("failed to convert tvp parameter row %d col %d: %s", index, val, err)
  160. }
  161. param, err := stmt.makeParam(cval)
  162. if err != nil {
  163. return nil, nil, err
  164. }
  165. column := columnStruct{
  166. ti: param.ti,
  167. }
  168. switch param.ti.TypeId {
  169. case typeNVarChar, typeBigVarBin:
  170. column.ti.Size = 0
  171. }
  172. columnConfiguration = append(columnConfiguration, column)
  173. }
  174. return columnConfiguration, tvpFieldIndexes, nil
  175. }
  176. func IsSkipField(tvpTagValue string, isTvpValue bool, jsonTagValue string, isJsonTagValue bool) bool {
  177. if !isTvpValue && !isJsonTagValue {
  178. return false
  179. } else if isTvpValue && tvpTagValue != skipTagValue {
  180. return false
  181. } else if !isTvpValue && isJsonTagValue && jsonTagValue != skipTagValue {
  182. return false
  183. }
  184. return true
  185. }
  186. func getSchemeAndName(tvpName string) (string, string, error) {
  187. if len(tvpName) == 0 {
  188. return "", "", ErrorEmptyTVPTypeName
  189. }
  190. splitVal := strings.Split(tvpName, ".")
  191. if len(splitVal) > 2 {
  192. return "", "", errors.New("wrong tvp name")
  193. }
  194. if len(splitVal) == 2 {
  195. res := make([]string, 2)
  196. for key, value := range splitVal {
  197. tmp := strings.Replace(value, "[", "", -1)
  198. tmp = strings.Replace(tmp, "]", "", -1)
  199. res[key] = tmp
  200. }
  201. return res[0], res[1], nil
  202. }
  203. tmp := strings.Replace(splitVal[0], "[", "", -1)
  204. tmp = strings.Replace(tmp, "]", "", -1)
  205. return "", tmp, nil
  206. }
  207. func getCountSQLSeparators(str string) int {
  208. return strings.Count(str, sqlSeparator)
  209. }
上海开阖软件有限公司 沪ICP备12045867号-1