Bump CEL to 0.11.2

This commit is contained in:
Joe Betz
2022-03-24 11:34:14 -04:00
parent e7845861a5
commit 4c90653d19
139 changed files with 2688 additions and 1637 deletions

View File

@@ -168,11 +168,9 @@ func (c *checker) checkSelect(e *exprpb.Expr) {
if found {
ident := c.env.LookupIdent(qname)
if ident != nil {
if sel.TestOnly {
c.errors.expressionDoesNotSelectField(c.location(e))
c.setType(e, decls.Bool)
return
}
// We don't check for a TestOnly expression here since the `found` result is
// always going to be false for TestOnly expressions.
// Rewrite the node to be a variable reference to the resolved fully-qualified
// variable name.
c.setType(e, ident.GetIdent().Type)
@@ -208,7 +206,7 @@ func (c *checker) checkSelect(e *exprpb.Expr) {
resultType = fieldType.Type
}
case kindTypeParam:
// Set the operand type to DYN to prevent assignment to a potentionally incorrect type
// Set the operand type to DYN to prevent assignment to a potentially incorrect type
// at a later point in type-checking. The isAssignable call will update the type
// substitutions for the type param under the covers.
c.isAssignable(decls.Dyn, targetType)
@@ -323,6 +321,12 @@ func (c *checker) resolveOverload(
var resultType *exprpb.Type
var checkedRef *exprpb.Reference
for _, overload := range fn.GetFunction().Overloads {
// Determine whether the overload is currently considered.
if c.env.isOverloadDisabled(overload.GetOverloadId()) {
continue
}
// Ensure the call style for the overload matches.
if (target == nil && overload.IsInstanceFunction) ||
(target != nil && !overload.IsInstanceFunction) {
// not a compatible call style.
@@ -330,26 +334,26 @@ func (c *checker) resolveOverload(
}
overloadType := decls.NewFunctionType(overload.ResultType, overload.Params...)
if len(overload.TypeParams) > 0 {
if len(overload.GetTypeParams()) > 0 {
// Instantiate overload's type with fresh type variables.
substitutions := newMapping()
for _, typePar := range overload.TypeParams {
for _, typePar := range overload.GetTypeParams() {
substitutions.add(decls.NewTypeParamType(typePar), c.newTypeVar())
}
overloadType = substitute(substitutions, overloadType, false)
}
candidateArgTypes := overloadType.GetFunction().ArgTypes
candidateArgTypes := overloadType.GetFunction().GetArgTypes()
if c.isAssignableList(argTypes, candidateArgTypes) {
if checkedRef == nil {
checkedRef = newFunctionReference(overload.OverloadId)
checkedRef = newFunctionReference(overload.GetOverloadId())
} else {
checkedRef.OverloadId = append(checkedRef.OverloadId, overload.OverloadId)
checkedRef.OverloadId = append(checkedRef.OverloadId, overload.GetOverloadId())
}
// First matching overload, determines result type.
fnResultType := substitute(c.mappings,
overloadType.GetFunction().ResultType,
overloadType.GetFunction().GetResultType(),
false)
if resultType == nil {
resultType = fnResultType
@@ -478,7 +482,7 @@ func (c *checker) checkComprehension(e *exprpb.Expr) {
// Ranges over the keys.
varType = rangeType.GetMapType().KeyType
case kindDyn, kindError, kindTypeParam:
// Set the range type to DYN to prevent assignment to a potentionally incorrect type
// Set the range type to DYN to prevent assignment to a potentially incorrect type
// at a later point in type-checking. The isAssignable call will update the type
// substitutions for the type param under the covers.
c.isAssignable(decls.Dyn, rangeType)

View File

@@ -121,7 +121,7 @@ type SizeEstimate struct {
}
// Add adds to another SizeEstimate and returns the sum.
// If add would result in an uint64 overflow, the result is Maxuint64.
// If add would result in an uint64 overflow, the result is math.MaxUint64.
func (se SizeEstimate) Add(sizeEstimate SizeEstimate) SizeEstimate {
return SizeEstimate{
addUint64NoOverflow(se.Min, sizeEstimate.Min),
@@ -130,7 +130,7 @@ func (se SizeEstimate) Add(sizeEstimate SizeEstimate) SizeEstimate {
}
// Multiply multiplies by another SizeEstimate and returns the product.
// If multiply would result in an uint64 overflow, the result is Maxuint64.
// If multiply would result in an uint64 overflow, the result is math.MaxUint64.
func (se SizeEstimate) Multiply(sizeEstimate SizeEstimate) SizeEstimate {
return SizeEstimate{
multiplyUint64NoOverflow(se.Min, sizeEstimate.Min),
@@ -148,7 +148,7 @@ func (se SizeEstimate) MultiplyByCostFactor(costPerUnit float64) CostEstimate {
}
// MultiplyByCost multiplies by the cost and returns the product.
// If multiply would result in an uint64 overflow, the result is Maxuint64.
// If multiply would result in an uint64 overflow, the result is math.MaxUint64.
func (se SizeEstimate) MultiplyByCost(cost CostEstimate) CostEstimate {
return CostEstimate{
multiplyUint64NoOverflow(se.Min, cost.Min),
@@ -175,7 +175,7 @@ type CostEstimate struct {
}
// Add adds the costs and returns the sum.
// If add would result in an uint64 overflow for the min or max, the value is set to Maxuint64.
// If add would result in an uint64 overflow for the min or max, the value is set to math.MaxUint64.
func (ce CostEstimate) Add(cost CostEstimate) CostEstimate {
return CostEstimate{
addUint64NoOverflow(ce.Min, cost.Min),
@@ -184,7 +184,7 @@ func (ce CostEstimate) Add(cost CostEstimate) CostEstimate {
}
// Multiply multiplies by the cost and returns the product.
// If multiply would result in an uint64 overflow, the result is Maxuint64.
// If multiply would result in an uint64 overflow, the result is math.MaxUint64.
func (ce CostEstimate) Multiply(cost CostEstimate) CostEstimate {
return CostEstimate{
multiplyUint64NoOverflow(ce.Min, cost.Min),

View File

@@ -33,6 +33,19 @@ func NewScopes() *Scopes {
}
}
// Copy creates a copy of the current Scopes values, including a copy of its parent if non-nil.
func (s *Scopes) Copy() *Scopes {
cpy := NewScopes()
if s == nil {
return cpy
}
if s.parent != nil {
cpy.parent = s.parent.Copy()
}
cpy.scopes = s.scopes.copy()
return cpy
}
// Push creates a new Scopes value which references the current Scope as its parent.
func (s *Scopes) Push() *Scopes {
return &Scopes{
@@ -80,9 +93,9 @@ func (s *Scopes) FindIdentInScope(name string) *exprpb.Decl {
return nil
}
// AddFunction adds the function Decl to the current scope.
// SetFunction adds the function Decl to the current scope.
// Note: Any previous entry for a function in the current scope with the same name is overwritten.
func (s *Scopes) AddFunction(fn *exprpb.Decl) {
func (s *Scopes) SetFunction(fn *exprpb.Decl) {
s.scopes.functions[fn.Name] = fn
}
@@ -100,13 +113,30 @@ func (s *Scopes) FindFunction(name string) *exprpb.Decl {
}
// Group is a set of Decls that is pushed on or popped off a Scopes as a unit.
// Contains separate namespaces for idenifier and function Decls.
// Contains separate namespaces for identifier and function Decls.
// (Should be named "Scope" perhaps?)
type Group struct {
idents map[string]*exprpb.Decl
functions map[string]*exprpb.Decl
}
// copy creates a new Group instance with a shallow copy of the variables and functions.
// If callers need to mutate the exprpb.Decl definitions for a Function, they should copy-on-write.
func (g *Group) copy() *Group {
cpy := &Group{
idents: make(map[string]*exprpb.Decl, len(g.idents)),
functions: make(map[string]*exprpb.Decl, len(g.functions)),
}
for n, id := range g.idents {
cpy.idents[n] = id
}
for n, fn := range g.functions {
cpy.functions[n] = fn
}
return cpy
}
// newGroup creates a new Group with empty maps for identifiers and functions.
func newGroup() *Group {
return &Group{
idents: make(map[string]*exprpb.Decl),

View File

@@ -18,6 +18,8 @@ import (
"fmt"
"strings"
"google.golang.org/protobuf/proto"
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/overloads"
@@ -99,6 +101,9 @@ func NewEnv(container *containers.Container, provider ref.TypeProvider, opts ...
if envOptions.crossTypeNumericComparisons {
filteredOverloadIDs = make(map[string]struct{})
}
if envOptions.validatedDeclarations != nil {
declarations = envOptions.validatedDeclarations.Copy()
}
return &Env{
container: container,
provider: provider,
@@ -117,7 +122,7 @@ func (e *Env) Add(decls ...*exprpb.Decl) error {
case *exprpb.Decl_Ident:
errMsgs = append(errMsgs, e.addIdent(sanitizeIdent(decl)))
case *exprpb.Decl_Function:
errMsgs = append(errMsgs, e.addFunction(sanitizeFunction(decl))...)
errMsgs = append(errMsgs, e.setFunction(sanitizeFunction(decl))...)
}
}
return formatError(errMsgs)
@@ -204,22 +209,22 @@ func (e *Env) addOverload(f *exprpb.Decl, overload *exprpb.Decl_FunctionDecl_Ove
return errMsgs
}
// addFunction adds the function Decl to the Env.
// setFunction adds the function Decl to the Env.
// Adds a function decl if one doesn't already exist, then adds all overloads from the Decl.
// If overload overlaps with an existing overload, adds to the errors in the Env instead.
func (e *Env) addFunction(decl *exprpb.Decl) []errorMsg {
func (e *Env) setFunction(decl *exprpb.Decl) []errorMsg {
current := e.declarations.FindFunction(decl.Name)
if current == nil {
//Add the function declaration without overloads and check the overloads below.
current = decls.NewFunction(decl.Name)
e.declarations.AddFunction(current)
} else {
// Copy on write since we don't know where this original definition came from.
current = proto.Clone(current).(*exprpb.Decl)
}
e.declarations.SetFunction(current)
errorMsgs := make([]errorMsg, 0)
for _, overload := range decl.GetFunction().GetOverloads() {
if _, found := e.filteredOverloadIDs[overload.GetOverloadId()]; found {
continue
}
errorMsgs = append(errorMsgs, e.addOverload(current, overload)...)
}
return errorMsgs
@@ -236,6 +241,12 @@ func (e *Env) addIdent(decl *exprpb.Decl) errorMsg {
return ""
}
// isOverloadDisabled returns whether the overloadID is disabled in the current environment.
func (e *Env) isOverloadDisabled(overloadID string) bool {
_, found := e.filteredOverloadIDs[overloadID]
return found
}
// sanitizeFunction replaces well-known types referenced by message name with their equivalent
// CEL built-in type instances.
func sanitizeFunction(decl *exprpb.Decl) *exprpb.Decl {
@@ -313,6 +324,12 @@ func getObjectWellKnownType(t *exprpb.Type) *exprpb.Type {
return pb.CheckedWellKnowns[t.GetMessageType()]
}
// validatedDeclarations returns a reference to the validated variable and function declaration scope stack.
// must be copied before use.
func (e *Env) validatedDeclarations() *decls.Scopes {
return e.declarations
}
// enterScope creates a new Env instance with a new innermost declaration scope.
func (e *Env) enterScope() *Env {
childDecls := e.declarations.Push()

View File

@@ -29,10 +29,6 @@ func (e *typeErrors) undeclaredReference(l common.Location, container string, na
e.ReportError(l, "undeclared reference to '%s' (in container '%s')", name, container)
}
func (e *typeErrors) expressionDoesNotSelectField(l common.Location) {
e.ReportError(l, "expression does not select a field")
}
func (e *typeErrors) typeDoesNotSupportFieldSelection(l common.Location, t *exprpb.Type) {
e.ReportError(l, "type '%s' does not support field selection", t)
}

View File

@@ -14,9 +14,12 @@
package checker
import "github.com/google/cel-go/checker/decls"
type options struct {
crossTypeNumericComparisons bool
homogeneousAggregateLiterals bool
validatedDeclarations *decls.Scopes
}
// Option is a functional option for configuring the type-checker
@@ -39,3 +42,12 @@ func HomogeneousAggregateLiterals(enabled bool) Option {
return nil
}
}
// ValidatedDeclarations provides a references to validated declarations which will be copied
// into new checker instances.
func ValidatedDeclarations(env *Env) Option {
return func(opts *options) error {
opts.validatedDeclarations = env.validatedDeclarations()
return nil
}
}

View File

@@ -149,21 +149,6 @@ func isEqualOrLessSpecific(t1 *exprpb.Type, t2 *exprpb.Type) bool {
}
}
return true
case kindFunction:
fn1 := t1.GetFunction()
fn2 := t2.GetFunction()
if len(fn1.ArgTypes) != len(fn2.ArgTypes) {
return false
}
if !isEqualOrLessSpecific(fn1.ResultType, fn2.ResultType) {
return false
}
for i, a1 := range fn1.ArgTypes {
if !isEqualOrLessSpecific(a1, fn2.ArgTypes[i]) {
return false
}
}
return true
case kindList:
return isEqualOrLessSpecific(t1.GetListType().ElemType, t2.GetListType().ElemType)
case kindMap:
@@ -180,43 +165,26 @@ func isEqualOrLessSpecific(t1 *exprpb.Type, t2 *exprpb.Type) bool {
/// internalIsAssignable returns true if t1 is assignable to t2.
func internalIsAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) bool {
// A type is always assignable to itself.
// Early terminate the call to avoid cases of infinite recursion.
if proto.Equal(t1, t2) {
return true
}
// Process type parameters.
kind1, kind2 := kindOf(t1), kindOf(t2)
if kind2 == kindTypeParam {
if t2Sub, found := m.find(t2); found {
// If the types are compatible, pick the more general type and return true
if !internalIsAssignable(m, t1, t2Sub) {
return false
}
m.add(t2, mostGeneral(t1, t2Sub))
// If t2 is a valid type substitution for t1, return true.
valid, t2HasSub := isValidTypeSubstitution(m, t1, t2)
if valid {
return true
}
if notReferencedIn(m, t2, t1) {
m.add(t2, t1)
return true
// If t2 is not a valid type sub for t1, and already has a known substitution return false
// since it is not possible for t1 to be a substitution for t2.
if !valid && t2HasSub {
return false
}
// Otherwise, fall through to check whether t1 is a possible substitution for t2.
}
if kind1 == kindTypeParam {
// For the lower type bound, we currently do not perform adjustment. The restricted
// way we use type parameters in lower type bounds, it is not necessary, but may
// become if we generalize type unification.
if t1Sub, found := m.find(t1); found {
// If the types are compatible, pick the more general type and return true
if !internalIsAssignable(m, t1Sub, t2) {
return false
}
m.add(t1, mostGeneral(t1Sub, t2))
return true
}
if notReferencedIn(m, t1, t2) {
m.add(t1, t2)
return true
}
// Return whether t1 is a valid substitution for t2. If not, do no additional checks as the
// possible type substitutions have been searched in both directions.
valid, _ := isValidTypeSubstitution(m, t2, t1)
return valid
}
// Next check for wildcard types.
@@ -262,18 +230,40 @@ func internalIsAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) bool {
}
}
// isValidTypeSubstitution returns whether t2 (or its type substitution) is a valid type
// substitution for t1, and whether t2 has a type substitution in mapping m.
//
// The type t2 is a valid substitution for t1 if any of the following statements is true
// - t2 has a type substitition (t2sub) equal to t1
// - t2 has a type substitution (t2sub) assignable to t1
// - t2 does not occur within t1.
func isValidTypeSubstitution(m *mapping, t1, t2 *exprpb.Type) (valid, hasSub bool) {
if t2Sub, found := m.find(t2); found {
kind1, kind2 := kindOf(t1), kindOf(t2)
if kind1 == kind2 && proto.Equal(t1, t2Sub) {
return true, true
}
// If the types are compatible, pick the more general type and return true
if internalIsAssignable(m, t1, t2Sub) {
m.add(t2, mostGeneral(t1, t2Sub))
return true, true
}
return false, true
}
if notReferencedIn(m, t2, t1) {
m.add(t2, t1)
return true, false
}
return false, false
}
// internalIsAssignableAbstractType returns true if the abstract type names agree and all type
// parameters are assignable.
func internalIsAssignableAbstractType(m *mapping,
a1 *exprpb.Type_AbstractType,
a2 *exprpb.Type_AbstractType) bool {
if a1.GetName() != a2.GetName() {
return false
}
if internalIsAssignableList(m, a1.GetParameterTypes(), a2.GetParameterTypes()) {
return true
}
return false
return a1.GetName() == a2.GetName() &&
internalIsAssignableList(m, a1.GetParameterTypes(), a2.GetParameterTypes())
}
// internalIsAssignableFunction returns true if the function return type and arg types are
@@ -421,15 +411,6 @@ func notReferencedIn(m *mapping, t *exprpb.Type, withinType *exprpb.Type) bool {
}
}
return true
case kindFunction:
fn := withinType.GetFunction()
types := flattenFunctionTypes(fn)
for _, a := range types {
if !notReferencedIn(m, t, a) {
return false
}
}
return true
case kindList:
return notReferencedIn(m, t, withinType.GetListType().ElemType)
case kindMap:
@@ -454,7 +435,6 @@ func substitute(m *mapping, t *exprpb.Type, typeParamToDyn bool) *exprpb.Type {
}
switch kind {
case kindAbstract:
// TODO: implement!
at := t.GetAbstractType()
params := make([]*exprpb.Type, len(at.GetParameterTypes()))
for i, p := range at.GetParameterTypes() {