Improve CEL cost tests to catch unhandled estimates or types
This commit is contained in:
		@@ -17,6 +17,7 @@ limitations under the License.
 | 
			
		||||
package library
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"math"
 | 
			
		||||
 | 
			
		||||
	"github.com/google/cel-go/checker"
 | 
			
		||||
@@ -25,9 +26,28 @@ import (
 | 
			
		||||
	"github.com/google/cel-go/common/types"
 | 
			
		||||
	"github.com/google/cel-go/common/types/ref"
 | 
			
		||||
	"github.com/google/cel-go/common/types/traits"
 | 
			
		||||
 | 
			
		||||
	"k8s.io/apiserver/pkg/cel"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// panicOnUnknown makes cost estimate functions panic on unrecognized functions.
 | 
			
		||||
// This is only set to true for unit tests.
 | 
			
		||||
var panicOnUnknown = false
 | 
			
		||||
 | 
			
		||||
// builtInFunctions is a list of functions used in cost tests that are not handled by CostEstimator.
 | 
			
		||||
var knownUnhandledFunctions = map[string]bool{
 | 
			
		||||
	"uint":          true,
 | 
			
		||||
	"duration":      true,
 | 
			
		||||
	"bytes":         true,
 | 
			
		||||
	"timestamp":     true,
 | 
			
		||||
	"value":         true,
 | 
			
		||||
	"_==_":          true,
 | 
			
		||||
	"_&&_":          true,
 | 
			
		||||
	"_>_":           true,
 | 
			
		||||
	"!_":            true,
 | 
			
		||||
	"strings.quote": true,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CostEstimator implements CEL's interpretable.ActualCostEstimator and checker.CostEstimator.
 | 
			
		||||
type CostEstimator struct {
 | 
			
		||||
	// SizeEstimator provides a CostEstimator.EstimateSize that this CostEstimator will delegate size estimation
 | 
			
		||||
@@ -106,7 +126,7 @@ func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, re
 | 
			
		||||
			cost := uint64(math.Ceil(float64(actualSize(args[0])) * 2 * common.StringTraversalCostFactor))
 | 
			
		||||
			return &cost
 | 
			
		||||
		}
 | 
			
		||||
	case "masked", "prefixLength", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast":
 | 
			
		||||
	case "masked", "prefixLength", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast", "isGlobalUnicast":
 | 
			
		||||
		// IP and CIDR accessors are nominal cost.
 | 
			
		||||
		cost := uint64(1)
 | 
			
		||||
		return &cost
 | 
			
		||||
@@ -185,6 +205,13 @@ func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, re
 | 
			
		||||
	case "sign", "asInteger", "isInteger", "asApproximateFloat", "isGreaterThan", "isLessThan", "compareTo", "add", "sub":
 | 
			
		||||
		cost := uint64(1)
 | 
			
		||||
		return &cost
 | 
			
		||||
	case "getScheme", "getHostname", "getHost", "getPort", "getEscapedPath", "getQuery":
 | 
			
		||||
		// url accessors
 | 
			
		||||
		cost := uint64(1)
 | 
			
		||||
		return &cost
 | 
			
		||||
	}
 | 
			
		||||
	if panicOnUnknown && !knownUnhandledFunctions[function] {
 | 
			
		||||
		panic(fmt.Errorf("CallCost: unhandled function %q or args %v", function, args))
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
@@ -359,7 +386,7 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch
 | 
			
		||||
			// So we double the cost of parsing the string.
 | 
			
		||||
			return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(2 * common.StringTraversalCostFactor)}
 | 
			
		||||
		}
 | 
			
		||||
	case "masked", "prefixLength", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast":
 | 
			
		||||
	case "masked", "prefixLength", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast", "isGlobalUnicast":
 | 
			
		||||
		// IP and CIDR accessors are nominal cost.
 | 
			
		||||
		return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
 | 
			
		||||
	case "containsIP":
 | 
			
		||||
@@ -414,6 +441,12 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch
 | 
			
		||||
		return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
 | 
			
		||||
	case "sign", "asInteger", "isInteger", "asApproximateFloat", "isGreaterThan", "isLessThan", "compareTo", "add", "sub":
 | 
			
		||||
		return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
 | 
			
		||||
	case "getScheme", "getHostname", "getHost", "getPort", "getEscapedPath", "getQuery":
 | 
			
		||||
		// url accessors
 | 
			
		||||
		return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
 | 
			
		||||
	}
 | 
			
		||||
	if panicOnUnknown && !knownUnhandledFunctions[function] {
 | 
			
		||||
		panic(fmt.Errorf("EstimateCallCost: unhandled function %q, target %v, args %v", function, target, args))
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
@@ -422,6 +455,10 @@ func actualSize(value ref.Val) uint64 {
 | 
			
		||||
	if sz, ok := value.(traits.Sizer); ok {
 | 
			
		||||
		return uint64(sz.Size().(types.Int))
 | 
			
		||||
	}
 | 
			
		||||
	if panicOnUnknown {
 | 
			
		||||
		// debug.PrintStack()
 | 
			
		||||
		panic(fmt.Errorf("actualSize: non-sizer type %T", value))
 | 
			
		||||
	}
 | 
			
		||||
	return 1
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1053,6 +1053,10 @@ func TestSetsCost(t *testing.T) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func testCost(t *testing.T, expr string, expectEsimatedCost checker.CostEstimate, expectRuntimeCost uint64) {
 | 
			
		||||
	originalPanicOnUnknown := panicOnUnknown
 | 
			
		||||
	panicOnUnknown = true
 | 
			
		||||
	t.Cleanup(func() { panicOnUnknown = originalPanicOnUnknown })
 | 
			
		||||
 | 
			
		||||
	est := &CostEstimator{SizeEstimator: &testCostEstimator{}}
 | 
			
		||||
	env, err := cel.NewEnv(
 | 
			
		||||
		ext.Strings(ext.StringsVersion(2)),
 | 
			
		||||
@@ -1168,6 +1172,11 @@ func TestSize(t *testing.T) {
 | 
			
		||||
			expectSize: checker.SizeEstimate{Min: 2, Max: 4},
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	originalPanicOnUnknown := panicOnUnknown
 | 
			
		||||
	panicOnUnknown = true
 | 
			
		||||
	t.Cleanup(func() { panicOnUnknown = originalPanicOnUnknown })
 | 
			
		||||
 | 
			
		||||
	est := &CostEstimator{SizeEstimator: &testCostEstimator{}}
 | 
			
		||||
	for _, tc := range cases {
 | 
			
		||||
		t.Run(tc.name, func(t *testing.T) {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user