|
- // Copyright (C) MongoDB, Inc. 2017-present.
- //
- // Licensed under the Apache License, Version 2.0 (the "License"); you may
- // not use this file except in compliance with the License. You may obtain
- // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
-
- package bsoncodec
-
- import (
- "errors"
- "fmt"
- "reflect"
- "strings"
- "sync"
-
- "go.mongodb.org/mongo-driver/bson/bsonrw"
- "go.mongodb.org/mongo-driver/bson/bsontype"
- )
-
- var defaultStructCodec = &StructCodec{
- cache: make(map[reflect.Type]*structDescription),
- parser: DefaultStructTagParser,
- }
-
- // Zeroer allows custom struct types to implement a report of zero
- // state. All struct types that don't implement Zeroer or where IsZero
- // returns false are considered to be not zero.
- type Zeroer interface {
- IsZero() bool
- }
-
- // StructCodec is the Codec used for struct values.
- type StructCodec struct {
- cache map[reflect.Type]*structDescription
- l sync.RWMutex
- parser StructTagParser
- }
-
- var _ ValueEncoder = &StructCodec{}
- var _ ValueDecoder = &StructCodec{}
-
- // NewStructCodec returns a StructCodec that uses p for struct tag parsing.
- func NewStructCodec(p StructTagParser) (*StructCodec, error) {
- if p == nil {
- return nil, errors.New("a StructTagParser must be provided to NewStructCodec")
- }
-
- return &StructCodec{
- cache: make(map[reflect.Type]*structDescription),
- parser: p,
- }, nil
- }
-
- // EncodeValue handles encoding generic struct types.
- func (sc *StructCodec) EncodeValue(r EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
- if !val.IsValid() || val.Kind() != reflect.Struct {
- return ValueEncoderError{Name: "StructCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
- }
-
- sd, err := sc.describeStruct(r.Registry, val.Type())
- if err != nil {
- return err
- }
-
- dw, err := vw.WriteDocument()
- if err != nil {
- return err
- }
- var rv reflect.Value
- for _, desc := range sd.fl {
- if desc.inline == nil {
- rv = val.Field(desc.idx)
- } else {
- rv = val.FieldByIndex(desc.inline)
- }
-
- if desc.encoder == nil {
- return ErrNoEncoder{Type: rv.Type()}
- }
-
- encoder := desc.encoder
-
- iszero := sc.isZero
- if iz, ok := encoder.(CodecZeroer); ok {
- iszero = iz.IsTypeZero
- }
-
- if desc.omitEmpty && iszero(rv.Interface()) {
- continue
- }
-
- vw2, err := dw.WriteDocumentElement(desc.name)
- if err != nil {
- return err
- }
-
- ectx := EncodeContext{Registry: r.Registry, MinSize: desc.minSize}
- err = encoder.EncodeValue(ectx, vw2, rv)
- if err != nil {
- return err
- }
- }
-
- if sd.inlineMap >= 0 {
- rv := val.Field(sd.inlineMap)
- collisionFn := func(key string) bool {
- _, exists := sd.fm[key]
- return exists
- }
-
- return defaultValueEncoders.mapEncodeValue(r, dw, rv, collisionFn)
- }
-
- return dw.WriteDocumentEnd()
- }
-
- // DecodeValue implements the Codec interface.
- // By default, map types in val will not be cleared. If a map has existing key/value pairs, it will be extended with the new ones from vr.
- // For slices, the decoder will set the length of the slice to zero and append all elements. The underlying array will not be cleared.
- func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
- if !val.CanSet() || val.Kind() != reflect.Struct {
- return ValueDecoderError{Name: "StructCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
- }
-
- switch vr.Type() {
- case bsontype.Type(0), bsontype.EmbeddedDocument:
- default:
- return fmt.Errorf("cannot decode %v into a %s", vr.Type(), val.Type())
- }
-
- sd, err := sc.describeStruct(r.Registry, val.Type())
- if err != nil {
- return err
- }
-
- var decoder ValueDecoder
- var inlineMap reflect.Value
- if sd.inlineMap >= 0 {
- inlineMap = val.Field(sd.inlineMap)
- if inlineMap.IsNil() {
- inlineMap.Set(reflect.MakeMap(inlineMap.Type()))
- }
- decoder, err = r.LookupDecoder(inlineMap.Type().Elem())
- if err != nil {
- return err
- }
- }
-
- dr, err := vr.ReadDocument()
- if err != nil {
- return err
- }
-
- for {
- name, vr, err := dr.ReadElement()
- if err == bsonrw.ErrEOD {
- break
- }
- if err != nil {
- return err
- }
-
- fd, exists := sd.fm[name]
- if !exists {
- // if the original name isn't found in the struct description, try again with the name in lowercase
- // this could match if a BSON tag isn't specified because by default, describeStruct lowercases all field
- // names
- fd, exists = sd.fm[strings.ToLower(name)]
- }
-
- if !exists {
- if sd.inlineMap < 0 {
- // The encoding/json package requires a flag to return on error for non-existent fields.
- // This functionality seems appropriate for the struct codec.
- err = vr.Skip()
- if err != nil {
- return err
- }
- continue
- }
-
- elem := reflect.New(inlineMap.Type().Elem()).Elem()
- err = decoder.DecodeValue(r, vr, elem)
- if err != nil {
- return err
- }
- inlineMap.SetMapIndex(reflect.ValueOf(name), elem)
- continue
- }
-
- var field reflect.Value
- if fd.inline == nil {
- field = val.Field(fd.idx)
- } else {
- field = val.FieldByIndex(fd.inline)
- }
-
- if !field.CanSet() { // Being settable is a super set of being addressable.
- return fmt.Errorf("cannot decode element '%s' into field %v; it is not settable", name, field)
- }
- if field.Kind() == reflect.Ptr && field.IsNil() {
- field.Set(reflect.New(field.Type().Elem()))
- }
- field = field.Addr()
-
- dctx := DecodeContext{Registry: r.Registry, Truncate: fd.truncate || r.Truncate}
- if fd.decoder == nil {
- return ErrNoDecoder{Type: field.Elem().Type()}
- }
-
- if decoder, ok := fd.decoder.(ValueDecoder); ok {
- err = decoder.DecodeValue(dctx, vr, field.Elem())
- if err != nil {
- return err
- }
- continue
- }
- err = fd.decoder.DecodeValue(dctx, vr, field)
- if err != nil {
- return err
- }
- }
-
- return nil
- }
-
- func (sc *StructCodec) isZero(i interface{}) bool {
- v := reflect.ValueOf(i)
-
- // check the value validity
- if !v.IsValid() {
- return true
- }
-
- if z, ok := v.Interface().(Zeroer); ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
- return z.IsZero()
- }
-
- switch v.Kind() {
- case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
- return v.Len() == 0
- case reflect.Bool:
- return !v.Bool()
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- return v.Int() == 0
- case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
- return v.Uint() == 0
- case reflect.Float32, reflect.Float64:
- return v.Float() == 0
- case reflect.Interface, reflect.Ptr:
- return v.IsNil()
- }
-
- return false
- }
-
- type structDescription struct {
- fm map[string]fieldDescription
- fl []fieldDescription
- inlineMap int
- }
-
- type fieldDescription struct {
- name string
- idx int
- omitEmpty bool
- minSize bool
- truncate bool
- inline []int
- encoder ValueEncoder
- decoder ValueDecoder
- }
-
- func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescription, error) {
- // We need to analyze the struct, including getting the tags, collecting
- // information about inlining, and create a map of the field name to the field.
- sc.l.RLock()
- ds, exists := sc.cache[t]
- sc.l.RUnlock()
- if exists {
- return ds, nil
- }
-
- numFields := t.NumField()
- sd := &structDescription{
- fm: make(map[string]fieldDescription, numFields),
- fl: make([]fieldDescription, 0, numFields),
- inlineMap: -1,
- }
-
- for i := 0; i < numFields; i++ {
- sf := t.Field(i)
- if sf.PkgPath != "" {
- // unexported, ignore
- continue
- }
-
- encoder, err := r.LookupEncoder(sf.Type)
- if err != nil {
- encoder = nil
- }
- decoder, err := r.LookupDecoder(sf.Type)
- if err != nil {
- decoder = nil
- }
-
- description := fieldDescription{idx: i, encoder: encoder, decoder: decoder}
-
- stags, err := sc.parser.ParseStructTags(sf)
- if err != nil {
- return nil, err
- }
- if stags.Skip {
- continue
- }
- description.name = stags.Name
- description.omitEmpty = stags.OmitEmpty
- description.minSize = stags.MinSize
- description.truncate = stags.Truncate
-
- if stags.Inline {
- switch sf.Type.Kind() {
- case reflect.Map:
- if sd.inlineMap >= 0 {
- return nil, errors.New("(struct " + t.String() + ") multiple inline maps")
- }
- if sf.Type.Key() != tString {
- return nil, errors.New("(struct " + t.String() + ") inline map must have a string keys")
- }
- sd.inlineMap = description.idx
- case reflect.Struct:
- inlinesf, err := sc.describeStruct(r, sf.Type)
- if err != nil {
- return nil, err
- }
- for _, fd := range inlinesf.fl {
- if _, exists := sd.fm[fd.name]; exists {
- return nil, fmt.Errorf("(struct %s) duplicated key %s", t.String(), fd.name)
- }
- if fd.inline == nil {
- fd.inline = []int{i, fd.idx}
- } else {
- fd.inline = append([]int{i}, fd.inline...)
- }
- sd.fm[fd.name] = fd
- sd.fl = append(sd.fl, fd)
- }
- default:
- return nil, fmt.Errorf("(struct %s) inline fields must be either a struct or a map", t.String())
- }
- continue
- }
-
- if _, exists := sd.fm[description.name]; exists {
- return nil, fmt.Errorf("struct %s) duplicated key %s", t.String(), description.name)
- }
-
- sd.fm[description.name] = description
- sd.fl = append(sd.fl, description)
- }
-
- sc.l.Lock()
- sc.cache[t] = sd
- sc.l.Unlock()
-
- return sd, nil
- }
|