Bump cel-go to v0.17.7

This commit is contained in:
Cici Huang
2023-10-30 16:46:27 +00:00
parent 16fc00493b
commit 70c1f2143f
27 changed files with 225 additions and 60 deletions

View File

@@ -133,6 +133,7 @@ func PresenceTestHasCost(hasCost bool) CostTrackerOption {
func NewCostTracker(estimator ActualCostEstimator, opts ...CostTrackerOption) (*CostTracker, error) {
tracker := &CostTracker{
Estimator: estimator,
overloadTrackers: map[string]FunctionTracker{},
presenceTestHasCost: true,
}
for _, opt := range opts {
@@ -144,9 +145,24 @@ func NewCostTracker(estimator ActualCostEstimator, opts ...CostTrackerOption) (*
return tracker, nil
}
// OverloadCostTracker binds an overload ID to a runtime FunctionTracker implementation.
//
// OverloadCostTracker instances augment or override ActualCostEstimator decisions, allowing for versioned and/or
// optional cost tracking changes.
func OverloadCostTracker(overloadID string, fnTracker FunctionTracker) CostTrackerOption {
return func(tracker *CostTracker) error {
tracker.overloadTrackers[overloadID] = fnTracker
return nil
}
}
// FunctionTracker computes the actual cost of evaluating the functions with the given arguments and result.
type FunctionTracker func(args []ref.Val, result ref.Val) *uint64
// CostTracker represents the information needed for tracking runtime cost.
type CostTracker struct {
Estimator ActualCostEstimator
overloadTrackers map[string]FunctionTracker
Limit *uint64
presenceTestHasCost bool
@@ -159,10 +175,19 @@ func (c *CostTracker) ActualCost() uint64 {
return c.cost
}
func (c *CostTracker) costCall(call InterpretableCall, argValues []ref.Val, result ref.Val) uint64 {
func (c *CostTracker) costCall(call InterpretableCall, args []ref.Val, result ref.Val) uint64 {
var cost uint64
if len(c.overloadTrackers) != 0 {
if tracker, found := c.overloadTrackers[call.OverloadID()]; found {
callCost := tracker(args, result)
if callCost != nil {
cost += *callCost
return cost
}
}
}
if c.Estimator != nil {
callCost := c.Estimator.CallCost(call.Function(), call.OverloadID(), argValues, result)
callCost := c.Estimator.CallCost(call.Function(), call.OverloadID(), args, result)
if callCost != nil {
cost += *callCost
return cost
@@ -173,11 +198,11 @@ func (c *CostTracker) costCall(call InterpretableCall, argValues []ref.Val, resu
switch call.OverloadID() {
// O(n) functions
case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString, overloads.ExtQuoteString, overloads.ExtFormatString:
cost += uint64(math.Ceil(float64(c.actualSize(argValues[0])) * common.StringTraversalCostFactor))
cost += uint64(math.Ceil(float64(c.actualSize(args[0])) * common.StringTraversalCostFactor))
case overloads.InList:
// If a list is composed entirely of constant values this is O(1), but we don't account for that here.
// We just assume all list containment checks are O(n).
cost += c.actualSize(argValues[1])
cost += c.actualSize(args[1])
// O(min(m, n)) functions
case overloads.LessString, overloads.GreaterString, overloads.LessEqualsString, overloads.GreaterEqualsString,
overloads.LessBytes, overloads.GreaterBytes, overloads.LessEqualsBytes, overloads.GreaterEqualsBytes,
@@ -185,8 +210,8 @@ func (c *CostTracker) costCall(call InterpretableCall, argValues []ref.Val, resu
// When we check the equality of 2 scalar values (e.g. 2 integers, 2 floating-point numbers, 2 booleans etc.),
// the CostTracker.actualSize() function by definition returns 1 for each operand, resulting in an overall cost
// of 1.
lhsSize := c.actualSize(argValues[0])
rhsSize := c.actualSize(argValues[1])
lhsSize := c.actualSize(args[0])
rhsSize := c.actualSize(args[1])
minSize := lhsSize
if rhsSize < minSize {
minSize = rhsSize
@@ -195,23 +220,23 @@ func (c *CostTracker) costCall(call InterpretableCall, argValues []ref.Val, resu
// O(m+n) functions
case overloads.AddString, overloads.AddBytes:
// In the worst case scenario, we would need to reallocate a new backing store and copy both operands over.
cost += uint64(math.Ceil(float64(c.actualSize(argValues[0])+c.actualSize(argValues[1])) * common.StringTraversalCostFactor))
cost += uint64(math.Ceil(float64(c.actualSize(args[0])+c.actualSize(args[1])) * common.StringTraversalCostFactor))
// O(nm) functions
case overloads.MatchesString:
// https://swtch.com/~rsc/regexp/regexp1.html applies to RE2 implementation supported by CEL
// Add one to string length for purposes of cost calculation to prevent product of string and regex to be 0
// in case where string is empty but regex is still expensive.
strCost := uint64(math.Ceil((1.0 + float64(c.actualSize(argValues[0]))) * common.StringTraversalCostFactor))
strCost := uint64(math.Ceil((1.0 + float64(c.actualSize(args[0]))) * common.StringTraversalCostFactor))
// We don't know how many expressions are in the regex, just the string length (a huge
// improvement here would be to somehow get a count the number of expressions in the regex or
// how many states are in the regex state machine and use that to measure regex cost).
// For now, we're making a guess that each expression in a regex is typically at least 4 chars
// in length.
regexCost := uint64(math.Ceil(float64(c.actualSize(argValues[1])) * common.RegexStringLengthCostFactor))
regexCost := uint64(math.Ceil(float64(c.actualSize(args[1])) * common.RegexStringLengthCostFactor))
cost += strCost * regexCost
case overloads.ContainsString:
strCost := uint64(math.Ceil(float64(c.actualSize(argValues[0])) * common.StringTraversalCostFactor))
substrCost := uint64(math.Ceil(float64(c.actualSize(argValues[1])) * common.StringTraversalCostFactor))
strCost := uint64(math.Ceil(float64(c.actualSize(args[0])) * common.StringTraversalCostFactor))
substrCost := uint64(math.Ceil(float64(c.actualSize(args[1])) * common.StringTraversalCostFactor))
cost += strCost * substrCost
default: