本站源代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

368 lines
9.0KB

  1. // Copyright (C) MongoDB, Inc. 2017-present.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License"); you may
  4. // not use this file except in compliance with the License. You may obtain
  5. // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
  6. package bsoncodec
  7. import (
  8. "errors"
  9. "fmt"
  10. "reflect"
  11. "strings"
  12. "sync"
  13. "go.mongodb.org/mongo-driver/bson/bsonrw"
  14. "go.mongodb.org/mongo-driver/bson/bsontype"
  15. )
  16. var defaultStructCodec = &StructCodec{
  17. cache: make(map[reflect.Type]*structDescription),
  18. parser: DefaultStructTagParser,
  19. }
  20. // Zeroer allows custom struct types to implement a report of zero
  21. // state. All struct types that don't implement Zeroer or where IsZero
  22. // returns false are considered to be not zero.
  23. type Zeroer interface {
  24. IsZero() bool
  25. }
  26. // StructCodec is the Codec used for struct values.
  27. type StructCodec struct {
  28. cache map[reflect.Type]*structDescription
  29. l sync.RWMutex
  30. parser StructTagParser
  31. }
  32. var _ ValueEncoder = &StructCodec{}
  33. var _ ValueDecoder = &StructCodec{}
  34. // NewStructCodec returns a StructCodec that uses p for struct tag parsing.
  35. func NewStructCodec(p StructTagParser) (*StructCodec, error) {
  36. if p == nil {
  37. return nil, errors.New("a StructTagParser must be provided to NewStructCodec")
  38. }
  39. return &StructCodec{
  40. cache: make(map[reflect.Type]*structDescription),
  41. parser: p,
  42. }, nil
  43. }
  44. // EncodeValue handles encoding generic struct types.
  45. func (sc *StructCodec) EncodeValue(r EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
  46. if !val.IsValid() || val.Kind() != reflect.Struct {
  47. return ValueEncoderError{Name: "StructCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
  48. }
  49. sd, err := sc.describeStruct(r.Registry, val.Type())
  50. if err != nil {
  51. return err
  52. }
  53. dw, err := vw.WriteDocument()
  54. if err != nil {
  55. return err
  56. }
  57. var rv reflect.Value
  58. for _, desc := range sd.fl {
  59. if desc.inline == nil {
  60. rv = val.Field(desc.idx)
  61. } else {
  62. rv = val.FieldByIndex(desc.inline)
  63. }
  64. if desc.encoder == nil {
  65. return ErrNoEncoder{Type: rv.Type()}
  66. }
  67. encoder := desc.encoder
  68. iszero := sc.isZero
  69. if iz, ok := encoder.(CodecZeroer); ok {
  70. iszero = iz.IsTypeZero
  71. }
  72. if desc.omitEmpty && iszero(rv.Interface()) {
  73. continue
  74. }
  75. vw2, err := dw.WriteDocumentElement(desc.name)
  76. if err != nil {
  77. return err
  78. }
  79. ectx := EncodeContext{Registry: r.Registry, MinSize: desc.minSize}
  80. err = encoder.EncodeValue(ectx, vw2, rv)
  81. if err != nil {
  82. return err
  83. }
  84. }
  85. if sd.inlineMap >= 0 {
  86. rv := val.Field(sd.inlineMap)
  87. collisionFn := func(key string) bool {
  88. _, exists := sd.fm[key]
  89. return exists
  90. }
  91. return defaultValueEncoders.mapEncodeValue(r, dw, rv, collisionFn)
  92. }
  93. return dw.WriteDocumentEnd()
  94. }
  95. // DecodeValue implements the Codec interface.
  96. // By default, map types in val will not be cleared. If a map has existing key/value pairs, it will be extended with the new ones from vr.
  97. // For slices, the decoder will set the length of the slice to zero and append all elements. The underlying array will not be cleared.
  98. func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
  99. if !val.CanSet() || val.Kind() != reflect.Struct {
  100. return ValueDecoderError{Name: "StructCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
  101. }
  102. switch vr.Type() {
  103. case bsontype.Type(0), bsontype.EmbeddedDocument:
  104. default:
  105. return fmt.Errorf("cannot decode %v into a %s", vr.Type(), val.Type())
  106. }
  107. sd, err := sc.describeStruct(r.Registry, val.Type())
  108. if err != nil {
  109. return err
  110. }
  111. var decoder ValueDecoder
  112. var inlineMap reflect.Value
  113. if sd.inlineMap >= 0 {
  114. inlineMap = val.Field(sd.inlineMap)
  115. if inlineMap.IsNil() {
  116. inlineMap.Set(reflect.MakeMap(inlineMap.Type()))
  117. }
  118. decoder, err = r.LookupDecoder(inlineMap.Type().Elem())
  119. if err != nil {
  120. return err
  121. }
  122. }
  123. dr, err := vr.ReadDocument()
  124. if err != nil {
  125. return err
  126. }
  127. for {
  128. name, vr, err := dr.ReadElement()
  129. if err == bsonrw.ErrEOD {
  130. break
  131. }
  132. if err != nil {
  133. return err
  134. }
  135. fd, exists := sd.fm[name]
  136. if !exists {
  137. // if the original name isn't found in the struct description, try again with the name in lowercase
  138. // this could match if a BSON tag isn't specified because by default, describeStruct lowercases all field
  139. // names
  140. fd, exists = sd.fm[strings.ToLower(name)]
  141. }
  142. if !exists {
  143. if sd.inlineMap < 0 {
  144. // The encoding/json package requires a flag to return on error for non-existent fields.
  145. // This functionality seems appropriate for the struct codec.
  146. err = vr.Skip()
  147. if err != nil {
  148. return err
  149. }
  150. continue
  151. }
  152. elem := reflect.New(inlineMap.Type().Elem()).Elem()
  153. err = decoder.DecodeValue(r, vr, elem)
  154. if err != nil {
  155. return err
  156. }
  157. inlineMap.SetMapIndex(reflect.ValueOf(name), elem)
  158. continue
  159. }
  160. var field reflect.Value
  161. if fd.inline == nil {
  162. field = val.Field(fd.idx)
  163. } else {
  164. field = val.FieldByIndex(fd.inline)
  165. }
  166. if !field.CanSet() { // Being settable is a super set of being addressable.
  167. return fmt.Errorf("cannot decode element '%s' into field %v; it is not settable", name, field)
  168. }
  169. if field.Kind() == reflect.Ptr && field.IsNil() {
  170. field.Set(reflect.New(field.Type().Elem()))
  171. }
  172. field = field.Addr()
  173. dctx := DecodeContext{Registry: r.Registry, Truncate: fd.truncate || r.Truncate}
  174. if fd.decoder == nil {
  175. return ErrNoDecoder{Type: field.Elem().Type()}
  176. }
  177. if decoder, ok := fd.decoder.(ValueDecoder); ok {
  178. err = decoder.DecodeValue(dctx, vr, field.Elem())
  179. if err != nil {
  180. return err
  181. }
  182. continue
  183. }
  184. err = fd.decoder.DecodeValue(dctx, vr, field)
  185. if err != nil {
  186. return err
  187. }
  188. }
  189. return nil
  190. }
  191. func (sc *StructCodec) isZero(i interface{}) bool {
  192. v := reflect.ValueOf(i)
  193. // check the value validity
  194. if !v.IsValid() {
  195. return true
  196. }
  197. if z, ok := v.Interface().(Zeroer); ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
  198. return z.IsZero()
  199. }
  200. switch v.Kind() {
  201. case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
  202. return v.Len() == 0
  203. case reflect.Bool:
  204. return !v.Bool()
  205. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  206. return v.Int() == 0
  207. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
  208. return v.Uint() == 0
  209. case reflect.Float32, reflect.Float64:
  210. return v.Float() == 0
  211. case reflect.Interface, reflect.Ptr:
  212. return v.IsNil()
  213. }
  214. return false
  215. }
  216. type structDescription struct {
  217. fm map[string]fieldDescription
  218. fl []fieldDescription
  219. inlineMap int
  220. }
  221. type fieldDescription struct {
  222. name string
  223. idx int
  224. omitEmpty bool
  225. minSize bool
  226. truncate bool
  227. inline []int
  228. encoder ValueEncoder
  229. decoder ValueDecoder
  230. }
  231. func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescription, error) {
  232. // We need to analyze the struct, including getting the tags, collecting
  233. // information about inlining, and create a map of the field name to the field.
  234. sc.l.RLock()
  235. ds, exists := sc.cache[t]
  236. sc.l.RUnlock()
  237. if exists {
  238. return ds, nil
  239. }
  240. numFields := t.NumField()
  241. sd := &structDescription{
  242. fm: make(map[string]fieldDescription, numFields),
  243. fl: make([]fieldDescription, 0, numFields),
  244. inlineMap: -1,
  245. }
  246. for i := 0; i < numFields; i++ {
  247. sf := t.Field(i)
  248. if sf.PkgPath != "" {
  249. // unexported, ignore
  250. continue
  251. }
  252. encoder, err := r.LookupEncoder(sf.Type)
  253. if err != nil {
  254. encoder = nil
  255. }
  256. decoder, err := r.LookupDecoder(sf.Type)
  257. if err != nil {
  258. decoder = nil
  259. }
  260. description := fieldDescription{idx: i, encoder: encoder, decoder: decoder}
  261. stags, err := sc.parser.ParseStructTags(sf)
  262. if err != nil {
  263. return nil, err
  264. }
  265. if stags.Skip {
  266. continue
  267. }
  268. description.name = stags.Name
  269. description.omitEmpty = stags.OmitEmpty
  270. description.minSize = stags.MinSize
  271. description.truncate = stags.Truncate
  272. if stags.Inline {
  273. switch sf.Type.Kind() {
  274. case reflect.Map:
  275. if sd.inlineMap >= 0 {
  276. return nil, errors.New("(struct " + t.String() + ") multiple inline maps")
  277. }
  278. if sf.Type.Key() != tString {
  279. return nil, errors.New("(struct " + t.String() + ") inline map must have a string keys")
  280. }
  281. sd.inlineMap = description.idx
  282. case reflect.Struct:
  283. inlinesf, err := sc.describeStruct(r, sf.Type)
  284. if err != nil {
  285. return nil, err
  286. }
  287. for _, fd := range inlinesf.fl {
  288. if _, exists := sd.fm[fd.name]; exists {
  289. return nil, fmt.Errorf("(struct %s) duplicated key %s", t.String(), fd.name)
  290. }
  291. if fd.inline == nil {
  292. fd.inline = []int{i, fd.idx}
  293. } else {
  294. fd.inline = append([]int{i}, fd.inline...)
  295. }
  296. sd.fm[fd.name] = fd
  297. sd.fl = append(sd.fl, fd)
  298. }
  299. default:
  300. return nil, fmt.Errorf("(struct %s) inline fields must be either a struct or a map", t.String())
  301. }
  302. continue
  303. }
  304. if _, exists := sd.fm[description.name]; exists {
  305. return nil, fmt.Errorf("struct %s) duplicated key %s", t.String(), description.name)
  306. }
  307. sd.fm[description.name] = description
  308. sd.fl = append(sd.fl, description)
  309. }
  310. sc.l.Lock()
  311. sc.cache[t] = sd
  312. sc.l.Unlock()
  313. return sd, nil
  314. }
上海开阖软件有限公司 沪ICP备12045867号-1