Bump cel-go to v0.10.0

This commit is contained in:
Joe Betz
2022-03-07 20:47:04 -05:00
parent f93be6584e
commit 2a6b85c395
66 changed files with 3332 additions and 817 deletions

View File

@@ -8,6 +8,7 @@ package(
go_library(
name = "go_default_library",
srcs = [
"cost.go",
"error.go",
"errors.go",
"location.go",

40
vendor/github.com/google/cel-go/common/cost.go generated vendored Normal file
View File

@@ -0,0 +1,40 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package common
const (
// SelectAndIdentCost is the cost of an operation that accesses an identifier or performs a select.
SelectAndIdentCost = 1
// ConstCost is the cost of an operation that accesses a constant.
ConstCost = 0
// ListCreateBaseCost is the base cost of any operation that creates a new list.
ListCreateBaseCost = 10
// MapCreateBaseCost is the base cost of any operation that creates a new map.
MapCreateBaseCost = 30
// StructCreateBaseCost is the base cost of any operation that creates a new struct.
StructCreateBaseCost = 40
// StringTraversalCostFactor is multiplied to a length of a string when computing the cost of traversing the entire
// string once.
StringTraversalCostFactor = 0.1
// RegexStringLengthCostFactor is multiplied ot the length of a regex string pattern when computing the cost of
// applying the regex to a string of unit cost.
RegexStringLengthCostFactor = 0.25
)

View File

@@ -70,46 +70,35 @@ var (
"!=": NotEquals,
"-": Subtract,
}
reverseOperators = map[string]string{
Add: "+",
Divide: "/",
Equals: "==",
Greater: ">",
GreaterEquals: ">=",
In: "in",
Less: "<",
LessEquals: "<=",
LogicalAnd: "&&",
LogicalNot: "!",
LogicalOr: "||",
Modulo: "%",
Multiply: "*",
Negate: "-",
NotEquals: "!=",
OldIn: "in",
Subtract: "-",
}
// precedence of the operator, where the higher value means higher.
precedence = map[string]int{
Conditional: 8,
LogicalOr: 7,
LogicalAnd: 6,
Equals: 5,
Greater: 5,
GreaterEquals: 5,
In: 5,
Less: 5,
LessEquals: 5,
NotEquals: 5,
OldIn: 5,
Add: 4,
Subtract: 4,
Divide: 3,
Modulo: 3,
Multiply: 3,
LogicalNot: 2,
Negate: 2,
Index: 1,
// operatorMap of the operator symbol which refers to a struct containing the display name,
// if applicable, the operator precedence, and the arity.
//
// If the symbol does not have a display name listed in the map, it is only because it requires
// special casing to render properly as text.
operatorMap = map[string]struct {
displayName string
precedence int
arity int
}{
Conditional: {displayName: "", precedence: 8, arity: 3},
LogicalOr: {displayName: "||", precedence: 7, arity: 2},
LogicalAnd: {displayName: "&&", precedence: 6, arity: 2},
Equals: {displayName: "==", precedence: 5, arity: 2},
Greater: {displayName: ">", precedence: 5, arity: 2},
GreaterEquals: {displayName: ">=", precedence: 5, arity: 2},
In: {displayName: "in", precedence: 5, arity: 2},
Less: {displayName: "<", precedence: 5, arity: 2},
LessEquals: {displayName: "<=", precedence: 5, arity: 2},
NotEquals: {displayName: "!=", precedence: 5, arity: 2},
OldIn: {displayName: "in", precedence: 5, arity: 2},
Add: {displayName: "+", precedence: 4, arity: 2},
Subtract: {displayName: "-", precedence: 4, arity: 2},
Divide: {displayName: "/", precedence: 3, arity: 2},
Modulo: {displayName: "%", precedence: 3, arity: 2},
Multiply: {displayName: "*", precedence: 3, arity: 2},
LogicalNot: {displayName: "!", precedence: 2, arity: 1},
Negate: {displayName: "-", precedence: 2, arity: 1},
Index: {displayName: "", precedence: 1, arity: 2},
}
)
@@ -120,26 +109,35 @@ func Find(text string) (string, bool) {
}
// FindReverse returns the unmangled, text representation of the operator.
func FindReverse(op string) (string, bool) {
txt, found := reverseOperators[op]
return txt, found
func FindReverse(symbol string) (string, bool) {
op, found := operatorMap[symbol]
if !found {
return "", false
}
return op.displayName, true
}
// FindReverseBinaryOperator returns the unmangled, text representation of a binary operator.
func FindReverseBinaryOperator(op string) (string, bool) {
if op == LogicalNot || op == Negate {
//
// If the symbol does refer to an operator, but the operator does not have a display name the
// result is false.
func FindReverseBinaryOperator(symbol string) (string, bool) {
op, found := operatorMap[symbol]
if !found || op.arity != 2 {
return "", false
}
txt, found := reverseOperators[op]
return txt, found
if op.displayName == "" {
return "", false
}
return op.displayName, true
}
// Precedence returns the operator precedence, where the higher the number indicates
// higher precedence operations.
func Precedence(op string) int {
p, found := precedence[op]
if found {
return p
func Precedence(symbol string) int {
op, found := operatorMap[symbol]
if !found {
return 0
}
return 0
return op.precedence
}

View File

@@ -18,45 +18,69 @@ package overloads
// Boolean logic overloads
const (
Conditional = "conditional"
LogicalAnd = "logical_and"
LogicalOr = "logical_or"
LogicalNot = "logical_not"
NotStrictlyFalse = "not_strictly_false"
Equals = "equals"
NotEquals = "not_equals"
LessBool = "less_bool"
LessInt64 = "less_int64"
LessUint64 = "less_uint64"
LessDouble = "less_double"
LessString = "less_string"
LessBytes = "less_bytes"
LessTimestamp = "less_timestamp"
LessDuration = "less_duration"
LessEqualsBool = "less_equals_bool"
LessEqualsInt64 = "less_equals_int64"
LessEqualsUint64 = "less_equals_uint64"
LessEqualsDouble = "less_equals_double"
LessEqualsString = "less_equals_string"
LessEqualsBytes = "less_equals_bytes"
LessEqualsTimestamp = "less_equals_timestamp"
LessEqualsDuration = "less_equals_duration"
GreaterBool = "greater_bool"
GreaterInt64 = "greater_int64"
GreaterUint64 = "greater_uint64"
GreaterDouble = "greater_double"
GreaterString = "greater_string"
GreaterBytes = "greater_bytes"
GreaterTimestamp = "greater_timestamp"
GreaterDuration = "greater_duration"
GreaterEqualsBool = "greater_equals_bool"
GreaterEqualsInt64 = "greater_equals_int64"
GreaterEqualsUint64 = "greater_equals_uint64"
GreaterEqualsDouble = "greater_equals_double"
GreaterEqualsString = "greater_equals_string"
GreaterEqualsBytes = "greater_equals_bytes"
GreaterEqualsTimestamp = "greater_equals_timestamp"
GreaterEqualsDuration = "greater_equals_duration"
Conditional = "conditional"
LogicalAnd = "logical_and"
LogicalOr = "logical_or"
LogicalNot = "logical_not"
NotStrictlyFalse = "not_strictly_false"
Equals = "equals"
NotEquals = "not_equals"
LessBool = "less_bool"
LessInt64 = "less_int64"
LessInt64Double = "less_int64_double"
LessInt64Uint64 = "less_int64_uint64"
LessUint64 = "less_uint64"
LessUint64Double = "less_uint64_double"
LessUint64Int64 = "less_uint64_int64"
LessDouble = "less_double"
LessDoubleInt64 = "less_double_int64"
LessDoubleUint64 = "less_double_uint64"
LessString = "less_string"
LessBytes = "less_bytes"
LessTimestamp = "less_timestamp"
LessDuration = "less_duration"
LessEqualsBool = "less_equals_bool"
LessEqualsInt64 = "less_equals_int64"
LessEqualsInt64Double = "less_equals_int64_double"
LessEqualsInt64Uint64 = "less_equals_int64_uint64"
LessEqualsUint64 = "less_equals_uint64"
LessEqualsUint64Double = "less_equals_uint64_double"
LessEqualsUint64Int64 = "less_equals_uint64_int64"
LessEqualsDouble = "less_equals_double"
LessEqualsDoubleInt64 = "less_equals_double_int64"
LessEqualsDoubleUint64 = "less_equals_double_uint64"
LessEqualsString = "less_equals_string"
LessEqualsBytes = "less_equals_bytes"
LessEqualsTimestamp = "less_equals_timestamp"
LessEqualsDuration = "less_equals_duration"
GreaterBool = "greater_bool"
GreaterInt64 = "greater_int64"
GreaterInt64Double = "greater_int64_double"
GreaterInt64Uint64 = "greater_int64_uint64"
GreaterUint64 = "greater_uint64"
GreaterUint64Double = "greater_uint64_double"
GreaterUint64Int64 = "greater_uint64_int64"
GreaterDouble = "greater_double"
GreaterDoubleInt64 = "greater_double_int64"
GreaterDoubleUint64 = "greater_double_uint64"
GreaterString = "greater_string"
GreaterBytes = "greater_bytes"
GreaterTimestamp = "greater_timestamp"
GreaterDuration = "greater_duration"
GreaterEqualsBool = "greater_equals_bool"
GreaterEqualsInt64 = "greater_equals_int64"
GreaterEqualsInt64Double = "greater_equals_int64_double"
GreaterEqualsInt64Uint64 = "greater_equals_int64_uint64"
GreaterEqualsUint64 = "greater_equals_uint64"
GreaterEqualsUint64Double = "greater_equals_uint64_double"
GreaterEqualsUint64Int64 = "greater_equals_uint64_int64"
GreaterEqualsDouble = "greater_equals_double"
GreaterEqualsDoubleInt64 = "greater_equals_double_int64"
GreaterEqualsDoubleUint64 = "greater_equals_double_uint64"
GreaterEqualsString = "greater_equals_string"
GreaterEqualsBytes = "greater_equals_bytes"
GreaterEqualsTimestamp = "greater_equals_timestamp"
GreaterEqualsDuration = "greater_equals_duration"
)
// Math overloads

View File

@@ -11,6 +11,7 @@ go_library(
"any_value.go",
"bool.go",
"bytes.go",
"compare.go",
"double.go",
"duration.go",
"err.go",
@@ -38,6 +39,9 @@ go_library(
"//common/types/traits:go_default_library",
"@com_github_stoewer_go_strcase//:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto//googleapis/rpc/status:go_default_library",
"@org_golang_google_grpc//codes:go_default_library",
"@org_golang_google_grpc//status:go_default_library",
"@org_golang_google_protobuf//encoding/protojson:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library",

View File

@@ -111,10 +111,7 @@ func (b Bool) ConvertToType(typeVal ref.Type) ref.Val {
// Equal implements the ref.Val interface method.
func (b Bool) Equal(other ref.Val) ref.Val {
otherBool, ok := other.(Bool)
if !ok {
return ValOrErr(other, "no such overload")
}
return Bool(b == otherBool)
return Bool(ok && b == otherBool)
}
// Negate implements the traits.Negater interface method.

View File

@@ -113,10 +113,7 @@ func (b Bytes) ConvertToType(typeVal ref.Type) ref.Val {
// Equal implements the ref.Val interface method.
func (b Bytes) Equal(other ref.Val) ref.Val {
otherBytes, ok := other.(Bytes)
if !ok {
return ValOrErr(other, "no such overload")
}
return Bool(bytes.Equal(b, otherBytes))
return Bool(ok && bytes.Equal(b, otherBytes))
}
// Size implements the traits.Sizer interface method.

View File

@@ -0,0 +1,95 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package types
import (
"math"
)
func compareDoubleInt(d Double, i Int) Int {
if d < math.MinInt64 {
return IntNegOne
}
if d > math.MaxInt64 {
return IntOne
}
return compareDouble(d, Double(i))
}
func compareIntDouble(i Int, d Double) Int {
return -compareDoubleInt(d, i)
}
func compareDoubleUint(d Double, u Uint) Int {
if d < 0 {
return IntNegOne
}
if d > math.MaxUint64 {
return IntOne
}
return compareDouble(d, Double(u))
}
func compareUintDouble(u Uint, d Double) Int {
return -compareDoubleUint(d, u)
}
func compareIntUint(i Int, u Uint) Int {
if i < 0 || u > math.MaxInt64 {
return IntNegOne
}
cmp := i - Int(u)
if cmp < 0 {
return IntNegOne
}
if cmp > 0 {
return IntOne
}
return IntZero
}
func compareUintInt(u Uint, i Int) Int {
return -compareIntUint(i, u)
}
func compareDouble(a, b Double) Int {
if a < b {
return IntNegOne
}
if a > b {
return IntOne
}
return IntZero
}
func compareInt(a, b Int) Int {
if a < b {
return IntNegOne
}
if a > b {
return IntOne
}
return IntZero
}
func compareUint(a, b Uint) Int {
if a < b {
return IntNegOne
}
if a > b {
return IntOne
}
return IntZero
}

View File

@@ -16,6 +16,7 @@ package types
import (
"fmt"
"math"
"reflect"
"github.com/google/cel-go/common/types/ref"
@@ -58,17 +59,22 @@ func (d Double) Add(other ref.Val) ref.Val {
// Compare implements traits.Comparer.Compare.
func (d Double) Compare(other ref.Val) ref.Val {
otherDouble, ok := other.(Double)
if !ok {
if math.IsNaN(float64(d)) {
return NewErr("NaN values cannot be ordered")
}
switch ov := other.(type) {
case Double:
if math.IsNaN(float64(ov)) {
return NewErr("NaN values cannot be ordered")
}
return compareDouble(d, ov)
case Int:
return compareDoubleInt(d, ov)
case Uint:
return compareDoubleUint(d, ov)
default:
return MaybeNoSuchOverloadErr(other)
}
if d < otherDouble {
return IntNegOne
}
if d > otherDouble {
return IntOne
}
return IntZero
}
// ConvertToNative implements ref.Val.ConvertToNative.
@@ -158,12 +164,22 @@ func (d Double) Divide(other ref.Val) ref.Val {
// Equal implements ref.Val.Equal.
func (d Double) Equal(other ref.Val) ref.Val {
otherDouble, ok := other.(Double)
if !ok {
return MaybeNoSuchOverloadErr(other)
if math.IsNaN(float64(d)) {
return False
}
switch ov := other.(type) {
case Double:
if math.IsNaN(float64(ov)) {
return False
}
return Bool(d == ov)
case Int:
return Bool(compareDoubleInt(d, ov) == 0)
case Uint:
return Bool(compareDoubleUint(d, ov) == 0)
default:
return False
}
// TODO: Handle NaNs properly.
return Bool(d == otherDouble)
}
// Multiply implements traits.Multiplier.Multiply.

View File

@@ -135,10 +135,7 @@ func (d Duration) ConvertToType(typeVal ref.Type) ref.Val {
// Equal implements ref.Val.Equal.
func (d Duration) Equal(other ref.Val) ref.Val {
otherDur, ok := other.(Duration)
if !ok {
return MaybeNoSuchOverloadErr(other)
}
return Bool(d.Duration == otherDur.Duration)
return Bool(ok && d.Duration == otherDur.Duration)
}
// Negate implements traits.Negater.Negate.

View File

@@ -16,6 +16,7 @@ package types
import (
"fmt"
"math"
"reflect"
"strconv"
"time"
@@ -72,17 +73,19 @@ func (i Int) Add(other ref.Val) ref.Val {
// Compare implements traits.Comparer.Compare.
func (i Int) Compare(other ref.Val) ref.Val {
otherInt, ok := other.(Int)
if !ok {
switch ov := other.(type) {
case Double:
if math.IsNaN(float64(ov)) {
return NewErr("NaN values cannot be ordered")
}
return compareIntDouble(i, ov)
case Int:
return compareInt(i, ov)
case Uint:
return compareIntUint(i, ov)
default:
return MaybeNoSuchOverloadErr(other)
}
if i < otherInt {
return IntNegOne
}
if i > otherInt {
return IntOne
}
return IntZero
}
// ConvertToNative implements ref.Val.ConvertToNative.
@@ -208,11 +211,19 @@ func (i Int) Divide(other ref.Val) ref.Val {
// Equal implements ref.Val.Equal.
func (i Int) Equal(other ref.Val) ref.Val {
otherInt, ok := other.(Int)
if !ok {
return MaybeNoSuchOverloadErr(other)
switch ov := other.(type) {
case Double:
if math.IsNaN(float64(ov)) {
return False
}
return Bool(compareIntDouble(i, ov) == 0)
case Int:
return Bool(i == ov)
case Uint:
return Bool(compareIntUint(i, ov) == 0)
default:
return False
}
return Bool(i == otherInt)
}
// Modulo implements traits.Modder.Modulo.

View File

@@ -95,6 +95,18 @@ func NewJSONList(adapter ref.TypeAdapter, l *structpb.ListValue) traits.Lister {
}
}
// NewMutableList creates a new mutable list whose internal state can be modified.
//
// The mutable list only handles `Add` calls correctly as it is intended only for use within
// comprehension loops which generate an immutable result upon completion.
func NewMutableList(adapter ref.TypeAdapter) traits.Lister {
return &mutableList{
TypeAdapter: adapter,
baseList: nil,
mutableValues: []ref.Val{},
}
}
// baseList points to a list containing elements of any type.
// The `value` is an array of native values, and refValue is its reflection object.
// The `ref.TypeAdapter` enables native type to CEL type conversions.
@@ -131,28 +143,14 @@ func (l *baseList) Add(other ref.Val) ref.Val {
// Contains implements the traits.Container interface method.
func (l *baseList) Contains(elem ref.Val) ref.Val {
if IsUnknownOrError(elem) {
return elem
}
var err ref.Val
for i := 0; i < l.size; i++ {
val := l.NativeToValue(l.get(i))
cmp := elem.Equal(val)
b, ok := cmp.(Bool)
// When there is an error on the contain check, this is not necessarily terminal.
// The contains call could find the element and return True, just as though the user
// had written a per-element comparison in an exists() macro or logical ||, e.g.
// list.exists(e, e == elem)
if !ok && err == nil {
err = ValOrErr(cmp, "no such overload")
}
if b == True {
if ok && b == True {
return True
}
}
if err != nil {
return err
}
return False
}
@@ -222,25 +220,18 @@ func (l *baseList) ConvertToType(typeVal ref.Type) ref.Val {
func (l *baseList) Equal(other ref.Val) ref.Val {
otherList, ok := other.(traits.Lister)
if !ok {
return MaybeNoSuchOverloadErr(other)
return False
}
if l.Size() != otherList.Size() {
return False
}
var maybeErr ref.Val
for i := IntZero; i < l.Size().(Int); i++ {
thisElem := l.Get(i)
otherElem := otherList.Get(i)
elemEq := thisElem.Equal(otherElem)
elemEq := Equal(thisElem, otherElem)
if elemEq == False {
return False
}
if maybeErr == nil && IsUnknownOrError(elemEq) {
maybeErr = elemEq
}
}
if maybeErr != nil {
return maybeErr
}
return True
}
@@ -279,6 +270,32 @@ func (l *baseList) Value() interface{} {
return l.value
}
// mutableList aggregates values into its internal storage. For use with internal CEL variables only.
type mutableList struct {
ref.TypeAdapter
*baseList
mutableValues []ref.Val
}
// Add copies elements from the other list into the internal storage of the mutable list.
func (l *mutableList) Add(other ref.Val) ref.Val {
otherList, ok := other.(traits.Lister)
if !ok {
return MaybeNoSuchOverloadErr(otherList)
}
for i := IntZero; i < otherList.Size().(Int); i++ {
l.mutableValues = append(l.mutableValues, otherList.Get(i))
}
return l
}
// ToImmutableList returns an immutable list based on the internal storage of the mutable list.
func (l *mutableList) ToImmutableList() traits.Lister {
// The reference to internal state is guaranteed to be safe as this call is only performed
// when mutations have been completed.
return NewRefValList(l.TypeAdapter, l.mutableValues)
}
// concatList combines two list implementations together into a view.
// The `ref.TypeAdapter` enables native type to CEL type conversions.
type concatList struct {
@@ -349,7 +366,7 @@ func (l *concatList) ConvertToType(typeVal ref.Type) ref.Val {
func (l *concatList) Equal(other ref.Val) ref.Val {
otherList, ok := other.(traits.Lister)
if !ok {
return MaybeNoSuchOverloadErr(other)
return False
}
if l.Size() != otherList.Size() {
return False
@@ -358,7 +375,7 @@ func (l *concatList) Equal(other ref.Val) ref.Val {
for i := IntZero; i < l.Size().(Int); i++ {
thisElem := l.Get(i)
otherElem := otherList.Get(i)
elemEq := thisElem.Equal(otherElem)
elemEq := Equal(thisElem, otherElem)
if elemEq == False {
return False
}

View File

@@ -108,8 +108,6 @@ type mapAccessor interface {
// Find returns a value, if one exists, for the inpput key.
//
// If the key is not found the function returns (nil, false).
// If the input key is not valid for the map, or is Err or Unknown the function returns
// (Unknown|Err, false).
Find(ref.Val) (ref.Val, bool)
// Iterator returns an Iterator over the map key set.
@@ -135,11 +133,7 @@ type baseMap struct {
// Contains implements the traits.Container interface method.
func (m *baseMap) Contains(index ref.Val) ref.Val {
val, found := m.Find(index)
// When the index is not found and val is non-nil, this is an error or unknown value.
if !found && val != nil && IsUnknownOrError(val) {
return val
}
_, found := m.Find(index)
return Bool(found)
}
@@ -251,36 +245,23 @@ func (m *baseMap) ConvertToType(typeVal ref.Type) ref.Val {
func (m *baseMap) Equal(other ref.Val) ref.Val {
otherMap, ok := other.(traits.Mapper)
if !ok {
return MaybeNoSuchOverloadErr(other)
return False
}
if m.Size() != otherMap.Size() {
return False
}
it := m.Iterator()
var maybeErr ref.Val
for it.HasNext() == True {
key := it.Next()
thisVal, _ := m.Find(key)
otherVal, found := otherMap.Find(key)
if !found {
if otherVal == nil {
return False
}
if maybeErr == nil {
maybeErr = MaybeNoSuchOverloadErr(otherVal)
}
continue
return False
}
valEq := thisVal.Equal(otherVal)
valEq := Equal(thisVal, otherVal)
if valEq == False {
return False
}
if maybeErr == nil && IsUnknownOrError(valEq) {
maybeErr = valEq
}
}
if maybeErr != nil {
return maybeErr
}
return True
}
@@ -325,12 +306,10 @@ type jsonStructAccessor struct {
// found.
//
// If the key is not found the function returns (nil, false).
// If the input key is not a String, or is an Err or Unknown, the function returns
// (Unknown|Err, false).
func (a *jsonStructAccessor) Find(key ref.Val) (ref.Val, bool) {
strKey, ok := key.(String)
if !ok {
return ValOrErr(key, "unsupported key type: %v", key.Type()), false
return nil, false
}
keyVal, found := a.st[string(strKey)]
if !found {
@@ -373,39 +352,58 @@ type reflectMapAccessor struct {
// returning (value, true) if present.
//
// If the key is not found the function returns (nil, false).
// If the input key is not a String, or is an Err or Unknown, the function returns
// (Unknown|Err, false).
func (a *reflectMapAccessor) Find(key ref.Val) (ref.Val, bool) {
if IsUnknownOrError(key) {
return MaybeNoSuchOverloadErr(key), false
}
if a.refValue.Len() == 0 {
func (m *reflectMapAccessor) Find(key ref.Val) (ref.Val, bool) {
if m.refValue.Len() == 0 {
return nil, false
}
k, err := key.ConvertToNative(a.keyType)
if err != nil {
return NewErr("unsupported key type: %v", key.Type()), false
if keyVal, found := m.findInternal(key); found {
return keyVal, true
}
refKey := reflect.ValueOf(k)
val := a.refValue.MapIndex(refKey)
if val.IsValid() {
return a.NativeToValue(val.Interface()), true
}
mapIt := a.refValue.MapRange()
for mapIt.Next() {
if refKey.Kind() == mapIt.Key().Kind() {
return nil, false
switch k := key.(type) {
// Double is not a valid proto map key type, so check for the key as an int or uint.
case Double:
if ik, ok := doubleToInt64Lossless(float64(k)); ok {
if keyVal, found := m.findInternal(Int(ik)); found {
return keyVal, true
}
}
if uk, ok := doubleToUint64Lossless(float64(k)); ok {
return m.findInternal(Uint(uk))
}
// map keys of type double are not supported.
case Int:
if uk, ok := int64ToUint64Lossless(int64(k)); ok {
return m.findInternal(Uint(uk))
}
case Uint:
if ik, ok := uint64ToInt64Lossless(uint64(k)); ok {
return m.findInternal(Int(ik))
}
}
return NewErr("unsupported key type: %v", key.Type()), false
return nil, false
}
// findInternal attempts to convert the incoming key to the map's internal native type
// and then returns the value, if found.
func (m *reflectMapAccessor) findInternal(key ref.Val) (ref.Val, bool) {
k, err := key.ConvertToNative(m.keyType)
if err != nil {
return nil, false
}
refKey := reflect.ValueOf(k)
val := m.refValue.MapIndex(refKey)
if val.IsValid() {
return m.NativeToValue(val.Interface()), true
}
return nil, false
}
// Iterator creates a Golang reflection based traits.Iterator.
func (a *reflectMapAccessor) Iterator() traits.Iterator {
func (m *reflectMapAccessor) Iterator() traits.Iterator {
return &mapIterator{
TypeAdapter: a.TypeAdapter,
mapKeys: a.refValue.MapRange(),
len: a.refValue.Len(),
TypeAdapter: m.TypeAdapter,
mapKeys: m.refValue.MapRange(),
len: m.refValue.Len(),
}
}
@@ -420,24 +418,37 @@ type refValMapAccessor struct {
// Find uses native map accesses to find the key, returning (value, true) if present.
//
// If the key is not found the function returns (nil, false).
// If the input key is an Err or Unknown, the function returns (Unknown|Err, false).
func (a *refValMapAccessor) Find(key ref.Val) (ref.Val, bool) {
if IsUnknownOrError(key) {
return key, false
}
if len(a.mapVal) == 0 {
return nil, false
}
keyVal, found := a.mapVal[key]
if found {
if keyVal, found := a.mapVal[key]; found {
return keyVal, true
}
for k := range a.mapVal {
if k.Type().TypeName() == key.Type().TypeName() {
return nil, false
switch k := key.(type) {
case Double:
if ik, ok := doubleToInt64Lossless(float64(k)); ok {
if keyVal, found := a.mapVal[Int(ik)]; found {
return keyVal, true
}
}
if uk, ok := doubleToUint64Lossless(float64(k)); ok {
keyVal, found := a.mapVal[Uint(uk)]
return keyVal, found
}
// map keys of type double are not supported.
case Int:
if uk, ok := int64ToUint64Lossless(int64(k)); ok {
keyVal, found := a.mapVal[Uint(uk)]
return keyVal, found
}
case Uint:
if ik, ok := uint64ToInt64Lossless(uint64(k)); ok {
keyVal, found := a.mapVal[Int(ik)]
return keyVal, found
}
}
return NewErr("unsupported key type: %v", key.Type()), found
return nil, false
}
// Iterator produces a new traits.Iterator which iterates over the map keys via Golang reflection.
@@ -460,12 +471,10 @@ type stringMapAccessor struct {
// Find uses native map accesses to find the key, returning (value, true) if present.
//
// If the key is not found the function returns (nil, false).
// If the input key is not a String, or is an Err or Unknown, the function returns
// (Unknown|Err, false).
func (a *stringMapAccessor) Find(key ref.Val) (ref.Val, bool) {
strKey, ok := key.(String)
if !ok {
return ValOrErr(key, "unsupported key type: %v", key.Type()), false
return nil, false
}
keyVal, found := a.mapVal[string(strKey)]
if !found {
@@ -504,12 +513,10 @@ type stringIfaceMapAccessor struct {
// Find uses native map accesses to find the key, returning (value, true) if present.
//
// If the key is not found the function returns (nil, false).
// If the input key is not a String, or is an Err or Unknown, the function returns
// (Unknown|Err, false).
func (a *stringIfaceMapAccessor) Find(key ref.Val) (ref.Val, bool) {
strKey, ok := key.(String)
if !ok {
return ValOrErr(key, "unsupported key type: %v", key.Type()), false
return nil, false
}
keyVal, found := a.mapVal[string(strKey)]
if !found {
@@ -542,11 +549,7 @@ type protoMap struct {
// Contains returns whether the map contains the given key.
func (m *protoMap) Contains(key ref.Val) ref.Val {
val, found := m.Find(key)
// When the index is not found and val is non-nil, this is an error or unknown value.
if !found && val != nil && IsUnknownOrError(val) {
return val
}
_, found := m.Find(key)
return Bool(found)
}
@@ -642,7 +645,7 @@ func (m *protoMap) ConvertToType(typeVal ref.Type) ref.Val {
func (m *protoMap) Equal(other ref.Val) ref.Val {
otherMap, ok := other.(traits.Mapper)
if !ok {
return MaybeNoSuchOverloadErr(other)
return False
}
if m.value.Map.Len() != int(otherMap.Size().(Int)) {
return False
@@ -653,14 +656,10 @@ func (m *protoMap) Equal(other ref.Val) ref.Val {
valVal := m.NativeToValue(val)
otherVal, found := otherMap.Find(keyVal)
if !found {
if otherVal == nil {
retVal = False
return false
}
retVal = MaybeNoSuchOverloadErr(otherVal)
retVal = False
return false
}
valEq := valVal.Equal(otherVal)
valEq := Equal(valVal, otherVal)
if valEq != True {
retVal = valEq
return false
@@ -673,17 +672,41 @@ func (m *protoMap) Equal(other ref.Val) ref.Val {
// Find returns whether the protoreflect.Map contains the input key.
//
// If the key is not found the function returns (nil, false).
// If the input key is not a supported proto map key type, or is an Err or Unknown,
// the function returns
// (Unknown|Err, false).
func (m *protoMap) Find(key ref.Val) (ref.Val, bool) {
if IsUnknownOrError(key) {
return key, false
if keyVal, found := m.findInternal(key); found {
return keyVal, true
}
switch k := key.(type) {
// Double is not a valid proto map key type, so check for the key as an int or uint.
case Double:
if ik, ok := doubleToInt64Lossless(float64(k)); ok {
if keyVal, found := m.findInternal(Int(ik)); found {
return keyVal, true
}
}
if uk, ok := doubleToUint64Lossless(float64(k)); ok {
return m.findInternal(Uint(uk))
}
// map keys of type double are not supported.
case Int:
if uk, ok := int64ToUint64Lossless(int64(k)); ok {
return m.findInternal(Uint(uk))
}
case Uint:
if ik, ok := uint64ToInt64Lossless(uint64(k)); ok {
return m.findInternal(Int(ik))
}
}
return nil, false
}
// findInternal attempts to convert the incoming key to the map's internal native type
// and then returns the value, if found.
func (m *protoMap) findInternal(key ref.Val) (ref.Val, bool) {
// Convert the input key to the expected protobuf key type.
ntvKey, err := key.ConvertToNative(m.value.KeyType.ReflectType())
if err != nil {
return NewErr("unsupported key type: %v", key.Type()), false
return nil, false
}
// Use protoreflection to get the key value.
val := m.value.Get(protoreflect.ValueOf(ntvKey).MapKey())
@@ -694,7 +717,7 @@ func (m *protoMap) Find(key ref.Val) (ref.Val, bool) {
switch v := val.Interface().(type) {
case protoreflect.List, protoreflect.Map:
// Maps do not support list or map values
return NewErr("unsupported map element type: (%T)%v", v, v), false
return nil, false
default:
return m.NativeToValue(v), true
}

View File

@@ -83,10 +83,7 @@ func (n Null) ConvertToType(typeVal ref.Type) ref.Val {
// Equal implements ref.Val.Equal.
func (n Null) Equal(other ref.Val) ref.Val {
if NullType != other.Type() {
return ValOrErr(other, "no such overload")
}
return True
return Bool(NullType == other.Type())
}
// Type implements ref.Val.Type.

View File

@@ -109,10 +109,8 @@ func (o *protoObj) ConvertToType(typeVal ref.Type) ref.Val {
}
func (o *protoObj) Equal(other ref.Val) ref.Val {
if o.typeDesc.Name() != other.Type().TypeName() {
return MaybeNoSuchOverloadErr(other)
}
return Bool(proto.Equal(o.value, other.Value().(proto.Message)))
otherPB, ok := other.Value().(proto.Message)
return Bool(ok && pb.Equal(o.value, otherPB))
}
// IsSet tests whether a field which is defined is set to a non-default value.

View File

@@ -355,3 +355,35 @@ func uint64ToInt64Checked(v uint64) (int64, error) {
}
return int64(v), nil
}
func doubleToUint64Lossless(v float64) (uint64, bool) {
u, err := doubleToUint64Checked(v)
if err != nil {
return 0, false
}
if float64(u) != v {
return 0, false
}
return u, true
}
func doubleToInt64Lossless(v float64) (int64, bool) {
i, err := doubleToInt64Checked(v)
if err != nil {
return 0, false
}
if float64(i) != v {
return 0, false
}
return i, true
}
func int64ToUint64Lossless(v int64) (uint64, bool) {
u, err := int64ToUint64Checked(v)
return u, err == nil
}
func uint64ToInt64Lossless(v uint64) (int64, bool) {
i, err := uint64ToInt64Checked(v)
return i, err == nil
}

View File

@@ -10,6 +10,7 @@ go_library(
srcs = [
"checked.go",
"enum.go",
"equal.go",
"file.go",
"pb.go",
"type.go",
@@ -17,6 +18,7 @@ go_library(
importpath = "github.com/google/cel-go/common/types/pb",
deps = [
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//encoding/protowire:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library",
"@org_golang_google_protobuf//reflect/protoregistry:go_default_library",
@@ -34,6 +36,7 @@ go_test(
name = "go_default_test",
size = "small",
srcs = [
"equal_test.go",
"file_test.go",
"pb_test.go",
"type_test.go",

View File

@@ -0,0 +1,205 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pb
import (
"bytes"
"reflect"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/anypb"
)
// Equal returns whether two proto.Message instances are equal using the following criteria:
//
// - Messages must share the same instance of the type descriptor
// - Known set fields are compared using semantics equality
// - Bytes are compared using bytes.Equal
// - Scalar values are compared with operator ==
// - List and map types are equal if they have the same length and all elements are equal
// - Messages are equal if they share the same descriptor and all set fields are equal
// - Unknown fields are compared using byte equality
// - NaN values are not equal to each other
// - google.protobuf.Any values are unpacked before comparison
// - If the type descriptor for a protobuf.Any cannot be found, byte equality is used rather than
// semantic equality.
//
// This method of proto equality mirrors the behavior of the C++ protobuf MessageDifferencer
// whereas the golang proto.Equal implementation mirrors the Java protobuf equals() methods
// behaviors which needed to treat NaN values as equal due to Java semantics.
func Equal(x, y proto.Message) bool {
if x == nil || y == nil {
return x == nil && y == nil
}
xRef := x.ProtoReflect()
yRef := y.ProtoReflect()
return equalMessage(xRef, yRef)
}
func equalMessage(mx, my protoreflect.Message) bool {
// Note, the original proto.Equal upon which this implementation is based does not specifically handle the
// case when both messages are invalid. It is assumed that the descriptors will be equal and that byte-wise
// comparison will be used, though the semantics of validity are neither clear, nor promised within the
// proto.Equal implementation.
if mx.IsValid() != my.IsValid() || mx.Descriptor() != my.Descriptor() {
return false
}
// This is an innovation on the default proto.Equal where protobuf.Any values are unpacked before comparison
// as otherwise the Any values are compared by bytes rather than structurally.
if isAny(mx) && isAny(my) {
ax := mx.Interface().(*anypb.Any)
ay := my.Interface().(*anypb.Any)
// If the values are not the same type url, return false.
if ax.GetTypeUrl() != ay.GetTypeUrl() {
return false
}
// If the values are byte equal, then return true.
if bytes.Equal(ax.GetValue(), ay.GetValue()) {
return true
}
// Otherwise fall through to the semantic comparison of the any values.
x, err := ax.UnmarshalNew()
if err != nil {
return false
}
y, err := ay.UnmarshalNew()
if err != nil {
return false
}
// Recursively compare the unwrapped messages to ensure nested Any values are unwrapped accordingly.
return equalMessage(x.ProtoReflect(), y.ProtoReflect())
}
// Walk the set fields to determine field-wise equality
nx := 0
equal := true
mx.Range(func(fd protoreflect.FieldDescriptor, vx protoreflect.Value) bool {
nx++
equal = my.Has(fd) && equalField(fd, vx, my.Get(fd))
return equal
})
if !equal {
return false
}
// Establish the count of set fields on message y
ny := 0
my.Range(func(protoreflect.FieldDescriptor, protoreflect.Value) bool {
ny++
return true
})
// If the number of set fields is not equal return false.
if nx != ny {
return false
}
return equalUnknown(mx.GetUnknown(), my.GetUnknown())
}
func equalField(fd protoreflect.FieldDescriptor, x, y protoreflect.Value) bool {
switch {
case fd.IsMap():
return equalMap(fd, x.Map(), y.Map())
case fd.IsList():
return equalList(fd, x.List(), y.List())
default:
return equalValue(fd, x, y)
}
}
func equalMap(fd protoreflect.FieldDescriptor, x, y protoreflect.Map) bool {
if x.Len() != y.Len() {
return false
}
equal := true
x.Range(func(k protoreflect.MapKey, vx protoreflect.Value) bool {
vy := y.Get(k)
equal = y.Has(k) && equalValue(fd.MapValue(), vx, vy)
return equal
})
return equal
}
func equalList(fd protoreflect.FieldDescriptor, x, y protoreflect.List) bool {
if x.Len() != y.Len() {
return false
}
for i := x.Len() - 1; i >= 0; i-- {
if !equalValue(fd, x.Get(i), y.Get(i)) {
return false
}
}
return true
}
func equalValue(fd protoreflect.FieldDescriptor, x, y protoreflect.Value) bool {
switch fd.Kind() {
case protoreflect.BoolKind:
return x.Bool() == y.Bool()
case protoreflect.EnumKind:
return x.Enum() == y.Enum()
case protoreflect.Int32Kind, protoreflect.Sint32Kind,
protoreflect.Int64Kind, protoreflect.Sint64Kind,
protoreflect.Sfixed32Kind, protoreflect.Sfixed64Kind:
return x.Int() == y.Int()
case protoreflect.Uint32Kind, protoreflect.Uint64Kind,
protoreflect.Fixed32Kind, protoreflect.Fixed64Kind:
return x.Uint() == y.Uint()
case protoreflect.FloatKind, protoreflect.DoubleKind:
return x.Float() == y.Float()
case protoreflect.StringKind:
return x.String() == y.String()
case protoreflect.BytesKind:
return bytes.Equal(x.Bytes(), y.Bytes())
case protoreflect.MessageKind, protoreflect.GroupKind:
return equalMessage(x.Message(), y.Message())
default:
return x.Interface() == y.Interface()
}
}
func equalUnknown(x, y protoreflect.RawFields) bool {
lenX := len(x)
lenY := len(y)
if lenX != lenY {
return false
}
if lenX == 0 {
return true
}
if bytes.Equal([]byte(x), []byte(y)) {
return true
}
mx := make(map[protoreflect.FieldNumber]protoreflect.RawFields)
my := make(map[protoreflect.FieldNumber]protoreflect.RawFields)
for len(x) > 0 {
fnum, _, n := protowire.ConsumeField(x)
mx[fnum] = append(mx[fnum], x[:n]...)
x = x[n:]
}
for len(y) > 0 {
fnum, _, n := protowire.ConsumeField(y)
my[fnum] = append(my[fnum], y[:n]...)
y = y[n:]
}
return reflect.DeepEqual(mx, my)
}
func isAny(m protoreflect.Message) bool {
return string(m.Descriptor().FullName()) == "google.protobuf.Any"
}

View File

@@ -151,10 +151,7 @@ func (s String) ConvertToType(typeVal ref.Type) ref.Val {
// Equal implements ref.Val.Equal.
func (s String) Equal(other ref.Val) ref.Val {
otherString, ok := other.(String)
if !ok {
return MaybeNoSuchOverloadErr(other)
}
return Bool(s == otherString)
return Bool(ok && s == otherString)
}
// Match implements traits.Matcher.Match.

View File

@@ -134,10 +134,8 @@ func (t Timestamp) ConvertToType(typeVal ref.Type) ref.Val {
// Equal implements ref.Val.Equal.
func (t Timestamp) Equal(other ref.Val) ref.Val {
if TimestampType != other.Type() {
return MaybeNoSuchOverloadErr(other)
}
return Bool(t.Time.Equal(other.(Timestamp).Time))
otherTime, ok := other.(Timestamp)
return Bool(ok && t.Time.Equal(otherTime.Time))
}
// Receive implements traits.Reciever.Receive.

View File

@@ -25,3 +25,8 @@ type Lister interface {
Iterable
Sizer
}
// MutableLister interface which emits an immutable result after an intermediate computation.
type MutableLister interface {
ToImmutableList() Lister
}

View File

@@ -71,10 +71,8 @@ func (t *TypeValue) ConvertToType(typeVal ref.Type) ref.Val {
// Equal implements ref.Val.Equal.
func (t *TypeValue) Equal(other ref.Val) ref.Val {
if TypeType != other.Type() {
return ValOrErr(other, "no such overload")
}
return Bool(t.TypeName() == other.(ref.Type).TypeName())
otherType, ok := other.(ref.Type)
return Bool(ok && t.TypeName() == otherType.TypeName())
}
// HasTrait indicates whether the type supports the given trait.

View File

@@ -16,6 +16,7 @@ package types
import (
"fmt"
"math"
"reflect"
"strconv"
@@ -65,17 +66,19 @@ func (i Uint) Add(other ref.Val) ref.Val {
// Compare implements traits.Comparer.Compare.
func (i Uint) Compare(other ref.Val) ref.Val {
otherUint, ok := other.(Uint)
if !ok {
switch ov := other.(type) {
case Double:
if math.IsNaN(float64(ov)) {
return NewErr("NaN values cannot be ordered")
}
return compareUintDouble(i, ov)
case Int:
return compareUintInt(i, ov)
case Uint:
return compareUint(i, ov)
default:
return MaybeNoSuchOverloadErr(other)
}
if i < otherUint {
return IntNegOne
}
if i > otherUint {
return IntOne
}
return IntZero
}
// ConvertToNative implements ref.Val.ConvertToNative.
@@ -176,11 +179,19 @@ func (i Uint) Divide(other ref.Val) ref.Val {
// Equal implements ref.Val.Equal.
func (i Uint) Equal(other ref.Val) ref.Val {
otherUint, ok := other.(Uint)
if !ok {
return MaybeNoSuchOverloadErr(other)
switch ov := other.(type) {
case Double:
if math.IsNaN(float64(ov)) {
return False
}
return Bool(compareUintDouble(i, ov) == 0)
case Int:
return Bool(compareUintInt(i, ov) == 0)
case Uint:
return Bool(i == ov)
default:
return False
}
return Bool(i == otherUint)
}
// Modulo implements traits.Modder.Modulo.

View File

@@ -36,3 +36,13 @@ func IsPrimitiveType(val ref.Val) bool {
}
return false
}
// Equal returns whether the two ref.Value are heterogeneously equivalent.
func Equal(lhs ref.Val, rhs ref.Val) ref.Val {
lNull := lhs == NullValue
rNull := rhs == NullValue
if lNull || rNull {
return Bool(lNull == rNull)
}
return lhs.Equal(rhs)
}