389 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			389 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// Copyright 2019 The Go Authors. All rights reserved.
 | 
						|
// Use of this source code is governed by a BSD-style
 | 
						|
// license that can be found in the LICENSE file.
 | 
						|
 | 
						|
package impl
 | 
						|
 | 
						|
import (
 | 
						|
	"reflect"
 | 
						|
	"sort"
 | 
						|
 | 
						|
	"google.golang.org/protobuf/encoding/protowire"
 | 
						|
	"google.golang.org/protobuf/internal/genid"
 | 
						|
	pref "google.golang.org/protobuf/reflect/protoreflect"
 | 
						|
)
 | 
						|
 | 
						|
type mapInfo struct {
 | 
						|
	goType     reflect.Type
 | 
						|
	keyWiretag uint64
 | 
						|
	valWiretag uint64
 | 
						|
	keyFuncs   valueCoderFuncs
 | 
						|
	valFuncs   valueCoderFuncs
 | 
						|
	keyZero    pref.Value
 | 
						|
	keyKind    pref.Kind
 | 
						|
	conv       *mapConverter
 | 
						|
}
 | 
						|
 | 
						|
func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) {
 | 
						|
	// TODO: Consider generating specialized map coders.
 | 
						|
	keyField := fd.MapKey()
 | 
						|
	valField := fd.MapValue()
 | 
						|
	keyWiretag := protowire.EncodeTag(1, wireTypes[keyField.Kind()])
 | 
						|
	valWiretag := protowire.EncodeTag(2, wireTypes[valField.Kind()])
 | 
						|
	keyFuncs := encoderFuncsForValue(keyField)
 | 
						|
	valFuncs := encoderFuncsForValue(valField)
 | 
						|
	conv := newMapConverter(ft, fd)
 | 
						|
 | 
						|
	mapi := &mapInfo{
 | 
						|
		goType:     ft,
 | 
						|
		keyWiretag: keyWiretag,
 | 
						|
		valWiretag: valWiretag,
 | 
						|
		keyFuncs:   keyFuncs,
 | 
						|
		valFuncs:   valFuncs,
 | 
						|
		keyZero:    keyField.Default(),
 | 
						|
		keyKind:    keyField.Kind(),
 | 
						|
		conv:       conv,
 | 
						|
	}
 | 
						|
	if valField.Kind() == pref.MessageKind {
 | 
						|
		valueMessage = getMessageInfo(ft.Elem())
 | 
						|
	}
 | 
						|
 | 
						|
	funcs = pointerCoderFuncs{
 | 
						|
		size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
 | 
						|
			return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts)
 | 
						|
		},
 | 
						|
		marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
 | 
						|
			return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts)
 | 
						|
		},
 | 
						|
		unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
 | 
						|
			mp := p.AsValueOf(ft)
 | 
						|
			if mp.Elem().IsNil() {
 | 
						|
				mp.Elem().Set(reflect.MakeMap(mapi.goType))
 | 
						|
			}
 | 
						|
			if f.mi == nil {
 | 
						|
				return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts)
 | 
						|
			} else {
 | 
						|
				return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts)
 | 
						|
			}
 | 
						|
		},
 | 
						|
	}
 | 
						|
	switch valField.Kind() {
 | 
						|
	case pref.MessageKind:
 | 
						|
		funcs.merge = mergeMapOfMessage
 | 
						|
	case pref.BytesKind:
 | 
						|
		funcs.merge = mergeMapOfBytes
 | 
						|
	default:
 | 
						|
		funcs.merge = mergeMap
 | 
						|
	}
 | 
						|
	if valFuncs.isInit != nil {
 | 
						|
		funcs.isInit = func(p pointer, f *coderFieldInfo) error {
 | 
						|
			return isInitMap(p.AsValueOf(ft).Elem(), mapi, f)
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return valueMessage, funcs
 | 
						|
}
 | 
						|
 | 
						|
const (
 | 
						|
	mapKeyTagSize = 1 // field 1, tag size 1.
 | 
						|
	mapValTagSize = 1 // field 2, tag size 2.
 | 
						|
)
 | 
						|
 | 
						|
func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int {
 | 
						|
	if mapv.Len() == 0 {
 | 
						|
		return 0
 | 
						|
	}
 | 
						|
	n := 0
 | 
						|
	iter := mapRange(mapv)
 | 
						|
	for iter.Next() {
 | 
						|
		key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey()
 | 
						|
		keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
 | 
						|
		var valSize int
 | 
						|
		value := mapi.conv.valConv.PBValueOf(iter.Value())
 | 
						|
		if f.mi == nil {
 | 
						|
			valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
 | 
						|
		} else {
 | 
						|
			p := pointerOfValue(iter.Value())
 | 
						|
			valSize += mapValTagSize
 | 
						|
			valSize += protowire.SizeBytes(f.mi.sizePointer(p, opts))
 | 
						|
		}
 | 
						|
		n += f.tagsize + protowire.SizeBytes(keySize+valSize)
 | 
						|
	}
 | 
						|
	return n
 | 
						|
}
 | 
						|
 | 
						|
func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
 | 
						|
	if wtyp != protowire.BytesType {
 | 
						|
		return out, errUnknown
 | 
						|
	}
 | 
						|
	b, n := protowire.ConsumeBytes(b)
 | 
						|
	if n < 0 {
 | 
						|
		return out, errDecode
 | 
						|
	}
 | 
						|
	var (
 | 
						|
		key = mapi.keyZero
 | 
						|
		val = mapi.conv.valConv.New()
 | 
						|
	)
 | 
						|
	for len(b) > 0 {
 | 
						|
		num, wtyp, n := protowire.ConsumeTag(b)
 | 
						|
		if n < 0 {
 | 
						|
			return out, errDecode
 | 
						|
		}
 | 
						|
		if num > protowire.MaxValidNumber {
 | 
						|
			return out, errDecode
 | 
						|
		}
 | 
						|
		b = b[n:]
 | 
						|
		err := errUnknown
 | 
						|
		switch num {
 | 
						|
		case genid.MapEntry_Key_field_number:
 | 
						|
			var v pref.Value
 | 
						|
			var o unmarshalOutput
 | 
						|
			v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
 | 
						|
			if err != nil {
 | 
						|
				break
 | 
						|
			}
 | 
						|
			key = v
 | 
						|
			n = o.n
 | 
						|
		case genid.MapEntry_Value_field_number:
 | 
						|
			var v pref.Value
 | 
						|
			var o unmarshalOutput
 | 
						|
			v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
 | 
						|
			if err != nil {
 | 
						|
				break
 | 
						|
			}
 | 
						|
			val = v
 | 
						|
			n = o.n
 | 
						|
		}
 | 
						|
		if err == errUnknown {
 | 
						|
			n = protowire.ConsumeFieldValue(num, wtyp, b)
 | 
						|
			if n < 0 {
 | 
						|
				return out, errDecode
 | 
						|
			}
 | 
						|
		} else if err != nil {
 | 
						|
			return out, err
 | 
						|
		}
 | 
						|
		b = b[n:]
 | 
						|
	}
 | 
						|
	mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
 | 
						|
	out.n = n
 | 
						|
	return out, nil
 | 
						|
}
 | 
						|
 | 
						|
func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
 | 
						|
	if wtyp != protowire.BytesType {
 | 
						|
		return out, errUnknown
 | 
						|
	}
 | 
						|
	b, n := protowire.ConsumeBytes(b)
 | 
						|
	if n < 0 {
 | 
						|
		return out, errDecode
 | 
						|
	}
 | 
						|
	var (
 | 
						|
		key = mapi.keyZero
 | 
						|
		val = reflect.New(f.mi.GoReflectType.Elem())
 | 
						|
	)
 | 
						|
	for len(b) > 0 {
 | 
						|
		num, wtyp, n := protowire.ConsumeTag(b)
 | 
						|
		if n < 0 {
 | 
						|
			return out, errDecode
 | 
						|
		}
 | 
						|
		if num > protowire.MaxValidNumber {
 | 
						|
			return out, errDecode
 | 
						|
		}
 | 
						|
		b = b[n:]
 | 
						|
		err := errUnknown
 | 
						|
		switch num {
 | 
						|
		case 1:
 | 
						|
			var v pref.Value
 | 
						|
			var o unmarshalOutput
 | 
						|
			v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
 | 
						|
			if err != nil {
 | 
						|
				break
 | 
						|
			}
 | 
						|
			key = v
 | 
						|
			n = o.n
 | 
						|
		case 2:
 | 
						|
			if wtyp != protowire.BytesType {
 | 
						|
				break
 | 
						|
			}
 | 
						|
			var v []byte
 | 
						|
			v, n = protowire.ConsumeBytes(b)
 | 
						|
			if n < 0 {
 | 
						|
				return out, errDecode
 | 
						|
			}
 | 
						|
			var o unmarshalOutput
 | 
						|
			o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts)
 | 
						|
			if o.initialized {
 | 
						|
				// Consider this map item initialized so long as we see
 | 
						|
				// an initialized value.
 | 
						|
				out.initialized = true
 | 
						|
			}
 | 
						|
		}
 | 
						|
		if err == errUnknown {
 | 
						|
			n = protowire.ConsumeFieldValue(num, wtyp, b)
 | 
						|
			if n < 0 {
 | 
						|
				return out, errDecode
 | 
						|
			}
 | 
						|
		} else if err != nil {
 | 
						|
			return out, err
 | 
						|
		}
 | 
						|
		b = b[n:]
 | 
						|
	}
 | 
						|
	mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
 | 
						|
	out.n = n
 | 
						|
	return out, nil
 | 
						|
}
 | 
						|
 | 
						|
func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
 | 
						|
	if f.mi == nil {
 | 
						|
		key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
 | 
						|
		val := mapi.conv.valConv.PBValueOf(valrv)
 | 
						|
		size := 0
 | 
						|
		size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
 | 
						|
		size += mapi.valFuncs.size(val, mapValTagSize, opts)
 | 
						|
		b = protowire.AppendVarint(b, uint64(size))
 | 
						|
		b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
		return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
 | 
						|
	} else {
 | 
						|
		key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
 | 
						|
		val := pointerOfValue(valrv)
 | 
						|
		valSize := f.mi.sizePointer(val, opts)
 | 
						|
		size := 0
 | 
						|
		size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
 | 
						|
		size += mapValTagSize + protowire.SizeBytes(valSize)
 | 
						|
		b = protowire.AppendVarint(b, uint64(size))
 | 
						|
		b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
		b = protowire.AppendVarint(b, mapi.valWiretag)
 | 
						|
		b = protowire.AppendVarint(b, uint64(valSize))
 | 
						|
		return f.mi.marshalAppendPointer(b, val, opts)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
 | 
						|
	if mapv.Len() == 0 {
 | 
						|
		return b, nil
 | 
						|
	}
 | 
						|
	if opts.Deterministic() {
 | 
						|
		return appendMapDeterministic(b, mapv, mapi, f, opts)
 | 
						|
	}
 | 
						|
	iter := mapRange(mapv)
 | 
						|
	for iter.Next() {
 | 
						|
		var err error
 | 
						|
		b = protowire.AppendVarint(b, f.wiretag)
 | 
						|
		b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts)
 | 
						|
		if err != nil {
 | 
						|
			return b, err
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return b, nil
 | 
						|
}
 | 
						|
 | 
						|
func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
 | 
						|
	keys := mapv.MapKeys()
 | 
						|
	sort.Slice(keys, func(i, j int) bool {
 | 
						|
		switch keys[i].Kind() {
 | 
						|
		case reflect.Bool:
 | 
						|
			return !keys[i].Bool() && keys[j].Bool()
 | 
						|
		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
 | 
						|
			return keys[i].Int() < keys[j].Int()
 | 
						|
		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
 | 
						|
			return keys[i].Uint() < keys[j].Uint()
 | 
						|
		case reflect.Float32, reflect.Float64:
 | 
						|
			return keys[i].Float() < keys[j].Float()
 | 
						|
		case reflect.String:
 | 
						|
			return keys[i].String() < keys[j].String()
 | 
						|
		default:
 | 
						|
			panic("invalid kind: " + keys[i].Kind().String())
 | 
						|
		}
 | 
						|
	})
 | 
						|
	for _, key := range keys {
 | 
						|
		var err error
 | 
						|
		b = protowire.AppendVarint(b, f.wiretag)
 | 
						|
		b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts)
 | 
						|
		if err != nil {
 | 
						|
			return b, err
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return b, nil
 | 
						|
}
 | 
						|
 | 
						|
func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error {
 | 
						|
	if mi := f.mi; mi != nil {
 | 
						|
		mi.init()
 | 
						|
		if !mi.needsInitCheck {
 | 
						|
			return nil
 | 
						|
		}
 | 
						|
		iter := mapRange(mapv)
 | 
						|
		for iter.Next() {
 | 
						|
			val := pointerOfValue(iter.Value())
 | 
						|
			if err := mi.checkInitializedPointer(val); err != nil {
 | 
						|
				return err
 | 
						|
			}
 | 
						|
		}
 | 
						|
	} else {
 | 
						|
		iter := mapRange(mapv)
 | 
						|
		for iter.Next() {
 | 
						|
			val := mapi.conv.valConv.PBValueOf(iter.Value())
 | 
						|
			if err := mapi.valFuncs.isInit(val); err != nil {
 | 
						|
				return err
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func mergeMap(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
 | 
						|
	dstm := dst.AsValueOf(f.ft).Elem()
 | 
						|
	srcm := src.AsValueOf(f.ft).Elem()
 | 
						|
	if srcm.Len() == 0 {
 | 
						|
		return
 | 
						|
	}
 | 
						|
	if dstm.IsNil() {
 | 
						|
		dstm.Set(reflect.MakeMap(f.ft))
 | 
						|
	}
 | 
						|
	iter := mapRange(srcm)
 | 
						|
	for iter.Next() {
 | 
						|
		dstm.SetMapIndex(iter.Key(), iter.Value())
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func mergeMapOfBytes(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
 | 
						|
	dstm := dst.AsValueOf(f.ft).Elem()
 | 
						|
	srcm := src.AsValueOf(f.ft).Elem()
 | 
						|
	if srcm.Len() == 0 {
 | 
						|
		return
 | 
						|
	}
 | 
						|
	if dstm.IsNil() {
 | 
						|
		dstm.Set(reflect.MakeMap(f.ft))
 | 
						|
	}
 | 
						|
	iter := mapRange(srcm)
 | 
						|
	for iter.Next() {
 | 
						|
		dstm.SetMapIndex(iter.Key(), reflect.ValueOf(append(emptyBuf[:], iter.Value().Bytes()...)))
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func mergeMapOfMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
 | 
						|
	dstm := dst.AsValueOf(f.ft).Elem()
 | 
						|
	srcm := src.AsValueOf(f.ft).Elem()
 | 
						|
	if srcm.Len() == 0 {
 | 
						|
		return
 | 
						|
	}
 | 
						|
	if dstm.IsNil() {
 | 
						|
		dstm.Set(reflect.MakeMap(f.ft))
 | 
						|
	}
 | 
						|
	iter := mapRange(srcm)
 | 
						|
	for iter.Next() {
 | 
						|
		val := reflect.New(f.ft.Elem().Elem())
 | 
						|
		if f.mi != nil {
 | 
						|
			f.mi.mergePointer(pointerOfValue(val), pointerOfValue(iter.Value()), opts)
 | 
						|
		} else {
 | 
						|
			opts.Merge(asMessage(val), asMessage(iter.Value()))
 | 
						|
		}
 | 
						|
		dstm.SetMapIndex(iter.Key(), val)
 | 
						|
	}
 | 
						|
}
 |