generated: ./hack/update-vendor.sh

This commit is contained in:
Jiahui Feng
2024-04-22 10:54:32 -07:00
parent 51af569512
commit 350fcf957e
189 changed files with 14707 additions and 9152 deletions

View File

@@ -1,26 +0,0 @@
Copyright 2021 The ANTLR Project
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from this
software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -1,68 +0,0 @@
/*
Package antlr implements the Go version of the ANTLR 4 runtime.
# The ANTLR Tool
ANTLR (ANother Tool for Language Recognition) is a powerful parser generator for reading, processing, executing,
or translating structured text or binary files. It's widely used to build languages, tools, and frameworks.
From a grammar, ANTLR generates a parser that can build parse trees and also generates a listener interface
(or visitor) that makes it easy to respond to the recognition of phrases of interest.
# Code Generation
ANTLR supports the generation of code in a number of [target languages], and the generated code is supported by a
runtime library, written specifically to support the generated code in the target language. This library is the
runtime for the Go target.
To generate code for the go target, it is generally recommended to place the source grammar files in a package of
their own, and use the `.sh` script method of generating code, using the go generate directive. In that same directory
it is usual, though not required, to place the antlr tool that should be used to generate the code. That does mean
that the antlr tool JAR file will be checked in to your source code control though, so you are free to use any other
way of specifying the version of the ANTLR tool to use, such as aliasing in `.zshrc` or equivalent, or a profile in
your IDE, or configuration in your CI system.
Here is a general template for an ANTLR based recognizer in Go:
.
├── myproject
├── parser
│ ├── mygrammar.g4
│ ├── antlr-4.12.0-complete.jar
│ ├── error_listeners.go
│ ├── generate.go
│ ├── generate.sh
├── go.mod
├── go.sum
├── main.go
└── main_test.go
Make sure that the package statement in your grammar file(s) reflects the go package they exist in.
The generate.go file then looks like this:
package parser
//go:generate ./generate.sh
And the generate.sh file will look similar to this:
#!/bin/sh
alias antlr4='java -Xmx500M -cp "./antlr4-4.12.0-complete.jar:$CLASSPATH" org.antlr.v4.Tool'
antlr4 -Dlanguage=Go -no-visitor -package parser *.g4
depending on whether you want visitors or listeners or any other ANTLR options.
From the command line at the root of your package “myproject” you can then simply issue the command:
go generate ./...
# Copyright Notice
Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
Use of this file is governed by the BSD 3-clause license, which can be found in the [LICENSE.txt] file in the project root.
[target languages]: https://github.com/antlr/antlr4/tree/master/runtime
[LICENSE.txt]: https://github.com/antlr/antlr4/blob/master/LICENSE.txt
*/
package antlr

View File

@@ -1,303 +0,0 @@
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
package antlr
import (
"fmt"
)
// ATNConfig is a tuple: (ATN state, predicted alt, syntactic, semantic
// context). The syntactic context is a graph-structured stack node whose
// path(s) to the root is the rule invocation(s) chain used to arrive at the
// state. The semantic context is the tree of semantic predicates encountered
// before reaching an ATN state.
type ATNConfig interface {
Equals(o Collectable[ATNConfig]) bool
Hash() int
GetState() ATNState
GetAlt() int
GetSemanticContext() SemanticContext
GetContext() PredictionContext
SetContext(PredictionContext)
GetReachesIntoOuterContext() int
SetReachesIntoOuterContext(int)
String() string
getPrecedenceFilterSuppressed() bool
setPrecedenceFilterSuppressed(bool)
}
type BaseATNConfig struct {
precedenceFilterSuppressed bool
state ATNState
alt int
context PredictionContext
semanticContext SemanticContext
reachesIntoOuterContext int
}
func NewBaseATNConfig7(old *BaseATNConfig) ATNConfig { // TODO: Dup
return &BaseATNConfig{
state: old.state,
alt: old.alt,
context: old.context,
semanticContext: old.semanticContext,
reachesIntoOuterContext: old.reachesIntoOuterContext,
}
}
func NewBaseATNConfig6(state ATNState, alt int, context PredictionContext) *BaseATNConfig {
return NewBaseATNConfig5(state, alt, context, SemanticContextNone)
}
func NewBaseATNConfig5(state ATNState, alt int, context PredictionContext, semanticContext SemanticContext) *BaseATNConfig {
if semanticContext == nil {
panic("semanticContext cannot be nil") // TODO: Necessary?
}
return &BaseATNConfig{state: state, alt: alt, context: context, semanticContext: semanticContext}
}
func NewBaseATNConfig4(c ATNConfig, state ATNState) *BaseATNConfig {
return NewBaseATNConfig(c, state, c.GetContext(), c.GetSemanticContext())
}
func NewBaseATNConfig3(c ATNConfig, state ATNState, semanticContext SemanticContext) *BaseATNConfig {
return NewBaseATNConfig(c, state, c.GetContext(), semanticContext)
}
func NewBaseATNConfig2(c ATNConfig, semanticContext SemanticContext) *BaseATNConfig {
return NewBaseATNConfig(c, c.GetState(), c.GetContext(), semanticContext)
}
func NewBaseATNConfig1(c ATNConfig, state ATNState, context PredictionContext) *BaseATNConfig {
return NewBaseATNConfig(c, state, context, c.GetSemanticContext())
}
func NewBaseATNConfig(c ATNConfig, state ATNState, context PredictionContext, semanticContext SemanticContext) *BaseATNConfig {
if semanticContext == nil {
panic("semanticContext cannot be nil")
}
return &BaseATNConfig{
state: state,
alt: c.GetAlt(),
context: context,
semanticContext: semanticContext,
reachesIntoOuterContext: c.GetReachesIntoOuterContext(),
precedenceFilterSuppressed: c.getPrecedenceFilterSuppressed(),
}
}
func (b *BaseATNConfig) getPrecedenceFilterSuppressed() bool {
return b.precedenceFilterSuppressed
}
func (b *BaseATNConfig) setPrecedenceFilterSuppressed(v bool) {
b.precedenceFilterSuppressed = v
}
func (b *BaseATNConfig) GetState() ATNState {
return b.state
}
func (b *BaseATNConfig) GetAlt() int {
return b.alt
}
func (b *BaseATNConfig) SetContext(v PredictionContext) {
b.context = v
}
func (b *BaseATNConfig) GetContext() PredictionContext {
return b.context
}
func (b *BaseATNConfig) GetSemanticContext() SemanticContext {
return b.semanticContext
}
func (b *BaseATNConfig) GetReachesIntoOuterContext() int {
return b.reachesIntoOuterContext
}
func (b *BaseATNConfig) SetReachesIntoOuterContext(v int) {
b.reachesIntoOuterContext = v
}
// Equals is the default comparison function for an ATNConfig when no specialist implementation is required
// for a collection.
//
// An ATN configuration is equal to another if both have the same state, they
// predict the same alternative, and syntactic/semantic contexts are the same.
func (b *BaseATNConfig) Equals(o Collectable[ATNConfig]) bool {
if b == o {
return true
} else if o == nil {
return false
}
var other, ok = o.(*BaseATNConfig)
if !ok {
return false
}
var equal bool
if b.context == nil {
equal = other.context == nil
} else {
equal = b.context.Equals(other.context)
}
var (
nums = b.state.GetStateNumber() == other.state.GetStateNumber()
alts = b.alt == other.alt
cons = b.semanticContext.Equals(other.semanticContext)
sups = b.precedenceFilterSuppressed == other.precedenceFilterSuppressed
)
return nums && alts && cons && sups && equal
}
// Hash is the default hash function for BaseATNConfig, when no specialist hash function
// is required for a collection
func (b *BaseATNConfig) Hash() int {
var c int
if b.context != nil {
c = b.context.Hash()
}
h := murmurInit(7)
h = murmurUpdate(h, b.state.GetStateNumber())
h = murmurUpdate(h, b.alt)
h = murmurUpdate(h, c)
h = murmurUpdate(h, b.semanticContext.Hash())
return murmurFinish(h, 4)
}
func (b *BaseATNConfig) String() string {
var s1, s2, s3 string
if b.context != nil {
s1 = ",[" + fmt.Sprint(b.context) + "]"
}
if b.semanticContext != SemanticContextNone {
s2 = "," + fmt.Sprint(b.semanticContext)
}
if b.reachesIntoOuterContext > 0 {
s3 = ",up=" + fmt.Sprint(b.reachesIntoOuterContext)
}
return fmt.Sprintf("(%v,%v%v%v%v)", b.state, b.alt, s1, s2, s3)
}
type LexerATNConfig struct {
*BaseATNConfig
lexerActionExecutor *LexerActionExecutor
passedThroughNonGreedyDecision bool
}
func NewLexerATNConfig6(state ATNState, alt int, context PredictionContext) *LexerATNConfig {
return &LexerATNConfig{BaseATNConfig: NewBaseATNConfig5(state, alt, context, SemanticContextNone)}
}
func NewLexerATNConfig5(state ATNState, alt int, context PredictionContext, lexerActionExecutor *LexerActionExecutor) *LexerATNConfig {
return &LexerATNConfig{
BaseATNConfig: NewBaseATNConfig5(state, alt, context, SemanticContextNone),
lexerActionExecutor: lexerActionExecutor,
}
}
func NewLexerATNConfig4(c *LexerATNConfig, state ATNState) *LexerATNConfig {
return &LexerATNConfig{
BaseATNConfig: NewBaseATNConfig(c, state, c.GetContext(), c.GetSemanticContext()),
lexerActionExecutor: c.lexerActionExecutor,
passedThroughNonGreedyDecision: checkNonGreedyDecision(c, state),
}
}
func NewLexerATNConfig3(c *LexerATNConfig, state ATNState, lexerActionExecutor *LexerActionExecutor) *LexerATNConfig {
return &LexerATNConfig{
BaseATNConfig: NewBaseATNConfig(c, state, c.GetContext(), c.GetSemanticContext()),
lexerActionExecutor: lexerActionExecutor,
passedThroughNonGreedyDecision: checkNonGreedyDecision(c, state),
}
}
func NewLexerATNConfig2(c *LexerATNConfig, state ATNState, context PredictionContext) *LexerATNConfig {
return &LexerATNConfig{
BaseATNConfig: NewBaseATNConfig(c, state, context, c.GetSemanticContext()),
lexerActionExecutor: c.lexerActionExecutor,
passedThroughNonGreedyDecision: checkNonGreedyDecision(c, state),
}
}
func NewLexerATNConfig1(state ATNState, alt int, context PredictionContext) *LexerATNConfig {
return &LexerATNConfig{BaseATNConfig: NewBaseATNConfig5(state, alt, context, SemanticContextNone)}
}
// Hash is the default hash function for LexerATNConfig objects, it can be used directly or via
// the default comparator [ObjEqComparator].
func (l *LexerATNConfig) Hash() int {
var f int
if l.passedThroughNonGreedyDecision {
f = 1
} else {
f = 0
}
h := murmurInit(7)
h = murmurUpdate(h, l.state.GetStateNumber())
h = murmurUpdate(h, l.alt)
h = murmurUpdate(h, l.context.Hash())
h = murmurUpdate(h, l.semanticContext.Hash())
h = murmurUpdate(h, f)
h = murmurUpdate(h, l.lexerActionExecutor.Hash())
h = murmurFinish(h, 6)
return h
}
// Equals is the default comparison function for LexerATNConfig objects, it can be used directly or via
// the default comparator [ObjEqComparator].
func (l *LexerATNConfig) Equals(other Collectable[ATNConfig]) bool {
if l == other {
return true
}
var othert, ok = other.(*LexerATNConfig)
if l == other {
return true
} else if !ok {
return false
} else if l.passedThroughNonGreedyDecision != othert.passedThroughNonGreedyDecision {
return false
}
var b bool
if l.lexerActionExecutor != nil {
b = !l.lexerActionExecutor.Equals(othert.lexerActionExecutor)
} else {
b = othert.lexerActionExecutor != nil
}
if b {
return false
}
return l.BaseATNConfig.Equals(othert.BaseATNConfig)
}
func checkNonGreedyDecision(source *LexerATNConfig, target ATNState) bool {
var ds, ok = target.(DecisionState)
return source.passedThroughNonGreedyDecision || (ok && ds.getNonGreedy())
}

View File

@@ -1,441 +0,0 @@
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
package antlr
import (
"fmt"
)
type ATNConfigSet interface {
Hash() int
Equals(o Collectable[ATNConfig]) bool
Add(ATNConfig, *DoubleDict) bool
AddAll([]ATNConfig) bool
GetStates() *JStore[ATNState, Comparator[ATNState]]
GetPredicates() []SemanticContext
GetItems() []ATNConfig
OptimizeConfigs(interpreter *BaseATNSimulator)
Length() int
IsEmpty() bool
Contains(ATNConfig) bool
ContainsFast(ATNConfig) bool
Clear()
String() string
HasSemanticContext() bool
SetHasSemanticContext(v bool)
ReadOnly() bool
SetReadOnly(bool)
GetConflictingAlts() *BitSet
SetConflictingAlts(*BitSet)
Alts() *BitSet
FullContext() bool
GetUniqueAlt() int
SetUniqueAlt(int)
GetDipsIntoOuterContext() bool
SetDipsIntoOuterContext(bool)
}
// BaseATNConfigSet is a specialized set of ATNConfig that tracks information
// about its elements and can combine similar configurations using a
// graph-structured stack.
type BaseATNConfigSet struct {
cachedHash int
// configLookup is used to determine whether two BaseATNConfigSets are equal. We
// need all configurations with the same (s, i, _, semctx) to be equal. A key
// effectively doubles the number of objects associated with ATNConfigs. All
// keys are hashed by (s, i, _, pi), not including the context. Wiped out when
// read-only because a set becomes a DFA state.
configLookup *JStore[ATNConfig, Comparator[ATNConfig]]
// configs is the added elements.
configs []ATNConfig
// TODO: These fields make me pretty uncomfortable, but it is nice to pack up
// info together because it saves recomputation. Can we track conflicts as they
// are added to save scanning configs later?
conflictingAlts *BitSet
// dipsIntoOuterContext is used by parsers and lexers. In a lexer, it indicates
// we hit a pred while computing a closure operation. Do not make a DFA state
// from the BaseATNConfigSet in this case. TODO: How is this used by parsers?
dipsIntoOuterContext bool
// fullCtx is whether it is part of a full context LL prediction. Used to
// determine how to merge $. It is a wildcard with SLL, but not for an LL
// context merge.
fullCtx bool
// Used in parser and lexer. In lexer, it indicates we hit a pred
// while computing a closure operation. Don't make a DFA state from a.
hasSemanticContext bool
// readOnly is whether it is read-only. Do not
// allow any code to manipulate the set if true because DFA states will point at
// sets and those must not change. It not, protect other fields; conflictingAlts
// in particular, which is assigned after readOnly.
readOnly bool
// TODO: These fields make me pretty uncomfortable, but it is nice to pack up
// info together because it saves recomputation. Can we track conflicts as they
// are added to save scanning configs later?
uniqueAlt int
}
func (b *BaseATNConfigSet) Alts() *BitSet {
alts := NewBitSet()
for _, it := range b.configs {
alts.add(it.GetAlt())
}
return alts
}
func NewBaseATNConfigSet(fullCtx bool) *BaseATNConfigSet {
return &BaseATNConfigSet{
cachedHash: -1,
configLookup: NewJStore[ATNConfig, Comparator[ATNConfig]](aConfCompInst),
fullCtx: fullCtx,
}
}
// Add merges contexts with existing configs for (s, i, pi, _), where s is the
// ATNConfig.state, i is the ATNConfig.alt, and pi is the
// ATNConfig.semanticContext. We use (s,i,pi) as the key. Updates
// dipsIntoOuterContext and hasSemanticContext when necessary.
func (b *BaseATNConfigSet) Add(config ATNConfig, mergeCache *DoubleDict) bool {
if b.readOnly {
panic("set is read-only")
}
if config.GetSemanticContext() != SemanticContextNone {
b.hasSemanticContext = true
}
if config.GetReachesIntoOuterContext() > 0 {
b.dipsIntoOuterContext = true
}
existing, present := b.configLookup.Put(config)
// The config was not already in the set
//
if !present {
b.cachedHash = -1
b.configs = append(b.configs, config) // Track order here
return true
}
// Merge a previous (s, i, pi, _) with it and save the result
rootIsWildcard := !b.fullCtx
merged := merge(existing.GetContext(), config.GetContext(), rootIsWildcard, mergeCache)
// No need to check for existing.context because config.context is in the cache,
// since the only way to create new graphs is the "call rule" and here. We cache
// at both places.
existing.SetReachesIntoOuterContext(intMax(existing.GetReachesIntoOuterContext(), config.GetReachesIntoOuterContext()))
// Preserve the precedence filter suppression during the merge
if config.getPrecedenceFilterSuppressed() {
existing.setPrecedenceFilterSuppressed(true)
}
// Replace the context because there is no need to do alt mapping
existing.SetContext(merged)
return true
}
func (b *BaseATNConfigSet) GetStates() *JStore[ATNState, Comparator[ATNState]] {
// states uses the standard comparator provided by the ATNState instance
//
states := NewJStore[ATNState, Comparator[ATNState]](aStateEqInst)
for i := 0; i < len(b.configs); i++ {
states.Put(b.configs[i].GetState())
}
return states
}
func (b *BaseATNConfigSet) HasSemanticContext() bool {
return b.hasSemanticContext
}
func (b *BaseATNConfigSet) SetHasSemanticContext(v bool) {
b.hasSemanticContext = v
}
func (b *BaseATNConfigSet) GetPredicates() []SemanticContext {
preds := make([]SemanticContext, 0)
for i := 0; i < len(b.configs); i++ {
c := b.configs[i].GetSemanticContext()
if c != SemanticContextNone {
preds = append(preds, c)
}
}
return preds
}
func (b *BaseATNConfigSet) GetItems() []ATNConfig {
return b.configs
}
func (b *BaseATNConfigSet) OptimizeConfigs(interpreter *BaseATNSimulator) {
if b.readOnly {
panic("set is read-only")
}
if b.configLookup.Len() == 0 {
return
}
for i := 0; i < len(b.configs); i++ {
config := b.configs[i]
config.SetContext(interpreter.getCachedContext(config.GetContext()))
}
}
func (b *BaseATNConfigSet) AddAll(coll []ATNConfig) bool {
for i := 0; i < len(coll); i++ {
b.Add(coll[i], nil)
}
return false
}
// Compare is a hack function just to verify that adding DFAstares to the known
// set works, so long as comparison of ATNConfigSet s works. For that to work, we
// need to make sure that the set of ATNConfigs in two sets are equivalent. We can't
// know the order, so we do this inefficient hack. If this proves the point, then
// we can change the config set to a better structure.
func (b *BaseATNConfigSet) Compare(bs *BaseATNConfigSet) bool {
if len(b.configs) != len(bs.configs) {
return false
}
for _, c := range b.configs {
found := false
for _, c2 := range bs.configs {
if c.Equals(c2) {
found = true
break
}
}
if !found {
return false
}
}
return true
}
func (b *BaseATNConfigSet) Equals(other Collectable[ATNConfig]) bool {
if b == other {
return true
} else if _, ok := other.(*BaseATNConfigSet); !ok {
return false
}
other2 := other.(*BaseATNConfigSet)
return b.configs != nil &&
b.fullCtx == other2.fullCtx &&
b.uniqueAlt == other2.uniqueAlt &&
b.conflictingAlts == other2.conflictingAlts &&
b.hasSemanticContext == other2.hasSemanticContext &&
b.dipsIntoOuterContext == other2.dipsIntoOuterContext &&
b.Compare(other2)
}
func (b *BaseATNConfigSet) Hash() int {
if b.readOnly {
if b.cachedHash == -1 {
b.cachedHash = b.hashCodeConfigs()
}
return b.cachedHash
}
return b.hashCodeConfigs()
}
func (b *BaseATNConfigSet) hashCodeConfigs() int {
h := 1
for _, config := range b.configs {
h = 31*h + config.Hash()
}
return h
}
func (b *BaseATNConfigSet) Length() int {
return len(b.configs)
}
func (b *BaseATNConfigSet) IsEmpty() bool {
return len(b.configs) == 0
}
func (b *BaseATNConfigSet) Contains(item ATNConfig) bool {
if b.configLookup == nil {
panic("not implemented for read-only sets")
}
return b.configLookup.Contains(item)
}
func (b *BaseATNConfigSet) ContainsFast(item ATNConfig) bool {
if b.configLookup == nil {
panic("not implemented for read-only sets")
}
return b.configLookup.Contains(item) // TODO: containsFast is not implemented for Set
}
func (b *BaseATNConfigSet) Clear() {
if b.readOnly {
panic("set is read-only")
}
b.configs = make([]ATNConfig, 0)
b.cachedHash = -1
b.configLookup = NewJStore[ATNConfig, Comparator[ATNConfig]](atnConfCompInst)
}
func (b *BaseATNConfigSet) FullContext() bool {
return b.fullCtx
}
func (b *BaseATNConfigSet) GetDipsIntoOuterContext() bool {
return b.dipsIntoOuterContext
}
func (b *BaseATNConfigSet) SetDipsIntoOuterContext(v bool) {
b.dipsIntoOuterContext = v
}
func (b *BaseATNConfigSet) GetUniqueAlt() int {
return b.uniqueAlt
}
func (b *BaseATNConfigSet) SetUniqueAlt(v int) {
b.uniqueAlt = v
}
func (b *BaseATNConfigSet) GetConflictingAlts() *BitSet {
return b.conflictingAlts
}
func (b *BaseATNConfigSet) SetConflictingAlts(v *BitSet) {
b.conflictingAlts = v
}
func (b *BaseATNConfigSet) ReadOnly() bool {
return b.readOnly
}
func (b *BaseATNConfigSet) SetReadOnly(readOnly bool) {
b.readOnly = readOnly
if readOnly {
b.configLookup = nil // Read only, so no need for the lookup cache
}
}
func (b *BaseATNConfigSet) String() string {
s := "["
for i, c := range b.configs {
s += c.String()
if i != len(b.configs)-1 {
s += ", "
}
}
s += "]"
if b.hasSemanticContext {
s += ",hasSemanticContext=" + fmt.Sprint(b.hasSemanticContext)
}
if b.uniqueAlt != ATNInvalidAltNumber {
s += ",uniqueAlt=" + fmt.Sprint(b.uniqueAlt)
}
if b.conflictingAlts != nil {
s += ",conflictingAlts=" + b.conflictingAlts.String()
}
if b.dipsIntoOuterContext {
s += ",dipsIntoOuterContext"
}
return s
}
type OrderedATNConfigSet struct {
*BaseATNConfigSet
}
func NewOrderedATNConfigSet() *OrderedATNConfigSet {
b := NewBaseATNConfigSet(false)
// This set uses the standard Hash() and Equals() from ATNConfig
b.configLookup = NewJStore[ATNConfig, Comparator[ATNConfig]](aConfEqInst)
return &OrderedATNConfigSet{BaseATNConfigSet: b}
}
func hashATNConfig(i interface{}) int {
o := i.(ATNConfig)
hash := 7
hash = 31*hash + o.GetState().GetStateNumber()
hash = 31*hash + o.GetAlt()
hash = 31*hash + o.GetSemanticContext().Hash()
return hash
}
func equalATNConfigs(a, b interface{}) bool {
if a == nil || b == nil {
return false
}
if a == b {
return true
}
var ai, ok = a.(ATNConfig)
var bi, ok1 = b.(ATNConfig)
if !ok || !ok1 {
return false
}
if ai.GetState().GetStateNumber() != bi.GetState().GetStateNumber() {
return false
}
if ai.GetAlt() != bi.GetAlt() {
return false
}
return ai.GetSemanticContext().Equals(bi.GetSemanticContext())
}

View File

@@ -1,113 +0,0 @@
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
package antlr
type InputStream struct {
name string
index int
data []rune
size int
}
func NewInputStream(data string) *InputStream {
is := new(InputStream)
is.name = "<empty>"
is.index = 0
is.data = []rune(data)
is.size = len(is.data) // number of runes
return is
}
func (is *InputStream) reset() {
is.index = 0
}
func (is *InputStream) Consume() {
if is.index >= is.size {
// assert is.LA(1) == TokenEOF
panic("cannot consume EOF")
}
is.index++
}
func (is *InputStream) LA(offset int) int {
if offset == 0 {
return 0 // nil
}
if offset < 0 {
offset++ // e.g., translate LA(-1) to use offset=0
}
pos := is.index + offset - 1
if pos < 0 || pos >= is.size { // invalid
return TokenEOF
}
return int(is.data[pos])
}
func (is *InputStream) LT(offset int) int {
return is.LA(offset)
}
func (is *InputStream) Index() int {
return is.index
}
func (is *InputStream) Size() int {
return is.size
}
// mark/release do nothing we have entire buffer
func (is *InputStream) Mark() int {
return -1
}
func (is *InputStream) Release(marker int) {
}
func (is *InputStream) Seek(index int) {
if index <= is.index {
is.index = index // just jump don't update stream state (line,...)
return
}
// seek forward
is.index = intMin(index, is.size)
}
func (is *InputStream) GetText(start int, stop int) string {
if stop >= is.size {
stop = is.size - 1
}
if start >= is.size {
return ""
}
return string(is.data[start : stop+1])
}
func (is *InputStream) GetTextFromTokens(start, stop Token) string {
if start != nil && stop != nil {
return is.GetTextFromInterval(NewInterval(start.GetTokenIndex(), stop.GetTokenIndex()))
}
return ""
}
func (is *InputStream) GetTextFromInterval(i *Interval) string {
return is.GetText(i.Start, i.Stop)
}
func (*InputStream) GetSourceName() string {
return "Obtained from string"
}
func (is *InputStream) String() string {
return string(is.data)
}

View File

@@ -1,198 +0,0 @@
package antlr
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
import (
"sort"
)
// Collectable is an interface that a struct should implement if it is to be
// usable as a key in these collections.
type Collectable[T any] interface {
Hash() int
Equals(other Collectable[T]) bool
}
type Comparator[T any] interface {
Hash1(o T) int
Equals2(T, T) bool
}
// JStore implements a container that allows the use of a struct to calculate the key
// for a collection of values akin to map. This is not meant to be a full-blown HashMap but just
// serve the needs of the ANTLR Go runtime.
//
// For ease of porting the logic of the runtime from the master target (Java), this collection
// operates in a similar way to Java, in that it can use any struct that supplies a Hash() and Equals()
// function as the key. The values are stored in a standard go map which internally is a form of hashmap
// itself, the key for the go map is the hash supplied by the key object. The collection is able to deal with
// hash conflicts by using a simple slice of values associated with the hash code indexed bucket. That isn't
// particularly efficient, but it is simple, and it works. As this is specifically for the ANTLR runtime, and
// we understand the requirements, then this is fine - this is not a general purpose collection.
type JStore[T any, C Comparator[T]] struct {
store map[int][]T
len int
comparator Comparator[T]
}
func NewJStore[T any, C Comparator[T]](comparator Comparator[T]) *JStore[T, C] {
if comparator == nil {
panic("comparator cannot be nil")
}
s := &JStore[T, C]{
store: make(map[int][]T, 1),
comparator: comparator,
}
return s
}
// Put will store given value in the collection. Note that the key for storage is generated from
// the value itself - this is specifically because that is what ANTLR needs - this would not be useful
// as any kind of general collection.
//
// If the key has a hash conflict, then the value will be added to the slice of values associated with the
// hash, unless the value is already in the slice, in which case the existing value is returned. Value equivalence is
// tested by calling the equals() method on the key.
//
// # If the given value is already present in the store, then the existing value is returned as v and exists is set to true
//
// If the given value is not present in the store, then the value is added to the store and returned as v and exists is set to false.
func (s *JStore[T, C]) Put(value T) (v T, exists bool) { //nolint:ireturn
kh := s.comparator.Hash1(value)
for _, v1 := range s.store[kh] {
if s.comparator.Equals2(value, v1) {
return v1, true
}
}
s.store[kh] = append(s.store[kh], value)
s.len++
return value, false
}
// Get will return the value associated with the key - the type of the key is the same type as the value
// which would not generally be useful, but this is a specific thing for ANTLR where the key is
// generated using the object we are going to store.
func (s *JStore[T, C]) Get(key T) (T, bool) { //nolint:ireturn
kh := s.comparator.Hash1(key)
for _, v := range s.store[kh] {
if s.comparator.Equals2(key, v) {
return v, true
}
}
return key, false
}
// Contains returns true if the given key is present in the store
func (s *JStore[T, C]) Contains(key T) bool { //nolint:ireturn
_, present := s.Get(key)
return present
}
func (s *JStore[T, C]) SortedSlice(less func(i, j T) bool) []T {
vs := make([]T, 0, len(s.store))
for _, v := range s.store {
vs = append(vs, v...)
}
sort.Slice(vs, func(i, j int) bool {
return less(vs[i], vs[j])
})
return vs
}
func (s *JStore[T, C]) Each(f func(T) bool) {
for _, e := range s.store {
for _, v := range e {
f(v)
}
}
}
func (s *JStore[T, C]) Len() int {
return s.len
}
func (s *JStore[T, C]) Values() []T {
vs := make([]T, 0, len(s.store))
for _, e := range s.store {
for _, v := range e {
vs = append(vs, v)
}
}
return vs
}
type entry[K, V any] struct {
key K
val V
}
type JMap[K, V any, C Comparator[K]] struct {
store map[int][]*entry[K, V]
len int
comparator Comparator[K]
}
func NewJMap[K, V any, C Comparator[K]](comparator Comparator[K]) *JMap[K, V, C] {
return &JMap[K, V, C]{
store: make(map[int][]*entry[K, V], 1),
comparator: comparator,
}
}
func (m *JMap[K, V, C]) Put(key K, val V) {
kh := m.comparator.Hash1(key)
m.store[kh] = append(m.store[kh], &entry[K, V]{key, val})
m.len++
}
func (m *JMap[K, V, C]) Values() []V {
vs := make([]V, 0, len(m.store))
for _, e := range m.store {
for _, v := range e {
vs = append(vs, v.val)
}
}
return vs
}
func (m *JMap[K, V, C]) Get(key K) (V, bool) {
var none V
kh := m.comparator.Hash1(key)
for _, e := range m.store[kh] {
if m.comparator.Equals2(e.key, key) {
return e.val, true
}
}
return none, false
}
func (m *JMap[K, V, C]) Len() int {
return len(m.store)
}
func (m *JMap[K, V, C]) Delete(key K) {
kh := m.comparator.Hash1(key)
for i, e := range m.store[kh] {
if m.comparator.Equals2(e.key, key) {
m.store[kh] = append(m.store[kh][:i], m.store[kh][i+1:]...)
m.len--
return
}
}
}
func (m *JMap[K, V, C]) Clear() {
m.store = make(map[int][]*entry[K, V])
}

View File

@@ -1,806 +0,0 @@
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
package antlr
import (
"fmt"
"golang.org/x/exp/slices"
"strconv"
)
// Represents {@code $} in local context prediction, which means wildcard.
// {@code//+x =//}.
// /
const (
BasePredictionContextEmptyReturnState = 0x7FFFFFFF
)
// Represents {@code $} in an array in full context mode, when {@code $}
// doesn't mean wildcard: {@code $ + x = [$,x]}. Here,
// {@code $} = {@link //EmptyReturnState}.
// /
var (
BasePredictionContextglobalNodeCount = 1
BasePredictionContextid = BasePredictionContextglobalNodeCount
)
type PredictionContext interface {
Hash() int
Equals(interface{}) bool
GetParent(int) PredictionContext
getReturnState(int) int
length() int
isEmpty() bool
hasEmptyPath() bool
String() string
}
type BasePredictionContext struct {
cachedHash int
}
func NewBasePredictionContext(cachedHash int) *BasePredictionContext {
pc := new(BasePredictionContext)
pc.cachedHash = cachedHash
return pc
}
func (b *BasePredictionContext) isEmpty() bool {
return false
}
func calculateHash(parent PredictionContext, returnState int) int {
h := murmurInit(1)
h = murmurUpdate(h, parent.Hash())
h = murmurUpdate(h, returnState)
return murmurFinish(h, 2)
}
var _emptyPredictionContextHash int
func init() {
_emptyPredictionContextHash = murmurInit(1)
_emptyPredictionContextHash = murmurFinish(_emptyPredictionContextHash, 0)
}
func calculateEmptyHash() int {
return _emptyPredictionContextHash
}
// Used to cache {@link BasePredictionContext} objects. Its used for the shared
// context cash associated with contexts in DFA states. This cache
// can be used for both lexers and parsers.
type PredictionContextCache struct {
cache map[PredictionContext]PredictionContext
}
func NewPredictionContextCache() *PredictionContextCache {
t := new(PredictionContextCache)
t.cache = make(map[PredictionContext]PredictionContext)
return t
}
// Add a context to the cache and return it. If the context already exists,
// return that one instead and do not add a Newcontext to the cache.
// Protect shared cache from unsafe thread access.
func (p *PredictionContextCache) add(ctx PredictionContext) PredictionContext {
if ctx == BasePredictionContextEMPTY {
return BasePredictionContextEMPTY
}
existing := p.cache[ctx]
if existing != nil {
return existing
}
p.cache[ctx] = ctx
return ctx
}
func (p *PredictionContextCache) Get(ctx PredictionContext) PredictionContext {
return p.cache[ctx]
}
func (p *PredictionContextCache) length() int {
return len(p.cache)
}
type SingletonPredictionContext interface {
PredictionContext
}
type BaseSingletonPredictionContext struct {
*BasePredictionContext
parentCtx PredictionContext
returnState int
}
func NewBaseSingletonPredictionContext(parent PredictionContext, returnState int) *BaseSingletonPredictionContext {
var cachedHash int
if parent != nil {
cachedHash = calculateHash(parent, returnState)
} else {
cachedHash = calculateEmptyHash()
}
s := new(BaseSingletonPredictionContext)
s.BasePredictionContext = NewBasePredictionContext(cachedHash)
s.parentCtx = parent
s.returnState = returnState
return s
}
func SingletonBasePredictionContextCreate(parent PredictionContext, returnState int) PredictionContext {
if returnState == BasePredictionContextEmptyReturnState && parent == nil {
// someone can pass in the bits of an array ctx that mean $
return BasePredictionContextEMPTY
}
return NewBaseSingletonPredictionContext(parent, returnState)
}
func (b *BaseSingletonPredictionContext) length() int {
return 1
}
func (b *BaseSingletonPredictionContext) GetParent(index int) PredictionContext {
return b.parentCtx
}
func (b *BaseSingletonPredictionContext) getReturnState(index int) int {
return b.returnState
}
func (b *BaseSingletonPredictionContext) hasEmptyPath() bool {
return b.returnState == BasePredictionContextEmptyReturnState
}
func (b *BaseSingletonPredictionContext) Hash() int {
return b.cachedHash
}
func (b *BaseSingletonPredictionContext) Equals(other interface{}) bool {
if b == other {
return true
}
if _, ok := other.(*BaseSingletonPredictionContext); !ok {
return false
}
otherP := other.(*BaseSingletonPredictionContext)
if b.returnState != otherP.getReturnState(0) {
return false
}
if b.parentCtx == nil {
return otherP.parentCtx == nil
}
return b.parentCtx.Equals(otherP.parentCtx)
}
func (b *BaseSingletonPredictionContext) String() string {
var up string
if b.parentCtx == nil {
up = ""
} else {
up = b.parentCtx.String()
}
if len(up) == 0 {
if b.returnState == BasePredictionContextEmptyReturnState {
return "$"
}
return strconv.Itoa(b.returnState)
}
return strconv.Itoa(b.returnState) + " " + up
}
var BasePredictionContextEMPTY = NewEmptyPredictionContext()
type EmptyPredictionContext struct {
*BaseSingletonPredictionContext
}
func NewEmptyPredictionContext() *EmptyPredictionContext {
p := new(EmptyPredictionContext)
p.BaseSingletonPredictionContext = NewBaseSingletonPredictionContext(nil, BasePredictionContextEmptyReturnState)
p.cachedHash = calculateEmptyHash()
return p
}
func (e *EmptyPredictionContext) isEmpty() bool {
return true
}
func (e *EmptyPredictionContext) GetParent(index int) PredictionContext {
return nil
}
func (e *EmptyPredictionContext) getReturnState(index int) int {
return e.returnState
}
func (e *EmptyPredictionContext) Hash() int {
return e.cachedHash
}
func (e *EmptyPredictionContext) Equals(other interface{}) bool {
return e == other
}
func (e *EmptyPredictionContext) String() string {
return "$"
}
type ArrayPredictionContext struct {
*BasePredictionContext
parents []PredictionContext
returnStates []int
}
func NewArrayPredictionContext(parents []PredictionContext, returnStates []int) *ArrayPredictionContext {
// Parent can be nil only if full ctx mode and we make an array
// from {@link //EMPTY} and non-empty. We merge {@link //EMPTY} by using
// nil parent and
// returnState == {@link //EmptyReturnState}.
hash := murmurInit(1)
for _, parent := range parents {
hash = murmurUpdate(hash, parent.Hash())
}
for _, returnState := range returnStates {
hash = murmurUpdate(hash, returnState)
}
hash = murmurFinish(hash, len(parents)<<1)
c := new(ArrayPredictionContext)
c.BasePredictionContext = NewBasePredictionContext(hash)
c.parents = parents
c.returnStates = returnStates
return c
}
func (a *ArrayPredictionContext) GetReturnStates() []int {
return a.returnStates
}
func (a *ArrayPredictionContext) hasEmptyPath() bool {
return a.getReturnState(a.length()-1) == BasePredictionContextEmptyReturnState
}
func (a *ArrayPredictionContext) isEmpty() bool {
// since EmptyReturnState can only appear in the last position, we
// don't need to verify that size==1
return a.returnStates[0] == BasePredictionContextEmptyReturnState
}
func (a *ArrayPredictionContext) length() int {
return len(a.returnStates)
}
func (a *ArrayPredictionContext) GetParent(index int) PredictionContext {
return a.parents[index]
}
func (a *ArrayPredictionContext) getReturnState(index int) int {
return a.returnStates[index]
}
// Equals is the default comparison function for ArrayPredictionContext when no specialized
// implementation is needed for a collection
func (a *ArrayPredictionContext) Equals(o interface{}) bool {
if a == o {
return true
}
other, ok := o.(*ArrayPredictionContext)
if !ok {
return false
}
if a.cachedHash != other.Hash() {
return false // can't be same if hash is different
}
// Must compare the actual array elements and not just the array address
//
return slices.Equal(a.returnStates, other.returnStates) &&
slices.EqualFunc(a.parents, other.parents, func(x, y PredictionContext) bool {
return x.Equals(y)
})
}
// Hash is the default hash function for ArrayPredictionContext when no specialized
// implementation is needed for a collection
func (a *ArrayPredictionContext) Hash() int {
return a.BasePredictionContext.cachedHash
}
func (a *ArrayPredictionContext) String() string {
if a.isEmpty() {
return "[]"
}
s := "["
for i := 0; i < len(a.returnStates); i++ {
if i > 0 {
s = s + ", "
}
if a.returnStates[i] == BasePredictionContextEmptyReturnState {
s = s + "$"
continue
}
s = s + strconv.Itoa(a.returnStates[i])
if a.parents[i] != nil {
s = s + " " + a.parents[i].String()
} else {
s = s + "nil"
}
}
return s + "]"
}
// Convert a {@link RuleContext} tree to a {@link BasePredictionContext} graph.
// Return {@link //EMPTY} if {@code outerContext} is empty or nil.
// /
func predictionContextFromRuleContext(a *ATN, outerContext RuleContext) PredictionContext {
if outerContext == nil {
outerContext = ParserRuleContextEmpty
}
// if we are in RuleContext of start rule, s, then BasePredictionContext
// is EMPTY. Nobody called us. (if we are empty, return empty)
if outerContext.GetParent() == nil || outerContext == ParserRuleContextEmpty {
return BasePredictionContextEMPTY
}
// If we have a parent, convert it to a BasePredictionContext graph
parent := predictionContextFromRuleContext(a, outerContext.GetParent().(RuleContext))
state := a.states[outerContext.GetInvokingState()]
transition := state.GetTransitions()[0]
return SingletonBasePredictionContextCreate(parent, transition.(*RuleTransition).followState.GetStateNumber())
}
func merge(a, b PredictionContext, rootIsWildcard bool, mergeCache *DoubleDict) PredictionContext {
// Share same graph if both same
//
if a == b || a.Equals(b) {
return a
}
// In Java, EmptyPredictionContext inherits from SingletonPredictionContext, and so the test
// in java for SingletonPredictionContext will succeed and a new ArrayPredictionContext will be created
// from it.
// In go, EmptyPredictionContext does not equate to SingletonPredictionContext and so that conversion
// will fail. We need to test for both Empty and Singleton and create an ArrayPredictionContext from
// either of them.
ac, ok1 := a.(*BaseSingletonPredictionContext)
bc, ok2 := b.(*BaseSingletonPredictionContext)
if ok1 && ok2 {
return mergeSingletons(ac, bc, rootIsWildcard, mergeCache)
}
// At least one of a or b is array
// If one is $ and rootIsWildcard, return $ as// wildcard
if rootIsWildcard {
if _, ok := a.(*EmptyPredictionContext); ok {
return a
}
if _, ok := b.(*EmptyPredictionContext); ok {
return b
}
}
// Convert Singleton or Empty so both are arrays to normalize - We should not use the existing parameters
// here.
//
// TODO: I think that maybe the Prediction Context structs should be redone as there is a chance we will see this mess again - maybe redo the logic here
var arp, arb *ArrayPredictionContext
var ok bool
if arp, ok = a.(*ArrayPredictionContext); ok {
} else if _, ok = a.(*BaseSingletonPredictionContext); ok {
arp = NewArrayPredictionContext([]PredictionContext{a.GetParent(0)}, []int{a.getReturnState(0)})
} else if _, ok = a.(*EmptyPredictionContext); ok {
arp = NewArrayPredictionContext([]PredictionContext{}, []int{})
}
if arb, ok = b.(*ArrayPredictionContext); ok {
} else if _, ok = b.(*BaseSingletonPredictionContext); ok {
arb = NewArrayPredictionContext([]PredictionContext{b.GetParent(0)}, []int{b.getReturnState(0)})
} else if _, ok = b.(*EmptyPredictionContext); ok {
arb = NewArrayPredictionContext([]PredictionContext{}, []int{})
}
// Both arp and arb
return mergeArrays(arp, arb, rootIsWildcard, mergeCache)
}
// Merge two {@link SingletonBasePredictionContext} instances.
//
// <p>Stack tops equal, parents merge is same return left graph.<br>
// <embed src="images/SingletonMerge_SameRootSamePar.svg"
// type="image/svg+xml"/></p>
//
// <p>Same stack top, parents differ merge parents giving array node, then
// remainders of those graphs. A Newroot node is created to point to the
// merged parents.<br>
// <embed src="images/SingletonMerge_SameRootDiffPar.svg"
// type="image/svg+xml"/></p>
//
// <p>Different stack tops pointing to same parent. Make array node for the
// root where both element in the root point to the same (original)
// parent.<br>
// <embed src="images/SingletonMerge_DiffRootSamePar.svg"
// type="image/svg+xml"/></p>
//
// <p>Different stack tops pointing to different parents. Make array node for
// the root where each element points to the corresponding original
// parent.<br>
// <embed src="images/SingletonMerge_DiffRootDiffPar.svg"
// type="image/svg+xml"/></p>
//
// @param a the first {@link SingletonBasePredictionContext}
// @param b the second {@link SingletonBasePredictionContext}
// @param rootIsWildcard {@code true} if this is a local-context merge,
// otherwise false to indicate a full-context merge
// @param mergeCache
// /
func mergeSingletons(a, b *BaseSingletonPredictionContext, rootIsWildcard bool, mergeCache *DoubleDict) PredictionContext {
if mergeCache != nil {
previous := mergeCache.Get(a.Hash(), b.Hash())
if previous != nil {
return previous.(PredictionContext)
}
previous = mergeCache.Get(b.Hash(), a.Hash())
if previous != nil {
return previous.(PredictionContext)
}
}
rootMerge := mergeRoot(a, b, rootIsWildcard)
if rootMerge != nil {
if mergeCache != nil {
mergeCache.set(a.Hash(), b.Hash(), rootMerge)
}
return rootMerge
}
if a.returnState == b.returnState {
parent := merge(a.parentCtx, b.parentCtx, rootIsWildcard, mergeCache)
// if parent is same as existing a or b parent or reduced to a parent,
// return it
if parent == a.parentCtx {
return a // ax + bx = ax, if a=b
}
if parent == b.parentCtx {
return b // ax + bx = bx, if a=b
}
// else: ax + ay = a'[x,y]
// merge parents x and y, giving array node with x,y then remainders
// of those graphs. dup a, a' points at merged array
// Newjoined parent so create Newsingleton pointing to it, a'
spc := SingletonBasePredictionContextCreate(parent, a.returnState)
if mergeCache != nil {
mergeCache.set(a.Hash(), b.Hash(), spc)
}
return spc
}
// a != b payloads differ
// see if we can collapse parents due to $+x parents if local ctx
var singleParent PredictionContext
if a == b || (a.parentCtx != nil && a.parentCtx == b.parentCtx) { // ax +
// bx =
// [a,b]x
singleParent = a.parentCtx
}
if singleParent != nil { // parents are same
// sort payloads and use same parent
payloads := []int{a.returnState, b.returnState}
if a.returnState > b.returnState {
payloads[0] = b.returnState
payloads[1] = a.returnState
}
parents := []PredictionContext{singleParent, singleParent}
apc := NewArrayPredictionContext(parents, payloads)
if mergeCache != nil {
mergeCache.set(a.Hash(), b.Hash(), apc)
}
return apc
}
// parents differ and can't merge them. Just pack together
// into array can't merge.
// ax + by = [ax,by]
payloads := []int{a.returnState, b.returnState}
parents := []PredictionContext{a.parentCtx, b.parentCtx}
if a.returnState > b.returnState { // sort by payload
payloads[0] = b.returnState
payloads[1] = a.returnState
parents = []PredictionContext{b.parentCtx, a.parentCtx}
}
apc := NewArrayPredictionContext(parents, payloads)
if mergeCache != nil {
mergeCache.set(a.Hash(), b.Hash(), apc)
}
return apc
}
// Handle case where at least one of {@code a} or {@code b} is
// {@link //EMPTY}. In the following diagrams, the symbol {@code $} is used
// to represent {@link //EMPTY}.
//
// <h2>Local-Context Merges</h2>
//
// <p>These local-context merge operations are used when {@code rootIsWildcard}
// is true.</p>
//
// <p>{@link //EMPTY} is superset of any graph return {@link //EMPTY}.<br>
// <embed src="images/LocalMerge_EmptyRoot.svg" type="image/svg+xml"/></p>
//
// <p>{@link //EMPTY} and anything is {@code //EMPTY}, so merged parent is
// {@code //EMPTY} return left graph.<br>
// <embed src="images/LocalMerge_EmptyParent.svg" type="image/svg+xml"/></p>
//
// <p>Special case of last merge if local context.<br>
// <embed src="images/LocalMerge_DiffRoots.svg" type="image/svg+xml"/></p>
//
// <h2>Full-Context Merges</h2>
//
// <p>These full-context merge operations are used when {@code rootIsWildcard}
// is false.</p>
//
// <p><embed src="images/FullMerge_EmptyRoots.svg" type="image/svg+xml"/></p>
//
// <p>Must keep all contexts {@link //EMPTY} in array is a special value (and
// nil parent).<br>
// <embed src="images/FullMerge_EmptyRoot.svg" type="image/svg+xml"/></p>
//
// <p><embed src="images/FullMerge_SameRoot.svg" type="image/svg+xml"/></p>
//
// @param a the first {@link SingletonBasePredictionContext}
// @param b the second {@link SingletonBasePredictionContext}
// @param rootIsWildcard {@code true} if this is a local-context merge,
// otherwise false to indicate a full-context merge
// /
func mergeRoot(a, b SingletonPredictionContext, rootIsWildcard bool) PredictionContext {
if rootIsWildcard {
if a == BasePredictionContextEMPTY {
return BasePredictionContextEMPTY // // + b =//
}
if b == BasePredictionContextEMPTY {
return BasePredictionContextEMPTY // a +// =//
}
} else {
if a == BasePredictionContextEMPTY && b == BasePredictionContextEMPTY {
return BasePredictionContextEMPTY // $ + $ = $
} else if a == BasePredictionContextEMPTY { // $ + x = [$,x]
payloads := []int{b.getReturnState(-1), BasePredictionContextEmptyReturnState}
parents := []PredictionContext{b.GetParent(-1), nil}
return NewArrayPredictionContext(parents, payloads)
} else if b == BasePredictionContextEMPTY { // x + $ = [$,x] ($ is always first if present)
payloads := []int{a.getReturnState(-1), BasePredictionContextEmptyReturnState}
parents := []PredictionContext{a.GetParent(-1), nil}
return NewArrayPredictionContext(parents, payloads)
}
}
return nil
}
// Merge two {@link ArrayBasePredictionContext} instances.
//
// <p>Different tops, different parents.<br>
// <embed src="images/ArrayMerge_DiffTopDiffPar.svg" type="image/svg+xml"/></p>
//
// <p>Shared top, same parents.<br>
// <embed src="images/ArrayMerge_ShareTopSamePar.svg" type="image/svg+xml"/></p>
//
// <p>Shared top, different parents.<br>
// <embed src="images/ArrayMerge_ShareTopDiffPar.svg" type="image/svg+xml"/></p>
//
// <p>Shared top, all shared parents.<br>
// <embed src="images/ArrayMerge_ShareTopSharePar.svg"
// type="image/svg+xml"/></p>
//
// <p>Equal tops, merge parents and reduce top to
// {@link SingletonBasePredictionContext}.<br>
// <embed src="images/ArrayMerge_EqualTop.svg" type="image/svg+xml"/></p>
// /
func mergeArrays(a, b *ArrayPredictionContext, rootIsWildcard bool, mergeCache *DoubleDict) PredictionContext {
if mergeCache != nil {
previous := mergeCache.Get(a.Hash(), b.Hash())
if previous != nil {
if ParserATNSimulatorTraceATNSim {
fmt.Println("mergeArrays a=" + a.String() + ",b=" + b.String() + " -> previous")
}
return previous.(PredictionContext)
}
previous = mergeCache.Get(b.Hash(), a.Hash())
if previous != nil {
if ParserATNSimulatorTraceATNSim {
fmt.Println("mergeArrays a=" + a.String() + ",b=" + b.String() + " -> previous")
}
return previous.(PredictionContext)
}
}
// merge sorted payloads a + b => M
i := 0 // walks a
j := 0 // walks b
k := 0 // walks target M array
mergedReturnStates := make([]int, len(a.returnStates)+len(b.returnStates))
mergedParents := make([]PredictionContext, len(a.returnStates)+len(b.returnStates))
// walk and merge to yield mergedParents, mergedReturnStates
for i < len(a.returnStates) && j < len(b.returnStates) {
aParent := a.parents[i]
bParent := b.parents[j]
if a.returnStates[i] == b.returnStates[j] {
// same payload (stack tops are equal), must yield merged singleton
payload := a.returnStates[i]
// $+$ = $
bothDollars := payload == BasePredictionContextEmptyReturnState && aParent == nil && bParent == nil
axAX := aParent != nil && bParent != nil && aParent == bParent // ax+ax
// ->
// ax
if bothDollars || axAX {
mergedParents[k] = aParent // choose left
mergedReturnStates[k] = payload
} else { // ax+ay -> a'[x,y]
mergedParent := merge(aParent, bParent, rootIsWildcard, mergeCache)
mergedParents[k] = mergedParent
mergedReturnStates[k] = payload
}
i++ // hop over left one as usual
j++ // but also Skip one in right side since we merge
} else if a.returnStates[i] < b.returnStates[j] { // copy a[i] to M
mergedParents[k] = aParent
mergedReturnStates[k] = a.returnStates[i]
i++
} else { // b > a, copy b[j] to M
mergedParents[k] = bParent
mergedReturnStates[k] = b.returnStates[j]
j++
}
k++
}
// copy over any payloads remaining in either array
if i < len(a.returnStates) {
for p := i; p < len(a.returnStates); p++ {
mergedParents[k] = a.parents[p]
mergedReturnStates[k] = a.returnStates[p]
k++
}
} else {
for p := j; p < len(b.returnStates); p++ {
mergedParents[k] = b.parents[p]
mergedReturnStates[k] = b.returnStates[p]
k++
}
}
// trim merged if we combined a few that had same stack tops
if k < len(mergedParents) { // write index < last position trim
if k == 1 { // for just one merged element, return singleton top
pc := SingletonBasePredictionContextCreate(mergedParents[0], mergedReturnStates[0])
if mergeCache != nil {
mergeCache.set(a.Hash(), b.Hash(), pc)
}
return pc
}
mergedParents = mergedParents[0:k]
mergedReturnStates = mergedReturnStates[0:k]
}
M := NewArrayPredictionContext(mergedParents, mergedReturnStates)
// if we created same array as a or b, return that instead
// TODO: track whether this is possible above during merge sort for speed
// TODO: In go, I do not think we can just do M == xx as M is a brand new allocation. This could be causing allocation problems
if M == a {
if mergeCache != nil {
mergeCache.set(a.Hash(), b.Hash(), a)
}
if ParserATNSimulatorTraceATNSim {
fmt.Println("mergeArrays a=" + a.String() + ",b=" + b.String() + " -> a")
}
return a
}
if M == b {
if mergeCache != nil {
mergeCache.set(a.Hash(), b.Hash(), b)
}
if ParserATNSimulatorTraceATNSim {
fmt.Println("mergeArrays a=" + a.String() + ",b=" + b.String() + " -> b")
}
return b
}
combineCommonParents(mergedParents)
if mergeCache != nil {
mergeCache.set(a.Hash(), b.Hash(), M)
}
if ParserATNSimulatorTraceATNSim {
fmt.Println("mergeArrays a=" + a.String() + ",b=" + b.String() + " -> " + M.String())
}
return M
}
// Make pass over all <em>M</em> {@code parents} merge any {@code equals()}
// ones.
// /
func combineCommonParents(parents []PredictionContext) {
uniqueParents := make(map[PredictionContext]PredictionContext)
for p := 0; p < len(parents); p++ {
parent := parents[p]
if uniqueParents[parent] == nil {
uniqueParents[parent] = parent
}
}
for q := 0; q < len(parents); q++ {
parents[q] = uniqueParents[parents[q]]
}
}
func getCachedBasePredictionContext(context PredictionContext, contextCache *PredictionContextCache, visited map[PredictionContext]PredictionContext) PredictionContext {
if context.isEmpty() {
return context
}
existing := visited[context]
if existing != nil {
return existing
}
existing = contextCache.Get(context)
if existing != nil {
visited[context] = existing
return existing
}
changed := false
parents := make([]PredictionContext, context.length())
for i := 0; i < len(parents); i++ {
parent := getCachedBasePredictionContext(context.GetParent(i), contextCache, visited)
if changed || parent != context.GetParent(i) {
if !changed {
parents = make([]PredictionContext, context.length())
for j := 0; j < context.length(); j++ {
parents[j] = context.GetParent(j)
}
changed = true
}
parents[i] = parent
}
}
if !changed {
contextCache.add(context)
visited[context] = context
return context
}
var updated PredictionContext
if len(parents) == 0 {
updated = BasePredictionContextEMPTY
} else if len(parents) == 1 {
updated = SingletonBasePredictionContextCreate(parents[0], context.getReturnState(0))
} else {
updated = NewArrayPredictionContext(parents, context.(*ArrayPredictionContext).GetReturnStates())
}
contextCache.add(updated)
visited[updated] = updated
visited[context] = updated
return updated
}

View File

@@ -1,529 +0,0 @@
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
package antlr
// This enumeration defines the prediction modes available in ANTLR 4 along with
// utility methods for analyzing configuration sets for conflicts and/or
// ambiguities.
const (
//
// The SLL(*) prediction mode. This prediction mode ignores the current
// parser context when making predictions. This is the fastest prediction
// mode, and provides correct results for many grammars. This prediction
// mode is more powerful than the prediction mode provided by ANTLR 3, but
// may result in syntax errors for grammar and input combinations which are
// not SLL.
//
// <p>
// When using this prediction mode, the parser will either return a correct
// parse tree (i.e. the same parse tree that would be returned with the
// {@link //LL} prediction mode), or it will Report a syntax error. If a
// syntax error is encountered when using the {@link //SLL} prediction mode,
// it may be due to either an actual syntax error in the input or indicate
// that the particular combination of grammar and input requires the more
// powerful {@link //LL} prediction abilities to complete successfully.</p>
//
// <p>
// This prediction mode does not provide any guarantees for prediction
// behavior for syntactically-incorrect inputs.</p>
//
PredictionModeSLL = 0
//
// The LL(*) prediction mode. This prediction mode allows the current parser
// context to be used for resolving SLL conflicts that occur during
// prediction. This is the fastest prediction mode that guarantees correct
// parse results for all combinations of grammars with syntactically correct
// inputs.
//
// <p>
// When using this prediction mode, the parser will make correct decisions
// for all syntactically-correct grammar and input combinations. However, in
// cases where the grammar is truly ambiguous this prediction mode might not
// Report a precise answer for <em>exactly which</em> alternatives are
// ambiguous.</p>
//
// <p>
// This prediction mode does not provide any guarantees for prediction
// behavior for syntactically-incorrect inputs.</p>
//
PredictionModeLL = 1
//
// The LL(*) prediction mode with exact ambiguity detection. In addition to
// the correctness guarantees provided by the {@link //LL} prediction mode,
// this prediction mode instructs the prediction algorithm to determine the
// complete and exact set of ambiguous alternatives for every ambiguous
// decision encountered while parsing.
//
// <p>
// This prediction mode may be used for diagnosing ambiguities during
// grammar development. Due to the performance overhead of calculating sets
// of ambiguous alternatives, this prediction mode should be avoided when
// the exact results are not necessary.</p>
//
// <p>
// This prediction mode does not provide any guarantees for prediction
// behavior for syntactically-incorrect inputs.</p>
//
PredictionModeLLExactAmbigDetection = 2
)
// Computes the SLL prediction termination condition.
//
// <p>
// This method computes the SLL prediction termination condition for both of
// the following cases.</p>
//
// <ul>
// <li>The usual SLL+LL fallback upon SLL conflict</li>
// <li>Pure SLL without LL fallback</li>
// </ul>
//
// <p><strong>COMBINED SLL+LL PARSING</strong></p>
//
// <p>When LL-fallback is enabled upon SLL conflict, correct predictions are
// ensured regardless of how the termination condition is computed by this
// method. Due to the substantially higher cost of LL prediction, the
// prediction should only fall back to LL when the additional lookahead
// cannot lead to a unique SLL prediction.</p>
//
// <p>Assuming combined SLL+LL parsing, an SLL configuration set with only
// conflicting subsets should fall back to full LL, even if the
// configuration sets don't resolve to the same alternative (e.g.
// {@code {1,2}} and {@code {3,4}}. If there is at least one non-conflicting
// configuration, SLL could continue with the hopes that more lookahead will
// resolve via one of those non-conflicting configurations.</p>
//
// <p>Here's the prediction termination rule them: SLL (for SLL+LL parsing)
// stops when it sees only conflicting configuration subsets. In contrast,
// full LL keeps going when there is uncertainty.</p>
//
// <p><strong>HEURISTIC</strong></p>
//
// <p>As a heuristic, we stop prediction when we see any conflicting subset
// unless we see a state that only has one alternative associated with it.
// The single-alt-state thing lets prediction continue upon rules like
// (otherwise, it would admit defeat too soon):</p>
//
// <p>{@code [12|1|[], 6|2|[], 12|2|[]]. s : (ID | ID ID?) ” }</p>
//
// <p>When the ATN simulation reaches the state before {@code ”}, it has a
// DFA state that looks like: {@code [12|1|[], 6|2|[], 12|2|[]]}. Naturally
// {@code 12|1|[]} and {@code 12|2|[]} conflict, but we cannot stop
// processing this node because alternative to has another way to continue,
// via {@code [6|2|[]]}.</p>
//
// <p>It also let's us continue for this rule:</p>
//
// <p>{@code [1|1|[], 1|2|[], 8|3|[]] a : A | A | A B }</p>
//
// <p>After Matching input A, we reach the stop state for rule A, state 1.
// State 8 is the state right before B. Clearly alternatives 1 and 2
// conflict and no amount of further lookahead will separate the two.
// However, alternative 3 will be able to continue and so we do not stop
// working on this state. In the previous example, we're concerned with
// states associated with the conflicting alternatives. Here alt 3 is not
// associated with the conflicting configs, but since we can continue
// looking for input reasonably, don't declare the state done.</p>
//
// <p><strong>PURE SLL PARSING</strong></p>
//
// <p>To handle pure SLL parsing, all we have to do is make sure that we
// combine stack contexts for configurations that differ only by semantic
// predicate. From there, we can do the usual SLL termination heuristic.</p>
//
// <p><strong>PREDICATES IN SLL+LL PARSING</strong></p>
//
// <p>SLL decisions don't evaluate predicates until after they reach DFA stop
// states because they need to create the DFA cache that works in all
// semantic situations. In contrast, full LL evaluates predicates collected
// during start state computation so it can ignore predicates thereafter.
// This means that SLL termination detection can totally ignore semantic
// predicates.</p>
//
// <p>Implementation-wise, {@link ATNConfigSet} combines stack contexts but not
// semantic predicate contexts so we might see two configurations like the
// following.</p>
//
// <p>{@code (s, 1, x, {}), (s, 1, x', {p})}</p>
//
// <p>Before testing these configurations against others, we have to merge
// {@code x} and {@code x'} (without modifying the existing configurations).
// For example, we test {@code (x+x')==x”} when looking for conflicts in
// the following configurations.</p>
//
// <p>{@code (s, 1, x, {}), (s, 1, x', {p}), (s, 2, x”, {})}</p>
//
// <p>If the configuration set has predicates (as indicated by
// {@link ATNConfigSet//hasSemanticContext}), this algorithm makes a copy of
// the configurations to strip out all of the predicates so that a standard
// {@link ATNConfigSet} will merge everything ignoring predicates.</p>
func PredictionModehasSLLConflictTerminatingPrediction(mode int, configs ATNConfigSet) bool {
// Configs in rule stop states indicate reaching the end of the decision
// rule (local context) or end of start rule (full context). If all
// configs meet this condition, then none of the configurations is able
// to Match additional input so we terminate prediction.
//
if PredictionModeallConfigsInRuleStopStates(configs) {
return true
}
// pure SLL mode parsing
if mode == PredictionModeSLL {
// Don't bother with combining configs from different semantic
// contexts if we can fail over to full LL costs more time
// since we'll often fail over anyway.
if configs.HasSemanticContext() {
// dup configs, tossing out semantic predicates
dup := NewBaseATNConfigSet(false)
for _, c := range configs.GetItems() {
// NewBaseATNConfig({semanticContext:}, c)
c = NewBaseATNConfig2(c, SemanticContextNone)
dup.Add(c, nil)
}
configs = dup
}
// now we have combined contexts for configs with dissimilar preds
}
// pure SLL or combined SLL+LL mode parsing
altsets := PredictionModegetConflictingAltSubsets(configs)
return PredictionModehasConflictingAltSet(altsets) && !PredictionModehasStateAssociatedWithOneAlt(configs)
}
// Checks if any configuration in {@code configs} is in a
// {@link RuleStopState}. Configurations meeting this condition have reached
// the end of the decision rule (local context) or end of start rule (full
// context).
//
// @param configs the configuration set to test
// @return {@code true} if any configuration in {@code configs} is in a
// {@link RuleStopState}, otherwise {@code false}
func PredictionModehasConfigInRuleStopState(configs ATNConfigSet) bool {
for _, c := range configs.GetItems() {
if _, ok := c.GetState().(*RuleStopState); ok {
return true
}
}
return false
}
// Checks if all configurations in {@code configs} are in a
// {@link RuleStopState}. Configurations meeting this condition have reached
// the end of the decision rule (local context) or end of start rule (full
// context).
//
// @param configs the configuration set to test
// @return {@code true} if all configurations in {@code configs} are in a
// {@link RuleStopState}, otherwise {@code false}
func PredictionModeallConfigsInRuleStopStates(configs ATNConfigSet) bool {
for _, c := range configs.GetItems() {
if _, ok := c.GetState().(*RuleStopState); !ok {
return false
}
}
return true
}
// Full LL prediction termination.
//
// <p>Can we stop looking ahead during ATN simulation or is there some
// uncertainty as to which alternative we will ultimately pick, after
// consuming more input? Even if there are partial conflicts, we might know
// that everything is going to resolve to the same minimum alternative. That
// means we can stop since no more lookahead will change that fact. On the
// other hand, there might be multiple conflicts that resolve to different
// minimums. That means we need more look ahead to decide which of those
// alternatives we should predict.</p>
//
// <p>The basic idea is to split the set of configurations {@code C}, into
// conflicting subsets {@code (s, _, ctx, _)} and singleton subsets with
// non-conflicting configurations. Two configurations conflict if they have
// identical {@link ATNConfig//state} and {@link ATNConfig//context} values
// but different {@link ATNConfig//alt} value, e.g. {@code (s, i, ctx, _)}
// and {@code (s, j, ctx, _)} for {@code i!=j}.</p>
//
// <p>Reduce these configuration subsets to the set of possible alternatives.
// You can compute the alternative subsets in one pass as follows:</p>
//
// <p>{@code A_s,ctx = {i | (s, i, ctx, _)}} for each configuration in
// {@code C} holding {@code s} and {@code ctx} fixed.</p>
//
// <p>Or in pseudo-code, for each configuration {@code c} in {@code C}:</p>
//
// <pre>
// map[c] U= c.{@link ATNConfig//alt alt} // map hash/equals uses s and x, not
// alt and not pred
// </pre>
//
// <p>The values in {@code map} are the set of {@code A_s,ctx} sets.</p>
//
// <p>If {@code |A_s,ctx|=1} then there is no conflict associated with
// {@code s} and {@code ctx}.</p>
//
// <p>Reduce the subsets to singletons by choosing a minimum of each subset. If
// the union of these alternative subsets is a singleton, then no amount of
// more lookahead will help us. We will always pick that alternative. If,
// however, there is more than one alternative, then we are uncertain which
// alternative to predict and must continue looking for resolution. We may
// or may not discover an ambiguity in the future, even if there are no
// conflicting subsets this round.</p>
//
// <p>The biggest sin is to terminate early because it means we've made a
// decision but were uncertain as to the eventual outcome. We haven't used
// enough lookahead. On the other hand, announcing a conflict too late is no
// big deal you will still have the conflict. It's just inefficient. It
// might even look until the end of file.</p>
//
// <p>No special consideration for semantic predicates is required because
// predicates are evaluated on-the-fly for full LL prediction, ensuring that
// no configuration contains a semantic context during the termination
// check.</p>
//
// <p><strong>CONFLICTING CONFIGS</strong></p>
//
// <p>Two configurations {@code (s, i, x)} and {@code (s, j, x')}, conflict
// when {@code i!=j} but {@code x=x'}. Because we merge all
// {@code (s, i, _)} configurations together, that means that there are at
// most {@code n} configurations associated with state {@code s} for
// {@code n} possible alternatives in the decision. The merged stacks
// complicate the comparison of configuration contexts {@code x} and
// {@code x'}. Sam checks to see if one is a subset of the other by calling
// merge and checking to see if the merged result is either {@code x} or
// {@code x'}. If the {@code x} associated with lowest alternative {@code i}
// is the superset, then {@code i} is the only possible prediction since the
// others resolve to {@code min(i)} as well. However, if {@code x} is
// associated with {@code j>i} then at least one stack configuration for
// {@code j} is not in conflict with alternative {@code i}. The algorithm
// should keep going, looking for more lookahead due to the uncertainty.</p>
//
// <p>For simplicity, I'm doing a equality check between {@code x} and
// {@code x'} that lets the algorithm continue to consume lookahead longer
// than necessary. The reason I like the equality is of course the
// simplicity but also because that is the test you need to detect the
// alternatives that are actually in conflict.</p>
//
// <p><strong>CONTINUE/STOP RULE</strong></p>
//
// <p>Continue if union of resolved alternative sets from non-conflicting and
// conflicting alternative subsets has more than one alternative. We are
// uncertain about which alternative to predict.</p>
//
// <p>The complete set of alternatives, {@code [i for (_,i,_)]}, tells us which
// alternatives are still in the running for the amount of input we've
// consumed at this point. The conflicting sets let us to strip away
// configurations that won't lead to more states because we resolve
// conflicts to the configuration with a minimum alternate for the
// conflicting set.</p>
//
// <p><strong>CASES</strong></p>
//
// <ul>
//
// <li>no conflicts and more than 1 alternative in set =&gt continue</li>
//
// <li> {@code (s, 1, x)}, {@code (s, 2, x)}, {@code (s, 3, z)},
// {@code (s', 1, y)}, {@code (s', 2, y)} yields non-conflicting set
// {@code {3}} U conflicting sets {@code min({1,2})} U {@code min({1,2})} =
// {@code {1,3}} =&gt continue
// </li>
//
// <li>{@code (s, 1, x)}, {@code (s, 2, x)}, {@code (s', 1, y)},
// {@code (s', 2, y)}, {@code (s”, 1, z)} yields non-conflicting set
// {@code {1}} U conflicting sets {@code min({1,2})} U {@code min({1,2})} =
// {@code {1}} =&gt stop and predict 1</li>
//
// <li>{@code (s, 1, x)}, {@code (s, 2, x)}, {@code (s', 1, y)},
// {@code (s', 2, y)} yields conflicting, reduced sets {@code {1}} U
// {@code {1}} = {@code {1}} =&gt stop and predict 1, can announce
// ambiguity {@code {1,2}}</li>
//
// <li>{@code (s, 1, x)}, {@code (s, 2, x)}, {@code (s', 2, y)},
// {@code (s', 3, y)} yields conflicting, reduced sets {@code {1}} U
// {@code {2}} = {@code {1,2}} =&gt continue</li>
//
// <li>{@code (s, 1, x)}, {@code (s, 2, x)}, {@code (s', 3, y)},
// {@code (s', 4, y)} yields conflicting, reduced sets {@code {1}} U
// {@code {3}} = {@code {1,3}} =&gt continue</li>
//
// </ul>
//
// <p><strong>EXACT AMBIGUITY DETECTION</strong></p>
//
// <p>If all states Report the same conflicting set of alternatives, then we
// know we have the exact ambiguity set.</p>
//
// <p><code>|A_<em>i</em>|&gt1</code> and
// <code>A_<em>i</em> = A_<em>j</em></code> for all <em>i</em>, <em>j</em>.</p>
//
// <p>In other words, we continue examining lookahead until all {@code A_i}
// have more than one alternative and all {@code A_i} are the same. If
// {@code A={{1,2}, {1,3}}}, then regular LL prediction would terminate
// because the resolved set is {@code {1}}. To determine what the real
// ambiguity is, we have to know whether the ambiguity is between one and
// two or one and three so we keep going. We can only stop prediction when
// we need exact ambiguity detection when the sets look like
// {@code A={{1,2}}} or {@code {{1,2},{1,2}}}, etc...</p>
func PredictionModeresolvesToJustOneViableAlt(altsets []*BitSet) int {
return PredictionModegetSingleViableAlt(altsets)
}
// Determines if every alternative subset in {@code altsets} contains more
// than one alternative.
//
// @param altsets a collection of alternative subsets
// @return {@code true} if every {@link BitSet} in {@code altsets} has
// {@link BitSet//cardinality cardinality} &gt 1, otherwise {@code false}
func PredictionModeallSubsetsConflict(altsets []*BitSet) bool {
return !PredictionModehasNonConflictingAltSet(altsets)
}
// Determines if any single alternative subset in {@code altsets} contains
// exactly one alternative.
//
// @param altsets a collection of alternative subsets
// @return {@code true} if {@code altsets} contains a {@link BitSet} with
// {@link BitSet//cardinality cardinality} 1, otherwise {@code false}
func PredictionModehasNonConflictingAltSet(altsets []*BitSet) bool {
for i := 0; i < len(altsets); i++ {
alts := altsets[i]
if alts.length() == 1 {
return true
}
}
return false
}
// Determines if any single alternative subset in {@code altsets} contains
// more than one alternative.
//
// @param altsets a collection of alternative subsets
// @return {@code true} if {@code altsets} contains a {@link BitSet} with
// {@link BitSet//cardinality cardinality} &gt 1, otherwise {@code false}
func PredictionModehasConflictingAltSet(altsets []*BitSet) bool {
for i := 0; i < len(altsets); i++ {
alts := altsets[i]
if alts.length() > 1 {
return true
}
}
return false
}
// Determines if every alternative subset in {@code altsets} is equivalent.
//
// @param altsets a collection of alternative subsets
// @return {@code true} if every member of {@code altsets} is equal to the
// others, otherwise {@code false}
func PredictionModeallSubsetsEqual(altsets []*BitSet) bool {
var first *BitSet
for i := 0; i < len(altsets); i++ {
alts := altsets[i]
if first == nil {
first = alts
} else if alts != first {
return false
}
}
return true
}
// Returns the unique alternative predicted by all alternative subsets in
// {@code altsets}. If no such alternative exists, this method returns
// {@link ATN//INVALID_ALT_NUMBER}.
//
// @param altsets a collection of alternative subsets
func PredictionModegetUniqueAlt(altsets []*BitSet) int {
all := PredictionModeGetAlts(altsets)
if all.length() == 1 {
return all.minValue()
}
return ATNInvalidAltNumber
}
// Gets the complete set of represented alternatives for a collection of
// alternative subsets. This method returns the union of each {@link BitSet}
// in {@code altsets}.
//
// @param altsets a collection of alternative subsets
// @return the set of represented alternatives in {@code altsets}
func PredictionModeGetAlts(altsets []*BitSet) *BitSet {
all := NewBitSet()
for _, alts := range altsets {
all.or(alts)
}
return all
}
// PredictionModegetConflictingAltSubsets gets the conflicting alt subsets from a configuration set.
// For each configuration {@code c} in {@code configs}:
//
// <pre>
// map[c] U= c.{@link ATNConfig//alt alt} // map hash/equals uses s and x, not
// alt and not pred
// </pre>
func PredictionModegetConflictingAltSubsets(configs ATNConfigSet) []*BitSet {
configToAlts := NewJMap[ATNConfig, *BitSet, *ATNAltConfigComparator[ATNConfig]](atnAltCfgEqInst)
for _, c := range configs.GetItems() {
alts, ok := configToAlts.Get(c)
if !ok {
alts = NewBitSet()
configToAlts.Put(c, alts)
}
alts.add(c.GetAlt())
}
return configToAlts.Values()
}
// PredictionModeGetStateToAltMap gets a map from state to alt subset from a configuration set. For each
// configuration {@code c} in {@code configs}:
//
// <pre>
// map[c.{@link ATNConfig//state state}] U= c.{@link ATNConfig//alt alt}
// </pre>
func PredictionModeGetStateToAltMap(configs ATNConfigSet) *AltDict {
m := NewAltDict()
for _, c := range configs.GetItems() {
alts := m.Get(c.GetState().String())
if alts == nil {
alts = NewBitSet()
m.put(c.GetState().String(), alts)
}
alts.(*BitSet).add(c.GetAlt())
}
return m
}
func PredictionModehasStateAssociatedWithOneAlt(configs ATNConfigSet) bool {
values := PredictionModeGetStateToAltMap(configs).values()
for i := 0; i < len(values); i++ {
if values[i].(*BitSet).length() == 1 {
return true
}
}
return false
}
func PredictionModegetSingleViableAlt(altsets []*BitSet) int {
result := ATNInvalidAltNumber
for i := 0; i < len(altsets); i++ {
alts := altsets[i]
minAlt := alts.minValue()
if result == ATNInvalidAltNumber {
result = minAlt
} else if result != minAlt { // more than 1 viable alt
return ATNInvalidAltNumber
}
}
return result
}

View File

@@ -1,114 +0,0 @@
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
package antlr
// A rule context is a record of a single rule invocation. It knows
// which context invoked it, if any. If there is no parent context, then
// naturally the invoking state is not valid. The parent link
// provides a chain upwards from the current rule invocation to the root
// of the invocation tree, forming a stack. We actually carry no
// information about the rule associated with b context (except
// when parsing). We keep only the state number of the invoking state from
// the ATN submachine that invoked b. Contrast b with the s
// pointer inside ParserRuleContext that tracks the current state
// being "executed" for the current rule.
//
// The parent contexts are useful for computing lookahead sets and
// getting error information.
//
// These objects are used during parsing and prediction.
// For the special case of parsers, we use the subclass
// ParserRuleContext.
//
// @see ParserRuleContext
//
type RuleContext interface {
RuleNode
GetInvokingState() int
SetInvokingState(int)
GetRuleIndex() int
IsEmpty() bool
GetAltNumber() int
SetAltNumber(altNumber int)
String([]string, RuleContext) string
}
type BaseRuleContext struct {
parentCtx RuleContext
invokingState int
RuleIndex int
}
func NewBaseRuleContext(parent RuleContext, invokingState int) *BaseRuleContext {
rn := new(BaseRuleContext)
// What context invoked b rule?
rn.parentCtx = parent
// What state invoked the rule associated with b context?
// The "return address" is the followState of invokingState
// If parent is nil, b should be -1.
if parent == nil {
rn.invokingState = -1
} else {
rn.invokingState = invokingState
}
return rn
}
func (b *BaseRuleContext) GetBaseRuleContext() *BaseRuleContext {
return b
}
func (b *BaseRuleContext) SetParent(v Tree) {
if v == nil {
b.parentCtx = nil
} else {
b.parentCtx = v.(RuleContext)
}
}
func (b *BaseRuleContext) GetInvokingState() int {
return b.invokingState
}
func (b *BaseRuleContext) SetInvokingState(t int) {
b.invokingState = t
}
func (b *BaseRuleContext) GetRuleIndex() int {
return b.RuleIndex
}
func (b *BaseRuleContext) GetAltNumber() int {
return ATNInvalidAltNumber
}
func (b *BaseRuleContext) SetAltNumber(altNumber int) {}
// A context is empty if there is no invoking state meaning nobody call
// current context.
func (b *BaseRuleContext) IsEmpty() bool {
return b.invokingState == -1
}
// Return the combined text of all child nodes. This method only considers
// tokens which have been added to the parse tree.
// <p>
// Since tokens on hidden channels (e.g. whitespace or comments) are not
// added to the parse trees, they will not appear in the output of b
// method.
//
func (b *BaseRuleContext) GetParent() Tree {
return b.parentCtx
}

View File

@@ -1,235 +0,0 @@
package antlr
import "math"
const (
_initalCapacity = 16
_initalBucketCapacity = 8
_loadFactor = 0.75
)
type Set interface {
Add(value interface{}) (added interface{})
Len() int
Get(value interface{}) (found interface{})
Contains(value interface{}) bool
Values() []interface{}
Each(f func(interface{}) bool)
}
type array2DHashSet struct {
buckets [][]Collectable[any]
hashcodeFunction func(interface{}) int
equalsFunction func(Collectable[any], Collectable[any]) bool
n int // How many elements in set
threshold int // when to expand
currentPrime int // jump by 4 primes each expand or whatever
initialBucketCapacity int
}
func (as *array2DHashSet) Each(f func(interface{}) bool) {
if as.Len() < 1 {
return
}
for _, bucket := range as.buckets {
for _, o := range bucket {
if o == nil {
break
}
if !f(o) {
return
}
}
}
}
func (as *array2DHashSet) Values() []interface{} {
if as.Len() < 1 {
return nil
}
values := make([]interface{}, 0, as.Len())
as.Each(func(i interface{}) bool {
values = append(values, i)
return true
})
return values
}
func (as *array2DHashSet) Contains(value Collectable[any]) bool {
return as.Get(value) != nil
}
func (as *array2DHashSet) Add(value Collectable[any]) interface{} {
if as.n > as.threshold {
as.expand()
}
return as.innerAdd(value)
}
func (as *array2DHashSet) expand() {
old := as.buckets
as.currentPrime += 4
var (
newCapacity = len(as.buckets) << 1
newTable = as.createBuckets(newCapacity)
newBucketLengths = make([]int, len(newTable))
)
as.buckets = newTable
as.threshold = int(float64(newCapacity) * _loadFactor)
for _, bucket := range old {
if bucket == nil {
continue
}
for _, o := range bucket {
if o == nil {
break
}
b := as.getBuckets(o)
bucketLength := newBucketLengths[b]
var newBucket []Collectable[any]
if bucketLength == 0 {
// new bucket
newBucket = as.createBucket(as.initialBucketCapacity)
newTable[b] = newBucket
} else {
newBucket = newTable[b]
if bucketLength == len(newBucket) {
// expand
newBucketCopy := make([]Collectable[any], len(newBucket)<<1)
copy(newBucketCopy[:bucketLength], newBucket)
newBucket = newBucketCopy
newTable[b] = newBucket
}
}
newBucket[bucketLength] = o
newBucketLengths[b]++
}
}
}
func (as *array2DHashSet) Len() int {
return as.n
}
func (as *array2DHashSet) Get(o Collectable[any]) interface{} {
if o == nil {
return nil
}
b := as.getBuckets(o)
bucket := as.buckets[b]
if bucket == nil { // no bucket
return nil
}
for _, e := range bucket {
if e == nil {
return nil // empty slot; not there
}
if as.equalsFunction(e, o) {
return e
}
}
return nil
}
func (as *array2DHashSet) innerAdd(o Collectable[any]) interface{} {
b := as.getBuckets(o)
bucket := as.buckets[b]
// new bucket
if bucket == nil {
bucket = as.createBucket(as.initialBucketCapacity)
bucket[0] = o
as.buckets[b] = bucket
as.n++
return o
}
// look for it in bucket
for i := 0; i < len(bucket); i++ {
existing := bucket[i]
if existing == nil { // empty slot; not there, add.
bucket[i] = o
as.n++
return o
}
if as.equalsFunction(existing, o) { // found existing, quit
return existing
}
}
// full bucket, expand and add to end
oldLength := len(bucket)
bucketCopy := make([]Collectable[any], oldLength<<1)
copy(bucketCopy[:oldLength], bucket)
bucket = bucketCopy
as.buckets[b] = bucket
bucket[oldLength] = o
as.n++
return o
}
func (as *array2DHashSet) getBuckets(value Collectable[any]) int {
hash := as.hashcodeFunction(value)
return hash & (len(as.buckets) - 1)
}
func (as *array2DHashSet) createBuckets(cap int) [][]Collectable[any] {
return make([][]Collectable[any], cap)
}
func (as *array2DHashSet) createBucket(cap int) []Collectable[any] {
return make([]Collectable[any], cap)
}
func newArray2DHashSetWithCap(
hashcodeFunction func(interface{}) int,
equalsFunction func(Collectable[any], Collectable[any]) bool,
initCap int,
initBucketCap int,
) *array2DHashSet {
if hashcodeFunction == nil {
hashcodeFunction = standardHashFunction
}
if equalsFunction == nil {
equalsFunction = standardEqualsFunction
}
ret := &array2DHashSet{
hashcodeFunction: hashcodeFunction,
equalsFunction: equalsFunction,
n: 0,
threshold: int(math.Floor(_initalCapacity * _loadFactor)),
currentPrime: 1,
initialBucketCapacity: initBucketCap,
}
ret.buckets = ret.createBuckets(initCap)
return ret
}
func newArray2DHashSet(
hashcodeFunction func(interface{}) int,
equalsFunction func(Collectable[any], Collectable[any]) bool,
) *array2DHashSet {
return newArray2DHashSetWithCap(hashcodeFunction, equalsFunction, _initalCapacity, _initalBucketCapacity)
}

18
vendor/github.com/antlr4-go/antlr/v4/.gitignore generated vendored Normal file
View File

@@ -0,0 +1,18 @@
### Go template
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib
# Test binary, built with `go test -c`
*.test
# Go workspace file
go.work
# No Goland stuff in this repo
.idea

28
vendor/github.com/antlr4-go/antlr/v4/LICENSE generated vendored Normal file
View File

@@ -0,0 +1,28 @@
Copyright (c) 2012-2023 The ANTLR Project. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither name of copyright holders nor the names of its contributors
may be used to endorse or promote products derived from this software
without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR
CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

54
vendor/github.com/antlr4-go/antlr/v4/README.md generated vendored Normal file
View File

@@ -0,0 +1,54 @@
[![Go Report Card](https://goreportcard.com/badge/github.com/antlr4-go/antlr?style=flat-square)](https://goreportcard.com/report/github.com/antlr4-go/antlr)
[![PkgGoDev](https://pkg.go.dev/badge/github.com/github.com/antlr4-go/antlr)](https://pkg.go.dev/github.com/antlr4-go/antlr)
[![Release](https://img.shields.io/github/v/release/antlr4-go/antlr?sort=semver&style=flat-square)](https://github.com/antlr4-go/antlr/releases/latest)
[![Release](https://img.shields.io/github/go-mod/go-version/antlr4-go/antlr?style=flat-square)](https://github.com/antlr4-go/antlr/releases/latest)
[![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg?style=flat-square)](https://github.com/antlr4-go/antlr/commit-activity)
[![License](https://img.shields.io/badge/License-BSD_3--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause)
[![GitHub stars](https://img.shields.io/github/stars/antlr4-go/antlr?style=flat-square&label=Star&maxAge=2592000)](https://GitHub.com/Naereen/StrapDown.js/stargazers/)
# ANTLR4 Go Runtime Module Repo
IMPORTANT: Please submit PRs via a clone of the https://github.com/antlr/antlr4 repo, and not here.
- Do not submit PRs or any change requests to this repo
- This repo is read only and is updated by the ANTLR team to create a new release of the Go Runtime for ANTLR
- This repo contains the Go runtime that your generated projects should import
## Introduction
This repo contains the official modules for the Go Runtime for ANTLR. It is a copy of the runtime maintained
at: https://github.com/antlr/antlr4/tree/master/runtime/Go/antlr and is automatically updated by the ANTLR team to create
the official Go runtime release only. No development work is carried out in this repo and PRs are not accepted here.
The dev branch of this repo is kept in sync with the dev branch of the main ANTLR repo and is updated periodically.
### Why?
The `go get` command is unable to retrieve the Go runtime when it is embedded so
deeply in the main repo. A `go get` against the `antlr/antlr4` repo, while retrieving the correct source code for the runtime,
does not correctly resolve tags and will create a reference in your `go.mod` file that is unclear, will not upgrade smoothly and
causes confusion.
For instance, the current Go runtime release, which is tagged with v4.13.0 in `antlr/antlr4` is retrieved by go get as:
```sh
require (
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230219212500-1f9a474cc2dc
)
```
Where you would expect to see:
```sh
require (
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.13.0
)
```
The decision was taken to create a separate org in a separate repo to hold the official Go runtime for ANTLR and
from whence users can expect `go get` to behave as expected.
# Documentation
Please read the official documentation at: https://github.com/antlr/antlr4/blob/master/doc/index.md for tips on
migrating existing projects to use the new module location and for information on how to use the Go runtime in
general.

102
vendor/github.com/antlr4-go/antlr/v4/antlrdoc.go generated vendored Normal file
View File

@@ -0,0 +1,102 @@
/*
Package antlr implements the Go version of the ANTLR 4 runtime.
# The ANTLR Tool
ANTLR (ANother Tool for Language Recognition) is a powerful parser generator for reading, processing, executing,
or translating structured text or binary files. It's widely used to build languages, tools, and frameworks.
From a grammar, ANTLR generates a parser that can build parse trees and also generates a listener interface
(or visitor) that makes it easy to respond to the recognition of phrases of interest.
# Go Runtime
At version 4.11.x and prior, the Go runtime was not properly versioned for go modules. After this point, the runtime
source code to be imported was held in the `runtime/Go/antlr/v4` directory, and the go.mod file was updated to reflect the version of
ANTLR4 that it is compatible with (I.E. uses the /v4 path).
However, this was found to be problematic, as it meant that with the runtime embedded so far underneath the root
of the repo, the `go get` and related commands could not properly resolve the location of the go runtime source code.
This meant that the reference to the runtime in your `go.mod` file would refer to the correct source code, but would not
list the release tag such as @4.12.0 - this was confusing, to say the least.
As of 4.12.1, the runtime is now available as a go module in its own repo, and can be imported as `github.com/antlr4-go/antlr`
(the go get command should also be used with this path). See the main documentation for the ANTLR4 project for more information,
which is available at [ANTLR docs]. The documentation for using the Go runtime is available at [Go runtime docs].
This means that if you are using the source code without modules, you should also use the source code in the [new repo].
Though we highly recommend that you use go modules, as they are now idiomatic for Go.
I am aware that this change will prove Hyrum's Law, but am prepared to live with it for the common good.
Go runtime author: [Jim Idle] jimi@idle.ws
# Code Generation
ANTLR supports the generation of code in a number of [target languages], and the generated code is supported by a
runtime library, written specifically to support the generated code in the target language. This library is the
runtime for the Go target.
To generate code for the go target, it is generally recommended to place the source grammar files in a package of
their own, and use the `.sh` script method of generating code, using the go generate directive. In that same directory
it is usual, though not required, to place the antlr tool that should be used to generate the code. That does mean
that the antlr tool JAR file will be checked in to your source code control though, so you are, of course, free to use any other
way of specifying the version of the ANTLR tool to use, such as aliasing in `.zshrc` or equivalent, or a profile in
your IDE, or configuration in your CI system. Checking in the jar does mean that it is easy to reproduce the build as
it was at any point in its history.
Here is a general/recommended template for an ANTLR based recognizer in Go:
.
├── parser
│ ├── mygrammar.g4
│ ├── antlr-4.12.1-complete.jar
│ ├── generate.go
│ └── generate.sh
├── parsing - generated code goes here
│ └── error_listeners.go
├── go.mod
├── go.sum
├── main.go
└── main_test.go
Make sure that the package statement in your grammar file(s) reflects the go package the generated code will exist in.
The generate.go file then looks like this:
package parser
//go:generate ./generate.sh
And the generate.sh file will look similar to this:
#!/bin/sh
alias antlr4='java -Xmx500M -cp "./antlr4-4.12.1-complete.jar:$CLASSPATH" org.antlr.v4.Tool'
antlr4 -Dlanguage=Go -no-visitor -package parsing *.g4
depending on whether you want visitors or listeners or any other ANTLR options. Not that another option here
is to generate the code into a
From the command line at the root of your source package (location of go.mo)d) you can then simply issue the command:
go generate ./...
Which will generate the code for the parser, and place it in the parsing package. You can then use the generated code
by importing the parsing package.
There are no hard and fast rules on this. It is just a recommendation. You can generate the code in any way and to anywhere you like.
# Copyright Notice
Copyright (c) 2012-2023 The ANTLR Project. All rights reserved.
Use of this file is governed by the BSD 3-clause license, which can be found in the [LICENSE.txt] file in the project root.
[target languages]: https://github.com/antlr/antlr4/tree/master/runtime
[LICENSE.txt]: https://github.com/antlr/antlr4/blob/master/LICENSE.txt
[ANTLR docs]: https://github.com/antlr/antlr4/blob/master/doc/index.md
[new repo]: https://github.com/antlr4-go/antlr
[Jim Idle]: https://github.com/jimidle
[Go runtime docs]: https://github.com/antlr/antlr4/blob/master/doc/go-target.md
*/
package antlr

View File

@@ -20,10 +20,11 @@ var ATNInvalidAltNumber int
// [ALL(*)]: https://www.antlr.org/papers/allstar-techreport.pdf
// [Recursive Transition Network]: https://en.wikipedia.org/wiki/Recursive_transition_network
type ATN struct {
// DecisionToState is the decision points for all rules, subrules, optional
// blocks, ()+, ()*, etc. Each subrule/rule is a decision point, and we must track them so we
// DecisionToState is the decision points for all rules, sub-rules, optional
// blocks, ()+, ()*, etc. Each sub-rule/rule is a decision point, and we must track them, so we
// can go back later and build DFA predictors for them. This includes
// all the rules, subrules, optional blocks, ()+, ()* etc...
// all the rules, sub-rules, optional blocks, ()+, ()* etc...
DecisionToState []DecisionState
// grammarType is the ATN type and is used for deserializing ATNs from strings.
@@ -51,6 +52,8 @@ type ATN struct {
// specified, and otherwise is nil.
ruleToTokenType []int
// ATNStates is a list of all states in the ATN, ordered by state number.
//
states []ATNState
mu sync.Mutex

335
vendor/github.com/antlr4-go/antlr/v4/atn_config.go generated vendored Normal file
View File

@@ -0,0 +1,335 @@
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
package antlr
import (
"fmt"
)
const (
lexerConfig = iota // Indicates that this ATNConfig is for a lexer
parserConfig // Indicates that this ATNConfig is for a parser
)
// ATNConfig is a tuple: (ATN state, predicted alt, syntactic, semantic
// context). The syntactic context is a graph-structured stack node whose
// path(s) to the root is the rule invocation(s) chain used to arrive in the
// state. The semantic context is the tree of semantic predicates encountered
// before reaching an ATN state.
type ATNConfig struct {
precedenceFilterSuppressed bool
state ATNState
alt int
context *PredictionContext
semanticContext SemanticContext
reachesIntoOuterContext int
cType int // lexerConfig or parserConfig
lexerActionExecutor *LexerActionExecutor
passedThroughNonGreedyDecision bool
}
// NewATNConfig6 creates a new ATNConfig instance given a state, alt and context only
func NewATNConfig6(state ATNState, alt int, context *PredictionContext) *ATNConfig {
return NewATNConfig5(state, alt, context, SemanticContextNone)
}
// NewATNConfig5 creates a new ATNConfig instance given a state, alt, context and semantic context
func NewATNConfig5(state ATNState, alt int, context *PredictionContext, semanticContext SemanticContext) *ATNConfig {
if semanticContext == nil {
panic("semanticContext cannot be nil") // TODO: Necessary?
}
pac := &ATNConfig{}
pac.state = state
pac.alt = alt
pac.context = context
pac.semanticContext = semanticContext
pac.cType = parserConfig
return pac
}
// NewATNConfig4 creates a new ATNConfig instance given an existing config, and a state only
func NewATNConfig4(c *ATNConfig, state ATNState) *ATNConfig {
return NewATNConfig(c, state, c.GetContext(), c.GetSemanticContext())
}
// NewATNConfig3 creates a new ATNConfig instance given an existing config, a state and a semantic context
func NewATNConfig3(c *ATNConfig, state ATNState, semanticContext SemanticContext) *ATNConfig {
return NewATNConfig(c, state, c.GetContext(), semanticContext)
}
// NewATNConfig2 creates a new ATNConfig instance given an existing config, and a context only
func NewATNConfig2(c *ATNConfig, semanticContext SemanticContext) *ATNConfig {
return NewATNConfig(c, c.GetState(), c.GetContext(), semanticContext)
}
// NewATNConfig1 creates a new ATNConfig instance given an existing config, a state, and a context only
func NewATNConfig1(c *ATNConfig, state ATNState, context *PredictionContext) *ATNConfig {
return NewATNConfig(c, state, context, c.GetSemanticContext())
}
// NewATNConfig creates a new ATNConfig instance given an existing config, a state, a context and a semantic context, other 'constructors'
// are just wrappers around this one.
func NewATNConfig(c *ATNConfig, state ATNState, context *PredictionContext, semanticContext SemanticContext) *ATNConfig {
if semanticContext == nil {
panic("semanticContext cannot be nil") // TODO: Remove this - probably put here for some bug that is now fixed
}
b := &ATNConfig{}
b.InitATNConfig(c, state, c.GetAlt(), context, semanticContext)
b.cType = parserConfig
return b
}
func (a *ATNConfig) InitATNConfig(c *ATNConfig, state ATNState, alt int, context *PredictionContext, semanticContext SemanticContext) {
a.state = state
a.alt = alt
a.context = context
a.semanticContext = semanticContext
a.reachesIntoOuterContext = c.GetReachesIntoOuterContext()
a.precedenceFilterSuppressed = c.getPrecedenceFilterSuppressed()
}
func (a *ATNConfig) getPrecedenceFilterSuppressed() bool {
return a.precedenceFilterSuppressed
}
func (a *ATNConfig) setPrecedenceFilterSuppressed(v bool) {
a.precedenceFilterSuppressed = v
}
// GetState returns the ATN state associated with this configuration
func (a *ATNConfig) GetState() ATNState {
return a.state
}
// GetAlt returns the alternative associated with this configuration
func (a *ATNConfig) GetAlt() int {
return a.alt
}
// SetContext sets the rule invocation stack associated with this configuration
func (a *ATNConfig) SetContext(v *PredictionContext) {
a.context = v
}
// GetContext returns the rule invocation stack associated with this configuration
func (a *ATNConfig) GetContext() *PredictionContext {
return a.context
}
// GetSemanticContext returns the semantic context associated with this configuration
func (a *ATNConfig) GetSemanticContext() SemanticContext {
return a.semanticContext
}
// GetReachesIntoOuterContext returns the count of references to an outer context from this configuration
func (a *ATNConfig) GetReachesIntoOuterContext() int {
return a.reachesIntoOuterContext
}
// SetReachesIntoOuterContext sets the count of references to an outer context from this configuration
func (a *ATNConfig) SetReachesIntoOuterContext(v int) {
a.reachesIntoOuterContext = v
}
// Equals is the default comparison function for an ATNConfig when no specialist implementation is required
// for a collection.
//
// An ATN configuration is equal to another if both have the same state, they
// predict the same alternative, and syntactic/semantic contexts are the same.
func (a *ATNConfig) Equals(o Collectable[*ATNConfig]) bool {
switch a.cType {
case lexerConfig:
return a.LEquals(o)
case parserConfig:
return a.PEquals(o)
default:
panic("Invalid ATNConfig type")
}
}
// PEquals is the default comparison function for a Parser ATNConfig when no specialist implementation is required
// for a collection.
//
// An ATN configuration is equal to another if both have the same state, they
// predict the same alternative, and syntactic/semantic contexts are the same.
func (a *ATNConfig) PEquals(o Collectable[*ATNConfig]) bool {
var other, ok = o.(*ATNConfig)
if !ok {
return false
}
if a == other {
return true
} else if other == nil {
return false
}
var equal bool
if a.context == nil {
equal = other.context == nil
} else {
equal = a.context.Equals(other.context)
}
var (
nums = a.state.GetStateNumber() == other.state.GetStateNumber()
alts = a.alt == other.alt
cons = a.semanticContext.Equals(other.semanticContext)
sups = a.precedenceFilterSuppressed == other.precedenceFilterSuppressed
)
return nums && alts && cons && sups && equal
}
// Hash is the default hash function for a parser ATNConfig, when no specialist hash function
// is required for a collection
func (a *ATNConfig) Hash() int {
switch a.cType {
case lexerConfig:
return a.LHash()
case parserConfig:
return a.PHash()
default:
panic("Invalid ATNConfig type")
}
}
// PHash is the default hash function for a parser ATNConfig, when no specialist hash function
// is required for a collection
func (a *ATNConfig) PHash() int {
var c int
if a.context != nil {
c = a.context.Hash()
}
h := murmurInit(7)
h = murmurUpdate(h, a.state.GetStateNumber())
h = murmurUpdate(h, a.alt)
h = murmurUpdate(h, c)
h = murmurUpdate(h, a.semanticContext.Hash())
return murmurFinish(h, 4)
}
// String returns a string representation of the ATNConfig, usually used for debugging purposes
func (a *ATNConfig) String() string {
var s1, s2, s3 string
if a.context != nil {
s1 = ",[" + fmt.Sprint(a.context) + "]"
}
if a.semanticContext != SemanticContextNone {
s2 = "," + fmt.Sprint(a.semanticContext)
}
if a.reachesIntoOuterContext > 0 {
s3 = ",up=" + fmt.Sprint(a.reachesIntoOuterContext)
}
return fmt.Sprintf("(%v,%v%v%v%v)", a.state, a.alt, s1, s2, s3)
}
func NewLexerATNConfig6(state ATNState, alt int, context *PredictionContext) *ATNConfig {
lac := &ATNConfig{}
lac.state = state
lac.alt = alt
lac.context = context
lac.semanticContext = SemanticContextNone
lac.cType = lexerConfig
return lac
}
func NewLexerATNConfig4(c *ATNConfig, state ATNState) *ATNConfig {
lac := &ATNConfig{}
lac.lexerActionExecutor = c.lexerActionExecutor
lac.passedThroughNonGreedyDecision = checkNonGreedyDecision(c, state)
lac.InitATNConfig(c, state, c.GetAlt(), c.GetContext(), c.GetSemanticContext())
lac.cType = lexerConfig
return lac
}
func NewLexerATNConfig3(c *ATNConfig, state ATNState, lexerActionExecutor *LexerActionExecutor) *ATNConfig {
lac := &ATNConfig{}
lac.lexerActionExecutor = lexerActionExecutor
lac.passedThroughNonGreedyDecision = checkNonGreedyDecision(c, state)
lac.InitATNConfig(c, state, c.GetAlt(), c.GetContext(), c.GetSemanticContext())
lac.cType = lexerConfig
return lac
}
func NewLexerATNConfig2(c *ATNConfig, state ATNState, context *PredictionContext) *ATNConfig {
lac := &ATNConfig{}
lac.lexerActionExecutor = c.lexerActionExecutor
lac.passedThroughNonGreedyDecision = checkNonGreedyDecision(c, state)
lac.InitATNConfig(c, state, c.GetAlt(), context, c.GetSemanticContext())
lac.cType = lexerConfig
return lac
}
//goland:noinspection GoUnusedExportedFunction
func NewLexerATNConfig1(state ATNState, alt int, context *PredictionContext) *ATNConfig {
lac := &ATNConfig{}
lac.state = state
lac.alt = alt
lac.context = context
lac.semanticContext = SemanticContextNone
lac.cType = lexerConfig
return lac
}
// LHash is the default hash function for Lexer ATNConfig objects, it can be used directly or via
// the default comparator [ObjEqComparator].
func (a *ATNConfig) LHash() int {
var f int
if a.passedThroughNonGreedyDecision {
f = 1
} else {
f = 0
}
h := murmurInit(7)
h = murmurUpdate(h, a.state.GetStateNumber())
h = murmurUpdate(h, a.alt)
h = murmurUpdate(h, a.context.Hash())
h = murmurUpdate(h, a.semanticContext.Hash())
h = murmurUpdate(h, f)
h = murmurUpdate(h, a.lexerActionExecutor.Hash())
h = murmurFinish(h, 6)
return h
}
// LEquals is the default comparison function for Lexer ATNConfig objects, it can be used directly or via
// the default comparator [ObjEqComparator].
func (a *ATNConfig) LEquals(other Collectable[*ATNConfig]) bool {
var otherT, ok = other.(*ATNConfig)
if !ok {
return false
} else if a == otherT {
return true
} else if a.passedThroughNonGreedyDecision != otherT.passedThroughNonGreedyDecision {
return false
}
switch {
case a.lexerActionExecutor == nil && otherT.lexerActionExecutor == nil:
return true
case a.lexerActionExecutor != nil && otherT.lexerActionExecutor != nil:
if !a.lexerActionExecutor.Equals(otherT.lexerActionExecutor) {
return false
}
default:
return false // One but not both, are nil
}
return a.PEquals(otherT)
}
func checkNonGreedyDecision(source *ATNConfig, target ATNState) bool {
var ds, ok = target.(DecisionState)
return source.passedThroughNonGreedyDecision || (ok && ds.getNonGreedy())
}

301
vendor/github.com/antlr4-go/antlr/v4/atn_config_set.go generated vendored Normal file
View File

@@ -0,0 +1,301 @@
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
package antlr
import (
"fmt"
)
// ATNConfigSet is a specialized set of ATNConfig that tracks information
// about its elements and can combine similar configurations using a
// graph-structured stack.
type ATNConfigSet struct {
cachedHash int
// configLookup is used to determine whether two ATNConfigSets are equal. We
// need all configurations with the same (s, i, _, semctx) to be equal. A key
// effectively doubles the number of objects associated with ATNConfigs. All
// keys are hashed by (s, i, _, pi), not including the context. Wiped out when
// read-only because a set becomes a DFA state.
configLookup *JStore[*ATNConfig, Comparator[*ATNConfig]]
// configs is the added elements that did not match an existing key in configLookup
configs []*ATNConfig
// TODO: These fields make me pretty uncomfortable, but it is nice to pack up
// info together because it saves re-computation. Can we track conflicts as they
// are added to save scanning configs later?
conflictingAlts *BitSet
// dipsIntoOuterContext is used by parsers and lexers. In a lexer, it indicates
// we hit a pred while computing a closure operation. Do not make a DFA state
// from the ATNConfigSet in this case. TODO: How is this used by parsers?
dipsIntoOuterContext bool
// fullCtx is whether it is part of a full context LL prediction. Used to
// determine how to merge $. It is a wildcard with SLL, but not for an LL
// context merge.
fullCtx bool
// Used in parser and lexer. In lexer, it indicates we hit a pred
// while computing a closure operation. Don't make a DFA state from this set.
hasSemanticContext bool
// readOnly is whether it is read-only. Do not
// allow any code to manipulate the set if true because DFA states will point at
// sets and those must not change. It not, protect other fields; conflictingAlts
// in particular, which is assigned after readOnly.
readOnly bool
// TODO: These fields make me pretty uncomfortable, but it is nice to pack up
// info together because it saves re-computation. Can we track conflicts as they
// are added to save scanning configs later?
uniqueAlt int
}
// Alts returns the combined set of alts for all the configurations in this set.
func (b *ATNConfigSet) Alts() *BitSet {
alts := NewBitSet()
for _, it := range b.configs {
alts.add(it.GetAlt())
}
return alts
}
// NewATNConfigSet creates a new ATNConfigSet instance.
func NewATNConfigSet(fullCtx bool) *ATNConfigSet {
return &ATNConfigSet{
cachedHash: -1,
configLookup: NewJStore[*ATNConfig, Comparator[*ATNConfig]](aConfCompInst, ATNConfigLookupCollection, "NewATNConfigSet()"),
fullCtx: fullCtx,
}
}
// Add merges contexts with existing configs for (s, i, pi, _),
// where 's' is the ATNConfig.state, 'i' is the ATNConfig.alt, and
// 'pi' is the [ATNConfig].semanticContext.
//
// We use (s,i,pi) as the key.
// Updates dipsIntoOuterContext and hasSemanticContext when necessary.
func (b *ATNConfigSet) Add(config *ATNConfig, mergeCache *JPCMap) bool {
if b.readOnly {
panic("set is read-only")
}
if config.GetSemanticContext() != SemanticContextNone {
b.hasSemanticContext = true
}
if config.GetReachesIntoOuterContext() > 0 {
b.dipsIntoOuterContext = true
}
existing, present := b.configLookup.Put(config)
// The config was not already in the set
//
if !present {
b.cachedHash = -1
b.configs = append(b.configs, config) // Track order here
return true
}
// Merge a previous (s, i, pi, _) with it and save the result
rootIsWildcard := !b.fullCtx
merged := merge(existing.GetContext(), config.GetContext(), rootIsWildcard, mergeCache)
// No need to check for existing.context because config.context is in the cache,
// since the only way to create new graphs is the "call rule" and here. We cache
// at both places.
existing.SetReachesIntoOuterContext(intMax(existing.GetReachesIntoOuterContext(), config.GetReachesIntoOuterContext()))
// Preserve the precedence filter suppression during the merge
if config.getPrecedenceFilterSuppressed() {
existing.setPrecedenceFilterSuppressed(true)
}
// Replace the context because there is no need to do alt mapping
existing.SetContext(merged)
return true
}
// GetStates returns the set of states represented by all configurations in this config set
func (b *ATNConfigSet) GetStates() *JStore[ATNState, Comparator[ATNState]] {
// states uses the standard comparator and Hash() provided by the ATNState instance
//
states := NewJStore[ATNState, Comparator[ATNState]](aStateEqInst, ATNStateCollection, "ATNConfigSet.GetStates()")
for i := 0; i < len(b.configs); i++ {
states.Put(b.configs[i].GetState())
}
return states
}
func (b *ATNConfigSet) GetPredicates() []SemanticContext {
predicates := make([]SemanticContext, 0)
for i := 0; i < len(b.configs); i++ {
c := b.configs[i].GetSemanticContext()
if c != SemanticContextNone {
predicates = append(predicates, c)
}
}
return predicates
}
func (b *ATNConfigSet) OptimizeConfigs(interpreter *BaseATNSimulator) {
if b.readOnly {
panic("set is read-only")
}
// Empty indicate no optimization is possible
if b.configLookup == nil || b.configLookup.Len() == 0 {
return
}
for i := 0; i < len(b.configs); i++ {
config := b.configs[i]
config.SetContext(interpreter.getCachedContext(config.GetContext()))
}
}
func (b *ATNConfigSet) AddAll(coll []*ATNConfig) bool {
for i := 0; i < len(coll); i++ {
b.Add(coll[i], nil)
}
return false
}
// Compare The configs are only equal if they are in the same order and their Equals function returns true.
// Java uses ArrayList.equals(), which requires the same order.
func (b *ATNConfigSet) Compare(bs *ATNConfigSet) bool {
if len(b.configs) != len(bs.configs) {
return false
}
for i := 0; i < len(b.configs); i++ {
if !b.configs[i].Equals(bs.configs[i]) {
return false
}
}
return true
}
func (b *ATNConfigSet) Equals(other Collectable[ATNConfig]) bool {
if b == other {
return true
} else if _, ok := other.(*ATNConfigSet); !ok {
return false
}
other2 := other.(*ATNConfigSet)
var eca bool
switch {
case b.conflictingAlts == nil && other2.conflictingAlts == nil:
eca = true
case b.conflictingAlts != nil && other2.conflictingAlts != nil:
eca = b.conflictingAlts.equals(other2.conflictingAlts)
}
return b.configs != nil &&
b.fullCtx == other2.fullCtx &&
b.uniqueAlt == other2.uniqueAlt &&
eca &&
b.hasSemanticContext == other2.hasSemanticContext &&
b.dipsIntoOuterContext == other2.dipsIntoOuterContext &&
b.Compare(other2)
}
func (b *ATNConfigSet) Hash() int {
if b.readOnly {
if b.cachedHash == -1 {
b.cachedHash = b.hashCodeConfigs()
}
return b.cachedHash
}
return b.hashCodeConfigs()
}
func (b *ATNConfigSet) hashCodeConfigs() int {
h := 1
for _, config := range b.configs {
h = 31*h + config.Hash()
}
return h
}
func (b *ATNConfigSet) Contains(item *ATNConfig) bool {
if b.readOnly {
panic("not implemented for read-only sets")
}
if b.configLookup == nil {
return false
}
return b.configLookup.Contains(item)
}
func (b *ATNConfigSet) ContainsFast(item *ATNConfig) bool {
return b.Contains(item)
}
func (b *ATNConfigSet) Clear() {
if b.readOnly {
panic("set is read-only")
}
b.configs = make([]*ATNConfig, 0)
b.cachedHash = -1
b.configLookup = NewJStore[*ATNConfig, Comparator[*ATNConfig]](aConfCompInst, ATNConfigLookupCollection, "NewATNConfigSet()")
}
func (b *ATNConfigSet) String() string {
s := "["
for i, c := range b.configs {
s += c.String()
if i != len(b.configs)-1 {
s += ", "
}
}
s += "]"
if b.hasSemanticContext {
s += ",hasSemanticContext=" + fmt.Sprint(b.hasSemanticContext)
}
if b.uniqueAlt != ATNInvalidAltNumber {
s += ",uniqueAlt=" + fmt.Sprint(b.uniqueAlt)
}
if b.conflictingAlts != nil {
s += ",conflictingAlts=" + b.conflictingAlts.String()
}
if b.dipsIntoOuterContext {
s += ",dipsIntoOuterContext"
}
return s
}
// NewOrderedATNConfigSet creates a config set with a slightly different Hash/Equal pair
// for use in lexers.
func NewOrderedATNConfigSet() *ATNConfigSet {
return &ATNConfigSet{
cachedHash: -1,
// This set uses the standard Hash() and Equals() from ATNConfig
configLookup: NewJStore[*ATNConfig, Comparator[*ATNConfig]](aConfEqInst, ATNConfigCollection, "ATNConfigSet.NewOrderedATNConfigSet()"),
fullCtx: false,
}
}

View File

@@ -20,7 +20,7 @@ func (opts *ATNDeserializationOptions) ReadOnly() bool {
func (opts *ATNDeserializationOptions) SetReadOnly(readOnly bool) {
if opts.readOnly {
panic(errors.New("Cannot mutate read only ATNDeserializationOptions"))
panic(errors.New("cannot mutate read only ATNDeserializationOptions"))
}
opts.readOnly = readOnly
}
@@ -31,7 +31,7 @@ func (opts *ATNDeserializationOptions) VerifyATN() bool {
func (opts *ATNDeserializationOptions) SetVerifyATN(verifyATN bool) {
if opts.readOnly {
panic(errors.New("Cannot mutate read only ATNDeserializationOptions"))
panic(errors.New("cannot mutate read only ATNDeserializationOptions"))
}
opts.verifyATN = verifyATN
}
@@ -42,11 +42,12 @@ func (opts *ATNDeserializationOptions) GenerateRuleBypassTransitions() bool {
func (opts *ATNDeserializationOptions) SetGenerateRuleBypassTransitions(generateRuleBypassTransitions bool) {
if opts.readOnly {
panic(errors.New("Cannot mutate read only ATNDeserializationOptions"))
panic(errors.New("cannot mutate read only ATNDeserializationOptions"))
}
opts.generateRuleBypassTransitions = generateRuleBypassTransitions
}
//goland:noinspection GoUnusedExportedFunction
func DefaultATNDeserializationOptions() *ATNDeserializationOptions {
return NewATNDeserializationOptions(&defaultATNDeserializationOptions)
}

View File

@@ -35,6 +35,7 @@ func NewATNDeserializer(options *ATNDeserializationOptions) *ATNDeserializer {
return &ATNDeserializer{options: options}
}
//goland:noinspection GoUnusedFunction
func stringInSlice(a string, list []string) int {
for i, b := range list {
if b == a {
@@ -193,7 +194,7 @@ func (a *ATNDeserializer) readModes(atn *ATN) {
}
}
func (a *ATNDeserializer) readSets(atn *ATN, sets []*IntervalSet) []*IntervalSet {
func (a *ATNDeserializer) readSets(_ *ATN, sets []*IntervalSet) []*IntervalSet {
m := a.readInt()
// Preallocate the needed capacity.
@@ -350,7 +351,7 @@ func (a *ATNDeserializer) generateRuleBypassTransition(atn *ATN, idx int) {
bypassStart.endState = bypassStop
atn.defineDecisionState(bypassStart.BaseDecisionState)
atn.defineDecisionState(&bypassStart.BaseDecisionState)
bypassStop.startState = bypassStart
@@ -450,7 +451,7 @@ func (a *ATNDeserializer) markPrecedenceDecisions(atn *ATN) {
continue
}
// We analyze the ATN to determine if a ATN decision state is the
// We analyze the [ATN] to determine if an ATN decision state is the
// decision for the closure block that determines whether a
// precedence rule should continue or complete.
if atn.ruleToStartState[state.GetRuleIndex()].isPrecedenceRule {
@@ -553,7 +554,7 @@ func (a *ATNDeserializer) readInt() int {
return int(v) // data is 32 bits but int is at least that big
}
func (a *ATNDeserializer) edgeFactory(atn *ATN, typeIndex, src, trg, arg1, arg2, arg3 int, sets []*IntervalSet) Transition {
func (a *ATNDeserializer) edgeFactory(atn *ATN, typeIndex, _, trg, arg1, arg2, arg3 int, sets []*IntervalSet) Transition {
target := atn.states[trg]
switch typeIndex {

View File

@@ -4,7 +4,7 @@
package antlr
var ATNSimulatorError = NewDFAState(0x7FFFFFFF, NewBaseATNConfigSet(false))
var ATNSimulatorError = NewDFAState(0x7FFFFFFF, NewATNConfigSet(false))
type IATNSimulator interface {
SharedContextCache() *PredictionContextCache
@@ -18,22 +18,13 @@ type BaseATNSimulator struct {
decisionToDFA []*DFA
}
func NewBaseATNSimulator(atn *ATN, sharedContextCache *PredictionContextCache) *BaseATNSimulator {
b := new(BaseATNSimulator)
b.atn = atn
b.sharedContextCache = sharedContextCache
return b
}
func (b *BaseATNSimulator) getCachedContext(context PredictionContext) PredictionContext {
func (b *BaseATNSimulator) getCachedContext(context *PredictionContext) *PredictionContext {
if b.sharedContextCache == nil {
return context
}
visited := make(map[PredictionContext]PredictionContext)
//visited := NewJMap[*PredictionContext, *PredictionContext, Comparator[*PredictionContext]](pContextEqInst, PredictionVisitedCollection, "Visit map in getCachedContext()")
visited := NewVisitRecord()
return getCachedBasePredictionContext(context, b.sharedContextCache, visited)
}

View File

@@ -4,7 +4,11 @@
package antlr
import "strconv"
import (
"fmt"
"os"
"strconv"
)
// Constants for serialization.
const (
@@ -25,6 +29,7 @@ const (
ATNStateInvalidStateNumber = -1
)
//goland:noinspection GoUnusedGlobalVariable
var ATNStateInitialNumTransitions = 4
type ATNState interface {
@@ -73,7 +78,7 @@ type BaseATNState struct {
transitions []Transition
}
func NewBaseATNState() *BaseATNState {
func NewATNState() *BaseATNState {
return &BaseATNState{stateNumber: ATNStateInvalidStateNumber, stateType: ATNStateInvalidType}
}
@@ -148,27 +153,46 @@ func (as *BaseATNState) AddTransition(trans Transition, index int) {
if len(as.transitions) == 0 {
as.epsilonOnlyTransitions = trans.getIsEpsilon()
} else if as.epsilonOnlyTransitions != trans.getIsEpsilon() {
_, _ = fmt.Fprintf(os.Stdin, "ATN state %d has both epsilon and non-epsilon transitions.\n", as.stateNumber)
as.epsilonOnlyTransitions = false
}
// TODO: Check code for already present compared to the Java equivalent
//alreadyPresent := false
//for _, t := range as.transitions {
// if t.getTarget().GetStateNumber() == trans.getTarget().GetStateNumber() {
// if t.getLabel() != nil && trans.getLabel() != nil && trans.getLabel().Equals(t.getLabel()) {
// alreadyPresent = true
// break
// }
// } else if t.getIsEpsilon() && trans.getIsEpsilon() {
// alreadyPresent = true
// break
// }
//}
//if !alreadyPresent {
if index == -1 {
as.transitions = append(as.transitions, trans)
} else {
as.transitions = append(as.transitions[:index], append([]Transition{trans}, as.transitions[index:]...)...)
// TODO: as.transitions.splice(index, 1, trans)
}
//} else {
// _, _ = fmt.Fprintf(os.Stderr, "Transition already present in state %d\n", as.stateNumber)
//}
}
type BasicState struct {
*BaseATNState
BaseATNState
}
func NewBasicState() *BasicState {
b := NewBaseATNState()
b.stateType = ATNStateBasic
return &BasicState{BaseATNState: b}
return &BasicState{
BaseATNState: BaseATNState{
stateNumber: ATNStateInvalidStateNumber,
stateType: ATNStateBasic,
},
}
}
type DecisionState interface {
@@ -182,13 +206,19 @@ type DecisionState interface {
}
type BaseDecisionState struct {
*BaseATNState
BaseATNState
decision int
nonGreedy bool
}
func NewBaseDecisionState() *BaseDecisionState {
return &BaseDecisionState{BaseATNState: NewBaseATNState(), decision: -1}
return &BaseDecisionState{
BaseATNState: BaseATNState{
stateNumber: ATNStateInvalidStateNumber,
stateType: ATNStateBasic,
},
decision: -1,
}
}
func (s *BaseDecisionState) getDecision() int {
@@ -216,12 +246,20 @@ type BlockStartState interface {
// BaseBlockStartState is the start of a regular (...) block.
type BaseBlockStartState struct {
*BaseDecisionState
BaseDecisionState
endState *BlockEndState
}
func NewBlockStartState() *BaseBlockStartState {
return &BaseBlockStartState{BaseDecisionState: NewBaseDecisionState()}
return &BaseBlockStartState{
BaseDecisionState: BaseDecisionState{
BaseATNState: BaseATNState{
stateNumber: ATNStateInvalidStateNumber,
stateType: ATNStateBasic,
},
decision: -1,
},
}
}
func (s *BaseBlockStartState) getEndState() *BlockEndState {
@@ -233,31 +271,38 @@ func (s *BaseBlockStartState) setEndState(b *BlockEndState) {
}
type BasicBlockStartState struct {
*BaseBlockStartState
BaseBlockStartState
}
func NewBasicBlockStartState() *BasicBlockStartState {
b := NewBlockStartState()
b.stateType = ATNStateBlockStart
return &BasicBlockStartState{BaseBlockStartState: b}
return &BasicBlockStartState{
BaseBlockStartState: BaseBlockStartState{
BaseDecisionState: BaseDecisionState{
BaseATNState: BaseATNState{
stateNumber: ATNStateInvalidStateNumber,
stateType: ATNStateBlockStart,
},
},
},
}
}
var _ BlockStartState = &BasicBlockStartState{}
// BlockEndState is a terminal node of a simple (a|b|c) block.
type BlockEndState struct {
*BaseATNState
BaseATNState
startState ATNState
}
func NewBlockEndState() *BlockEndState {
b := NewBaseATNState()
b.stateType = ATNStateBlockEnd
return &BlockEndState{BaseATNState: b}
return &BlockEndState{
BaseATNState: BaseATNState{
stateNumber: ATNStateInvalidStateNumber,
stateType: ATNStateBlockEnd,
},
startState: nil,
}
}
// RuleStopState is the last node in the ATN for a rule, unless that rule is the
@@ -265,43 +310,48 @@ func NewBlockEndState() *BlockEndState {
// encode references to all calls to this rule to compute FOLLOW sets for error
// handling.
type RuleStopState struct {
*BaseATNState
BaseATNState
}
func NewRuleStopState() *RuleStopState {
b := NewBaseATNState()
b.stateType = ATNStateRuleStop
return &RuleStopState{BaseATNState: b}
return &RuleStopState{
BaseATNState: BaseATNState{
stateNumber: ATNStateInvalidStateNumber,
stateType: ATNStateRuleStop,
},
}
}
type RuleStartState struct {
*BaseATNState
BaseATNState
stopState ATNState
isPrecedenceRule bool
}
func NewRuleStartState() *RuleStartState {
b := NewBaseATNState()
b.stateType = ATNStateRuleStart
return &RuleStartState{BaseATNState: b}
return &RuleStartState{
BaseATNState: BaseATNState{
stateNumber: ATNStateInvalidStateNumber,
stateType: ATNStateRuleStart,
},
}
}
// PlusLoopbackState is a decision state for A+ and (A|B)+. It has two
// transitions: one to the loop back to start of the block, and one to exit.
type PlusLoopbackState struct {
*BaseDecisionState
BaseDecisionState
}
func NewPlusLoopbackState() *PlusLoopbackState {
b := NewBaseDecisionState()
b.stateType = ATNStatePlusLoopBack
return &PlusLoopbackState{BaseDecisionState: b}
return &PlusLoopbackState{
BaseDecisionState: BaseDecisionState{
BaseATNState: BaseATNState{
stateNumber: ATNStateInvalidStateNumber,
stateType: ATNStatePlusLoopBack,
},
},
}
}
// PlusBlockStartState is the start of a (A|B|...)+ loop. Technically it is a
@@ -309,85 +359,103 @@ func NewPlusLoopbackState() *PlusLoopbackState {
// it is included for completeness. In reality, PlusLoopbackState is the real
// decision-making node for A+.
type PlusBlockStartState struct {
*BaseBlockStartState
BaseBlockStartState
loopBackState ATNState
}
func NewPlusBlockStartState() *PlusBlockStartState {
b := NewBlockStartState()
b.stateType = ATNStatePlusBlockStart
return &PlusBlockStartState{BaseBlockStartState: b}
return &PlusBlockStartState{
BaseBlockStartState: BaseBlockStartState{
BaseDecisionState: BaseDecisionState{
BaseATNState: BaseATNState{
stateNumber: ATNStateInvalidStateNumber,
stateType: ATNStatePlusBlockStart,
},
},
},
}
}
var _ BlockStartState = &PlusBlockStartState{}
// StarBlockStartState is the block that begins a closure loop.
type StarBlockStartState struct {
*BaseBlockStartState
BaseBlockStartState
}
func NewStarBlockStartState() *StarBlockStartState {
b := NewBlockStartState()
b.stateType = ATNStateStarBlockStart
return &StarBlockStartState{BaseBlockStartState: b}
return &StarBlockStartState{
BaseBlockStartState: BaseBlockStartState{
BaseDecisionState: BaseDecisionState{
BaseATNState: BaseATNState{
stateNumber: ATNStateInvalidStateNumber,
stateType: ATNStateStarBlockStart,
},
},
},
}
}
var _ BlockStartState = &StarBlockStartState{}
type StarLoopbackState struct {
*BaseATNState
BaseATNState
}
func NewStarLoopbackState() *StarLoopbackState {
b := NewBaseATNState()
b.stateType = ATNStateStarLoopBack
return &StarLoopbackState{BaseATNState: b}
return &StarLoopbackState{
BaseATNState: BaseATNState{
stateNumber: ATNStateInvalidStateNumber,
stateType: ATNStateStarLoopBack,
},
}
}
type StarLoopEntryState struct {
*BaseDecisionState
BaseDecisionState
loopBackState ATNState
precedenceRuleDecision bool
}
func NewStarLoopEntryState() *StarLoopEntryState {
b := NewBaseDecisionState()
b.stateType = ATNStateStarLoopEntry
// False precedenceRuleDecision indicates whether s state can benefit from a precedence DFA during SLL decision making.
return &StarLoopEntryState{BaseDecisionState: b}
return &StarLoopEntryState{
BaseDecisionState: BaseDecisionState{
BaseATNState: BaseATNState{
stateNumber: ATNStateInvalidStateNumber,
stateType: ATNStateStarLoopEntry,
},
},
}
}
// LoopEndState marks the end of a * or + loop.
type LoopEndState struct {
*BaseATNState
BaseATNState
loopBackState ATNState
}
func NewLoopEndState() *LoopEndState {
b := NewBaseATNState()
b.stateType = ATNStateLoopEnd
return &LoopEndState{BaseATNState: b}
return &LoopEndState{
BaseATNState: BaseATNState{
stateNumber: ATNStateInvalidStateNumber,
stateType: ATNStateLoopEnd,
},
}
}
// TokensStartState is the Tokens rule start state linking to each lexer rule start state.
type TokensStartState struct {
*BaseDecisionState
BaseDecisionState
}
func NewTokensStartState() *TokensStartState {
b := NewBaseDecisionState()
b.stateType = ATNStateTokenStart
return &TokensStartState{BaseDecisionState: b}
return &TokensStartState{
BaseDecisionState: BaseDecisionState{
BaseATNState: BaseATNState{
stateNumber: ATNStateInvalidStateNumber,
stateType: ATNStateTokenStart,
},
},
}
}

View File

@@ -8,5 +8,5 @@ type CharStream interface {
IntStream
GetText(int, int) string
GetTextFromTokens(start, end Token) string
GetTextFromInterval(*Interval) string
GetTextFromInterval(Interval) string
}

View File

@@ -28,22 +28,24 @@ type CommonTokenStream struct {
// trivial with bt field.
fetchedEOF bool
// index indexs into tokens of the current token (next token to consume).
// index into [tokens] of the current token (next token to consume).
// tokens[p] should be LT(1). It is set to -1 when the stream is first
// constructed or when SetTokenSource is called, indicating that the first token
// has not yet been fetched from the token source. For additional information,
// see the documentation of IntStream for a description of initializing methods.
// see the documentation of [IntStream] for a description of initializing methods.
index int
// tokenSource is the TokenSource from which tokens for the bt stream are
// tokenSource is the [TokenSource] from which tokens for the bt stream are
// fetched.
tokenSource TokenSource
// tokens is all tokens fetched from the token source. The list is considered a
// tokens contains all tokens fetched from the token source. The list is considered a
// complete view of the input once fetchedEOF is set to true.
tokens []Token
}
// NewCommonTokenStream creates a new CommonTokenStream instance using the supplied lexer to produce
// tokens and will pull tokens from the given lexer channel.
func NewCommonTokenStream(lexer Lexer, channel int) *CommonTokenStream {
return &CommonTokenStream{
channel: channel,
@@ -53,6 +55,7 @@ func NewCommonTokenStream(lexer Lexer, channel int) *CommonTokenStream {
}
}
// GetAllTokens returns all tokens currently pulled from the token source.
func (c *CommonTokenStream) GetAllTokens() []Token {
return c.tokens
}
@@ -61,9 +64,11 @@ func (c *CommonTokenStream) Mark() int {
return 0
}
func (c *CommonTokenStream) Release(marker int) {}
func (c *CommonTokenStream) Release(_ int) {}
func (c *CommonTokenStream) reset() {
func (c *CommonTokenStream) Reset() {
c.fetchedEOF = false
c.tokens = make([]Token, 0)
c.Seek(0)
}
@@ -107,7 +112,7 @@ func (c *CommonTokenStream) Consume() {
// Sync makes sure index i in tokens has a token and returns true if a token is
// located at index i and otherwise false.
func (c *CommonTokenStream) Sync(i int) bool {
n := i - len(c.tokens) + 1 // TODO: How many more elements do we need?
n := i - len(c.tokens) + 1 // How many more elements do we need?
if n > 0 {
fetched := c.fetch(n)
@@ -193,12 +198,13 @@ func (c *CommonTokenStream) SetTokenSource(tokenSource TokenSource) {
c.tokenSource = tokenSource
c.tokens = make([]Token, 0)
c.index = -1
c.fetchedEOF = false
}
// NextTokenOnChannel returns the index of the next token on channel given a
// starting index. Returns i if tokens[i] is on channel. Returns -1 if there are
// no tokens on channel between i and EOF.
func (c *CommonTokenStream) NextTokenOnChannel(i, channel int) int {
// no tokens on channel between 'i' and [TokenEOF].
func (c *CommonTokenStream) NextTokenOnChannel(i, _ int) int {
c.Sync(i)
if i >= len(c.tokens) {
@@ -244,7 +250,7 @@ func (c *CommonTokenStream) GetHiddenTokensToRight(tokenIndex, channel int) []To
nextOnChannel := c.NextTokenOnChannel(tokenIndex+1, LexerDefaultTokenChannel)
from := tokenIndex + 1
// If no onchannel to the right, then nextOnChannel == -1, so set to to last token
// If no onChannel to the right, then nextOnChannel == -1, so set 'to' to the last token
var to int
if nextOnChannel == -1 {
@@ -314,7 +320,8 @@ func (c *CommonTokenStream) Index() int {
}
func (c *CommonTokenStream) GetAllText() string {
return c.GetTextFromInterval(nil)
c.Fill()
return c.GetTextFromInterval(NewInterval(0, len(c.tokens)-1))
}
func (c *CommonTokenStream) GetTextFromTokens(start, end Token) string {
@@ -329,15 +336,9 @@ func (c *CommonTokenStream) GetTextFromRuleContext(interval RuleContext) string
return c.GetTextFromInterval(interval.GetSourceInterval())
}
func (c *CommonTokenStream) GetTextFromInterval(interval *Interval) string {
func (c *CommonTokenStream) GetTextFromInterval(interval Interval) string {
c.lazyInit()
if interval == nil {
c.Fill()
interval = NewInterval(0, len(c.tokens)-1)
} else {
c.Sync(interval.Stop)
}
c.Sync(interval.Stop)
start := interval.Start
stop := interval.Stop

View File

@@ -18,17 +18,20 @@ package antlr
// type safety and avoid having to implement this for every type that we want to perform comparison on.
//
// This comparator works by using the standard Hash() and Equals() methods of the type T that is being compared. Which
// allows us to use it in any collection instance that does nto require a special hash or equals implementation.
// allows us to use it in any collection instance that does not require a special hash or equals implementation.
type ObjEqComparator[T Collectable[T]] struct{}
var (
aStateEqInst = &ObjEqComparator[ATNState]{}
aConfEqInst = &ObjEqComparator[ATNConfig]{}
aConfCompInst = &ATNConfigComparator[ATNConfig]{}
atnConfCompInst = &BaseATNConfigComparator[ATNConfig]{}
aStateEqInst = &ObjEqComparator[ATNState]{}
aConfEqInst = &ObjEqComparator[*ATNConfig]{}
// aConfCompInst is the comparator used for the ATNConfigSet for the configLookup cache
aConfCompInst = &ATNConfigComparator[*ATNConfig]{}
atnConfCompInst = &BaseATNConfigComparator[*ATNConfig]{}
dfaStateEqInst = &ObjEqComparator[*DFAState]{}
semctxEqInst = &ObjEqComparator[SemanticContext]{}
atnAltCfgEqInst = &ATNAltConfigComparator[ATNConfig]{}
atnAltCfgEqInst = &ATNAltConfigComparator[*ATNConfig]{}
pContextEqInst = &ObjEqComparator[*PredictionContext]{}
)
// Equals2 delegates to the Equals() method of type T
@@ -44,14 +47,14 @@ func (c *ObjEqComparator[T]) Hash1(o T) int {
type SemCComparator[T Collectable[T]] struct{}
// ATNConfigComparator is used as the compartor for the configLookup field of an ATNConfigSet
// ATNConfigComparator is used as the comparator for the configLookup field of an ATNConfigSet
// and has a custom Equals() and Hash() implementation, because equality is not based on the
// standard Hash() and Equals() methods of the ATNConfig type.
type ATNConfigComparator[T Collectable[T]] struct {
}
// Equals2 is a custom comparator for ATNConfigs specifically for configLookup
func (c *ATNConfigComparator[T]) Equals2(o1, o2 ATNConfig) bool {
func (c *ATNConfigComparator[T]) Equals2(o1, o2 *ATNConfig) bool {
// Same pointer, must be equal, even if both nil
//
@@ -72,7 +75,8 @@ func (c *ATNConfigComparator[T]) Equals2(o1, o2 ATNConfig) bool {
}
// Hash1 is custom hash implementation for ATNConfigs specifically for configLookup
func (c *ATNConfigComparator[T]) Hash1(o ATNConfig) int {
func (c *ATNConfigComparator[T]) Hash1(o *ATNConfig) int {
hash := 7
hash = 31*hash + o.GetState().GetStateNumber()
hash = 31*hash + o.GetAlt()
@@ -85,7 +89,7 @@ type ATNAltConfigComparator[T Collectable[T]] struct {
}
// Equals2 is a custom comparator for ATNConfigs specifically for configLookup
func (c *ATNAltConfigComparator[T]) Equals2(o1, o2 ATNConfig) bool {
func (c *ATNAltConfigComparator[T]) Equals2(o1, o2 *ATNConfig) bool {
// Same pointer, must be equal, even if both nil
//
@@ -105,21 +109,21 @@ func (c *ATNAltConfigComparator[T]) Equals2(o1, o2 ATNConfig) bool {
}
// Hash1 is custom hash implementation for ATNConfigs specifically for configLookup
func (c *ATNAltConfigComparator[T]) Hash1(o ATNConfig) int {
func (c *ATNAltConfigComparator[T]) Hash1(o *ATNConfig) int {
h := murmurInit(7)
h = murmurUpdate(h, o.GetState().GetStateNumber())
h = murmurUpdate(h, o.GetContext().Hash())
return murmurFinish(h, 2)
}
// BaseATNConfigComparator is used as the comparator for the configLookup field of a BaseATNConfigSet
// BaseATNConfigComparator is used as the comparator for the configLookup field of a ATNConfigSet
// and has a custom Equals() and Hash() implementation, because equality is not based on the
// standard Hash() and Equals() methods of the ATNConfig type.
type BaseATNConfigComparator[T Collectable[T]] struct {
}
// Equals2 is a custom comparator for ATNConfigs specifically for baseATNConfigSet
func (c *BaseATNConfigComparator[T]) Equals2(o1, o2 ATNConfig) bool {
func (c *BaseATNConfigComparator[T]) Equals2(o1, o2 *ATNConfig) bool {
// Same pointer, must be equal, even if both nil
//
@@ -141,7 +145,6 @@ func (c *BaseATNConfigComparator[T]) Equals2(o1, o2 ATNConfig) bool {
// Hash1 is custom hash implementation for ATNConfigs specifically for configLookup, but in fact just
// delegates to the standard Hash() method of the ATNConfig type.
func (c *BaseATNConfigComparator[T]) Hash1(o ATNConfig) int {
func (c *BaseATNConfigComparator[T]) Hash1(o *ATNConfig) int {
return o.Hash()
}

214
vendor/github.com/antlr4-go/antlr/v4/configuration.go generated vendored Normal file
View File

@@ -0,0 +1,214 @@
package antlr
type runtimeConfiguration struct {
statsTraceStacks bool
lexerATNSimulatorDebug bool
lexerATNSimulatorDFADebug bool
parserATNSimulatorDebug bool
parserATNSimulatorTraceATNSim bool
parserATNSimulatorDFADebug bool
parserATNSimulatorRetryDebug bool
lRLoopEntryBranchOpt bool
memoryManager bool
}
// Global runtime configuration
var runtimeConfig = runtimeConfiguration{
lRLoopEntryBranchOpt: true,
}
type runtimeOption func(*runtimeConfiguration) error
// ConfigureRuntime allows the runtime to be configured globally setting things like trace and statistics options.
// It uses the functional options pattern for go. This is a package global function as it operates on the runtime
// configuration regardless of the instantiation of anything higher up such as a parser or lexer. Generally this is
// used for debugging/tracing/statistics options, which are usually used by the runtime maintainers (or rather the
// only maintainer). However, it is possible that you might want to use this to set a global option concerning the
// memory allocation type used by the runtime such as sync.Pool or not.
//
// The options are applied in the order they are passed in, so the last option will override any previous options.
//
// For example, if you want to turn on the collection create point stack flag to true, you can do:
//
// antlr.ConfigureRuntime(antlr.WithStatsTraceStacks(true))
//
// If you want to turn it off, you can do:
//
// antlr.ConfigureRuntime(antlr.WithStatsTraceStacks(false))
func ConfigureRuntime(options ...runtimeOption) error {
for _, option := range options {
err := option(&runtimeConfig)
if err != nil {
return err
}
}
return nil
}
// WithStatsTraceStacks sets the global flag indicating whether to collect stack traces at the create-point of
// certain structs, such as collections, or the use point of certain methods such as Put().
// Because this can be expensive, it is turned off by default. However, it
// can be useful to track down exactly where memory is being created and used.
//
// Use:
//
// antlr.ConfigureRuntime(antlr.WithStatsTraceStacks(true))
//
// You can turn it off at any time using:
//
// antlr.ConfigureRuntime(antlr.WithStatsTraceStacks(false))
func WithStatsTraceStacks(trace bool) runtimeOption {
return func(config *runtimeConfiguration) error {
config.statsTraceStacks = trace
return nil
}
}
// WithLexerATNSimulatorDebug sets the global flag indicating whether to log debug information from the lexer [ATN]
// simulator. This is useful for debugging lexer issues by comparing the output with the Java runtime. Only useful
// to the runtime maintainers.
//
// Use:
//
// antlr.ConfigureRuntime(antlr.WithLexerATNSimulatorDebug(true))
//
// You can turn it off at any time using:
//
// antlr.ConfigureRuntime(antlr.WithLexerATNSimulatorDebug(false))
func WithLexerATNSimulatorDebug(debug bool) runtimeOption {
return func(config *runtimeConfiguration) error {
config.lexerATNSimulatorDebug = debug
return nil
}
}
// WithLexerATNSimulatorDFADebug sets the global flag indicating whether to log debug information from the lexer [ATN] [DFA]
// simulator. This is useful for debugging lexer issues by comparing the output with the Java runtime. Only useful
// to the runtime maintainers.
//
// Use:
//
// antlr.ConfigureRuntime(antlr.WithLexerATNSimulatorDFADebug(true))
//
// You can turn it off at any time using:
//
// antlr.ConfigureRuntime(antlr.WithLexerATNSimulatorDFADebug(false))
func WithLexerATNSimulatorDFADebug(debug bool) runtimeOption {
return func(config *runtimeConfiguration) error {
config.lexerATNSimulatorDFADebug = debug
return nil
}
}
// WithParserATNSimulatorDebug sets the global flag indicating whether to log debug information from the parser [ATN]
// simulator. This is useful for debugging parser issues by comparing the output with the Java runtime. Only useful
// to the runtime maintainers.
//
// Use:
//
// antlr.ConfigureRuntime(antlr.WithParserATNSimulatorDebug(true))
//
// You can turn it off at any time using:
//
// antlr.ConfigureRuntime(antlr.WithParserATNSimulatorDebug(false))
func WithParserATNSimulatorDebug(debug bool) runtimeOption {
return func(config *runtimeConfiguration) error {
config.parserATNSimulatorDebug = debug
return nil
}
}
// WithParserATNSimulatorTraceATNSim sets the global flag indicating whether to log trace information from the parser [ATN] simulator
// [DFA]. This is useful for debugging parser issues by comparing the output with the Java runtime. Only useful
// to the runtime maintainers.
//
// Use:
//
// antlr.ConfigureRuntime(antlr.WithParserATNSimulatorTraceATNSim(true))
//
// You can turn it off at any time using:
//
// antlr.ConfigureRuntime(antlr.WithParserATNSimulatorTraceATNSim(false))
func WithParserATNSimulatorTraceATNSim(trace bool) runtimeOption {
return func(config *runtimeConfiguration) error {
config.parserATNSimulatorTraceATNSim = trace
return nil
}
}
// WithParserATNSimulatorDFADebug sets the global flag indicating whether to log debug information from the parser [ATN] [DFA]
// simulator. This is useful for debugging parser issues by comparing the output with the Java runtime. Only useful
// to the runtime maintainers.
//
// Use:
//
// antlr.ConfigureRuntime(antlr.WithParserATNSimulatorDFADebug(true))
//
// You can turn it off at any time using:
//
// antlr.ConfigureRuntime(antlr.WithParserATNSimulatorDFADebug(false))
func WithParserATNSimulatorDFADebug(debug bool) runtimeOption {
return func(config *runtimeConfiguration) error {
config.parserATNSimulatorDFADebug = debug
return nil
}
}
// WithParserATNSimulatorRetryDebug sets the global flag indicating whether to log debug information from the parser [ATN] [DFA]
// simulator when retrying a decision. This is useful for debugging parser issues by comparing the output with the Java runtime.
// Only useful to the runtime maintainers.
//
// Use:
//
// antlr.ConfigureRuntime(antlr.WithParserATNSimulatorRetryDebug(true))
//
// You can turn it off at any time using:
//
// antlr.ConfigureRuntime(antlr.WithParserATNSimulatorRetryDebug(false))
func WithParserATNSimulatorRetryDebug(debug bool) runtimeOption {
return func(config *runtimeConfiguration) error {
config.parserATNSimulatorRetryDebug = debug
return nil
}
}
// WithLRLoopEntryBranchOpt sets the global flag indicating whether let recursive loop operations should be
// optimized or not. This is useful for debugging parser issues by comparing the output with the Java runtime.
// It turns off the functionality of [canDropLoopEntryEdgeInLeftRecursiveRule] in [ParserATNSimulator].
//
// Note that default is to use this optimization.
//
// Use:
//
// antlr.ConfigureRuntime(antlr.WithLRLoopEntryBranchOpt(true))
//
// You can turn it off at any time using:
//
// antlr.ConfigureRuntime(antlr.WithLRLoopEntryBranchOpt(false))
func WithLRLoopEntryBranchOpt(off bool) runtimeOption {
return func(config *runtimeConfiguration) error {
config.lRLoopEntryBranchOpt = off
return nil
}
}
// WithMemoryManager sets the global flag indicating whether to use the memory manager or not. This is useful
// for poorly constructed grammars that create a lot of garbage. It turns on the functionality of [memoryManager], which
// will intercept garbage collection and cause available memory to be reused. At the end of the day, this is no substitute
// for fixing your grammar by ridding yourself of extreme ambiguity. BUt if you are just trying to reuse an opensource
// grammar, this may help make it more practical.
//
// Note that default is to use normal Go memory allocation and not pool memory.
//
// Use:
//
// antlr.ConfigureRuntime(antlr.WithMemoryManager(true))
//
// Note that if you turn this on, you should probably leave it on. You should use only one memory strategy or the other
// and should remember to nil out any references to the parser or lexer when you are done with them.
func WithMemoryManager(use bool) runtimeOption {
return func(config *runtimeConfiguration) error {
config.memoryManager = use
return nil
}
}

View File

@@ -4,6 +4,8 @@
package antlr
// DFA represents the Deterministic Finite Automaton used by the recognizer, including all the states it can
// reach and the transitions between them.
type DFA struct {
// atnStartState is the ATN state in which this was created
atnStartState DecisionState
@@ -12,10 +14,9 @@ type DFA struct {
// states is all the DFA states. Use Map to get the old state back; Set can only
// indicate whether it is there. Go maps implement key hash collisions and so on and are very
// good, but the DFAState is an object and can't be used directly as the key as it can in say JAva
// good, but the DFAState is an object and can't be used directly as the key as it can in say Java
// amd C#, whereby if the hashcode is the same for two objects, then Equals() is called against them
// to see if they really are the same object.
//
// to see if they really are the same object. Hence, we have our own map storage.
//
states *JStore[*DFAState, *ObjEqComparator[*DFAState]]
@@ -32,11 +33,11 @@ func NewDFA(atnStartState DecisionState, decision int) *DFA {
dfa := &DFA{
atnStartState: atnStartState,
decision: decision,
states: NewJStore[*DFAState, *ObjEqComparator[*DFAState]](dfaStateEqInst),
states: nil, // Lazy initialize
}
if s, ok := atnStartState.(*StarLoopEntryState); ok && s.precedenceRuleDecision {
dfa.precedenceDfa = true
dfa.s0 = NewDFAState(-1, NewBaseATNConfigSet(false))
dfa.s0 = NewDFAState(-1, NewATNConfigSet(false))
dfa.s0.isAcceptState = false
dfa.s0.requiresFullContext = false
}
@@ -95,12 +96,11 @@ func (d *DFA) getPrecedenceDfa() bool {
// true or nil otherwise, and d.precedenceDfa is updated.
func (d *DFA) setPrecedenceDfa(precedenceDfa bool) {
if d.getPrecedenceDfa() != precedenceDfa {
d.states = NewJStore[*DFAState, *ObjEqComparator[*DFAState]](dfaStateEqInst)
d.states = nil // Lazy initialize
d.numstates = 0
if precedenceDfa {
precedenceState := NewDFAState(-1, NewBaseATNConfigSet(false))
precedenceState := NewDFAState(-1, NewATNConfigSet(false))
precedenceState.setEdges(make([]*DFAState, 0))
precedenceState.isAcceptState = false
precedenceState.requiresFullContext = false
@@ -113,6 +113,31 @@ func (d *DFA) setPrecedenceDfa(precedenceDfa bool) {
}
}
// Len returns the number of states in d. We use this instead of accessing states directly so that we can implement lazy
// instantiation of the states JMap.
func (d *DFA) Len() int {
if d.states == nil {
return 0
}
return d.states.Len()
}
// Get returns a state that matches s if it is present in the DFA state set. We defer to this
// function instead of accessing states directly so that we can implement lazy instantiation of the states JMap.
func (d *DFA) Get(s *DFAState) (*DFAState, bool) {
if d.states == nil {
return nil, false
}
return d.states.Get(s)
}
func (d *DFA) Put(s *DFAState) (*DFAState, bool) {
if d.states == nil {
d.states = NewJStore[*DFAState, *ObjEqComparator[*DFAState]](dfaStateEqInst, DFAStateCollection, "DFA via DFA.Put")
}
return d.states.Put(s)
}
func (d *DFA) getS0() *DFAState {
return d.s0
}
@@ -121,9 +146,11 @@ func (d *DFA) setS0(s *DFAState) {
d.s0 = s
}
// sortedStates returns the states in d sorted by their state number.
// sortedStates returns the states in d sorted by their state number, or an empty set if d.states is nil.
func (d *DFA) sortedStates() []*DFAState {
if d.states == nil {
return []*DFAState{}
}
vs := d.states.SortedSlice(func(i, j *DFAState) bool {
return i.stateNumber < j.stateNumber
})

View File

@@ -10,7 +10,7 @@ import (
"strings"
)
// DFASerializer is a DFA walker that knows how to dump them to serialized
// DFASerializer is a DFA walker that knows how to dump the DFA states to serialized
// strings.
type DFASerializer struct {
dfa *DFA

View File

@@ -22,30 +22,31 @@ func (p *PredPrediction) String() string {
return "(" + fmt.Sprint(p.pred) + ", " + fmt.Sprint(p.alt) + ")"
}
// DFAState represents a set of possible ATN configurations. As Aho, Sethi,
// DFAState represents a set of possible [ATN] configurations. As Aho, Sethi,
// Ullman p. 117 says: "The DFA uses its state to keep track of all possible
// states the ATN can be in after reading each input symbol. That is to say,
// after reading input a1a2..an, the DFA is in a state that represents the
// after reading input a1, a2,..an, the DFA is in a state that represents the
// subset T of the states of the ATN that are reachable from the ATN's start
// state along some path labeled a1a2..an." In conventional NFA-to-DFA
// conversion, therefore, the subset T would be a bitset representing the set of
// states the ATN could be in. We need to track the alt predicted by each state
// state along some path labeled a1a2..an."
//
// In conventional NFA-to-DFA conversion, therefore, the subset T would be a bitset representing the set of
// states the [ATN] could be in. We need to track the alt predicted by each state
// as well, however. More importantly, we need to maintain a stack of states,
// tracking the closure operations as they jump from rule to rule, emulating
// rule invocations (method calls). I have to add a stack to simulate the proper
// lookahead sequences for the underlying LL grammar from which the ATN was
// derived.
//
// I use a set of ATNConfig objects, not simple states. An ATNConfig is both a
// state (ala normal conversion) and a RuleContext describing the chain of rules
// I use a set of [ATNConfig] objects, not simple states. An [ATNConfig] is both a
// state (ala normal conversion) and a [RuleContext] describing the chain of rules
// (if any) followed to arrive at that state.
//
// A DFAState may have multiple references to a particular state, but with
// different ATN contexts (with same or different alts) meaning that state was
// A [DFAState] may have multiple references to a particular state, but with
// different [ATN] contexts (with same or different alts) meaning that state was
// reached via a different set of rule invocations.
type DFAState struct {
stateNumber int
configs ATNConfigSet
configs *ATNConfigSet
// edges elements point to the target of the symbol. Shift up by 1 so (-1)
// Token.EOF maps to the first element.
@@ -53,7 +54,7 @@ type DFAState struct {
isAcceptState bool
// prediction is the ttype we match or alt we predict if the state is accept.
// prediction is the 'ttype' we match or alt we predict if the state is 'accept'.
// Set to ATN.INVALID_ALT_NUMBER when predicates != nil or
// requiresFullContext.
prediction int
@@ -81,9 +82,9 @@ type DFAState struct {
predicates []*PredPrediction
}
func NewDFAState(stateNumber int, configs ATNConfigSet) *DFAState {
func NewDFAState(stateNumber int, configs *ATNConfigSet) *DFAState {
if configs == nil {
configs = NewBaseATNConfigSet(false)
configs = NewATNConfigSet(false)
}
return &DFAState{configs: configs, stateNumber: stateNumber}
@@ -94,7 +95,7 @@ func (d *DFAState) GetAltSet() []int {
var alts []int
if d.configs != nil {
for _, c := range d.configs.GetItems() {
for _, c := range d.configs.configs {
alts = append(alts, c.GetAlt())
}
}

View File

@@ -33,6 +33,7 @@ type DiagnosticErrorListener struct {
exactOnly bool
}
//goland:noinspection GoUnusedExportedFunction
func NewDiagnosticErrorListener(exactOnly bool) *DiagnosticErrorListener {
n := new(DiagnosticErrorListener)
@@ -42,7 +43,7 @@ func NewDiagnosticErrorListener(exactOnly bool) *DiagnosticErrorListener {
return n
}
func (d *DiagnosticErrorListener) ReportAmbiguity(recognizer Parser, dfa *DFA, startIndex, stopIndex int, exact bool, ambigAlts *BitSet, configs ATNConfigSet) {
func (d *DiagnosticErrorListener) ReportAmbiguity(recognizer Parser, dfa *DFA, startIndex, stopIndex int, exact bool, ambigAlts *BitSet, configs *ATNConfigSet) {
if d.exactOnly && !exact {
return
}
@@ -55,7 +56,7 @@ func (d *DiagnosticErrorListener) ReportAmbiguity(recognizer Parser, dfa *DFA, s
recognizer.NotifyErrorListeners(msg, nil, nil)
}
func (d *DiagnosticErrorListener) ReportAttemptingFullContext(recognizer Parser, dfa *DFA, startIndex, stopIndex int, conflictingAlts *BitSet, configs ATNConfigSet) {
func (d *DiagnosticErrorListener) ReportAttemptingFullContext(recognizer Parser, dfa *DFA, startIndex, stopIndex int, _ *BitSet, _ *ATNConfigSet) {
msg := "reportAttemptingFullContext d=" +
d.getDecisionDescription(recognizer, dfa) +
@@ -64,7 +65,7 @@ func (d *DiagnosticErrorListener) ReportAttemptingFullContext(recognizer Parser,
recognizer.NotifyErrorListeners(msg, nil, nil)
}
func (d *DiagnosticErrorListener) ReportContextSensitivity(recognizer Parser, dfa *DFA, startIndex, stopIndex, prediction int, configs ATNConfigSet) {
func (d *DiagnosticErrorListener) ReportContextSensitivity(recognizer Parser, dfa *DFA, startIndex, stopIndex, _ int, _ *ATNConfigSet) {
msg := "reportContextSensitivity d=" +
d.getDecisionDescription(recognizer, dfa) +
", input='" +
@@ -96,12 +97,12 @@ func (d *DiagnosticErrorListener) getDecisionDescription(recognizer Parser, dfa
// @param configs The conflicting or ambiguous configuration set.
// @return Returns {@code ReportedAlts} if it is not {@code nil}, otherwise
// returns the set of alternatives represented in {@code configs}.
func (d *DiagnosticErrorListener) getConflictingAlts(ReportedAlts *BitSet, set ATNConfigSet) *BitSet {
func (d *DiagnosticErrorListener) getConflictingAlts(ReportedAlts *BitSet, set *ATNConfigSet) *BitSet {
if ReportedAlts != nil {
return ReportedAlts
}
result := NewBitSet()
for _, c := range set.GetItems() {
for _, c := range set.configs {
result.add(c.GetAlt())
}

View File

@@ -16,28 +16,29 @@ import (
type ErrorListener interface {
SyntaxError(recognizer Recognizer, offendingSymbol interface{}, line, column int, msg string, e RecognitionException)
ReportAmbiguity(recognizer Parser, dfa *DFA, startIndex, stopIndex int, exact bool, ambigAlts *BitSet, configs ATNConfigSet)
ReportAttemptingFullContext(recognizer Parser, dfa *DFA, startIndex, stopIndex int, conflictingAlts *BitSet, configs ATNConfigSet)
ReportContextSensitivity(recognizer Parser, dfa *DFA, startIndex, stopIndex, prediction int, configs ATNConfigSet)
ReportAmbiguity(recognizer Parser, dfa *DFA, startIndex, stopIndex int, exact bool, ambigAlts *BitSet, configs *ATNConfigSet)
ReportAttemptingFullContext(recognizer Parser, dfa *DFA, startIndex, stopIndex int, conflictingAlts *BitSet, configs *ATNConfigSet)
ReportContextSensitivity(recognizer Parser, dfa *DFA, startIndex, stopIndex, prediction int, configs *ATNConfigSet)
}
type DefaultErrorListener struct {
}
//goland:noinspection GoUnusedExportedFunction
func NewDefaultErrorListener() *DefaultErrorListener {
return new(DefaultErrorListener)
}
func (d *DefaultErrorListener) SyntaxError(recognizer Recognizer, offendingSymbol interface{}, line, column int, msg string, e RecognitionException) {
func (d *DefaultErrorListener) SyntaxError(_ Recognizer, _ interface{}, _, _ int, _ string, _ RecognitionException) {
}
func (d *DefaultErrorListener) ReportAmbiguity(recognizer Parser, dfa *DFA, startIndex, stopIndex int, exact bool, ambigAlts *BitSet, configs ATNConfigSet) {
func (d *DefaultErrorListener) ReportAmbiguity(_ Parser, _ *DFA, _, _ int, _ bool, _ *BitSet, _ *ATNConfigSet) {
}
func (d *DefaultErrorListener) ReportAttemptingFullContext(recognizer Parser, dfa *DFA, startIndex, stopIndex int, conflictingAlts *BitSet, configs ATNConfigSet) {
func (d *DefaultErrorListener) ReportAttemptingFullContext(_ Parser, _ *DFA, _, _ int, _ *BitSet, _ *ATNConfigSet) {
}
func (d *DefaultErrorListener) ReportContextSensitivity(recognizer Parser, dfa *DFA, startIndex, stopIndex, prediction int, configs ATNConfigSet) {
func (d *DefaultErrorListener) ReportContextSensitivity(_ Parser, _ *DFA, _, _, _ int, _ *ATNConfigSet) {
}
type ConsoleErrorListener struct {
@@ -48,21 +49,16 @@ func NewConsoleErrorListener() *ConsoleErrorListener {
return new(ConsoleErrorListener)
}
// Provides a default instance of {@link ConsoleErrorListener}.
// ConsoleErrorListenerINSTANCE provides a default instance of {@link ConsoleErrorListener}.
var ConsoleErrorListenerINSTANCE = NewConsoleErrorListener()
// {@inheritDoc}
// SyntaxError prints messages to System.err containing the
// values of line, charPositionInLine, and msg using
// the following format:
//
// <p>
// This implementation prints messages to {@link System//err} containing the
// values of {@code line}, {@code charPositionInLine}, and {@code msg} using
// the following format.</p>
//
// <pre>
// line <em>line</em>:<em>charPositionInLine</em> <em>msg</em>
// </pre>
func (c *ConsoleErrorListener) SyntaxError(recognizer Recognizer, offendingSymbol interface{}, line, column int, msg string, e RecognitionException) {
fmt.Fprintln(os.Stderr, "line "+strconv.Itoa(line)+":"+strconv.Itoa(column)+" "+msg)
// line <line>:<charPositionInLine> <msg>
func (c *ConsoleErrorListener) SyntaxError(_ Recognizer, _ interface{}, line, column int, msg string, _ RecognitionException) {
_, _ = fmt.Fprintln(os.Stderr, "line "+strconv.Itoa(line)+":"+strconv.Itoa(column)+" "+msg)
}
type ProxyErrorListener struct {
@@ -85,19 +81,19 @@ func (p *ProxyErrorListener) SyntaxError(recognizer Recognizer, offendingSymbol
}
}
func (p *ProxyErrorListener) ReportAmbiguity(recognizer Parser, dfa *DFA, startIndex, stopIndex int, exact bool, ambigAlts *BitSet, configs ATNConfigSet) {
func (p *ProxyErrorListener) ReportAmbiguity(recognizer Parser, dfa *DFA, startIndex, stopIndex int, exact bool, ambigAlts *BitSet, configs *ATNConfigSet) {
for _, d := range p.delegates {
d.ReportAmbiguity(recognizer, dfa, startIndex, stopIndex, exact, ambigAlts, configs)
}
}
func (p *ProxyErrorListener) ReportAttemptingFullContext(recognizer Parser, dfa *DFA, startIndex, stopIndex int, conflictingAlts *BitSet, configs ATNConfigSet) {
func (p *ProxyErrorListener) ReportAttemptingFullContext(recognizer Parser, dfa *DFA, startIndex, stopIndex int, conflictingAlts *BitSet, configs *ATNConfigSet) {
for _, d := range p.delegates {
d.ReportAttemptingFullContext(recognizer, dfa, startIndex, stopIndex, conflictingAlts, configs)
}
}
func (p *ProxyErrorListener) ReportContextSensitivity(recognizer Parser, dfa *DFA, startIndex, stopIndex, prediction int, configs ATNConfigSet) {
func (p *ProxyErrorListener) ReportContextSensitivity(recognizer Parser, dfa *DFA, startIndex, stopIndex, prediction int, configs *ATNConfigSet) {
for _, d := range p.delegates {
d.ReportContextSensitivity(recognizer, dfa, startIndex, stopIndex, prediction, configs)
}

View File

@@ -21,8 +21,8 @@ type ErrorStrategy interface {
ReportMatch(Parser)
}
// This is the default implementation of {@link ANTLRErrorStrategy} used for
// error Reporting and recovery in ANTLR parsers.
// DefaultErrorStrategy is the default implementation of ANTLRErrorStrategy used for
// error reporting and recovery in ANTLR parsers.
type DefaultErrorStrategy struct {
errorRecoveryMode bool
lastErrorIndex int
@@ -46,7 +46,7 @@ func NewDefaultErrorStrategy() *DefaultErrorStrategy {
// The index into the input stream where the last error occurred.
// This is used to prevent infinite loops where an error is found
// but no token is consumed during recovery...another error is found,
// ad nauseum. This is a failsafe mechanism to guarantee that at least
// ad nauseam. This is a failsafe mechanism to guarantee that at least
// one token/tree node is consumed for two errors.
//
d.lastErrorIndex = -1
@@ -62,50 +62,37 @@ func (d *DefaultErrorStrategy) reset(recognizer Parser) {
// This method is called to enter error recovery mode when a recognition
// exception is Reported.
//
// @param recognizer the parser instance
func (d *DefaultErrorStrategy) beginErrorCondition(recognizer Parser) {
func (d *DefaultErrorStrategy) beginErrorCondition(_ Parser) {
d.errorRecoveryMode = true
}
func (d *DefaultErrorStrategy) InErrorRecoveryMode(recognizer Parser) bool {
func (d *DefaultErrorStrategy) InErrorRecoveryMode(_ Parser) bool {
return d.errorRecoveryMode
}
// This method is called to leave error recovery mode after recovering from
// a recognition exception.
//
// @param recognizer
func (d *DefaultErrorStrategy) endErrorCondition(recognizer Parser) {
func (d *DefaultErrorStrategy) endErrorCondition(_ Parser) {
d.errorRecoveryMode = false
d.lastErrorStates = nil
d.lastErrorIndex = -1
}
// {@inheritDoc}
//
// <p>The default implementation simply calls {@link //endErrorCondition}.</p>
// ReportMatch is the default implementation of error matching and simply calls endErrorCondition.
func (d *DefaultErrorStrategy) ReportMatch(recognizer Parser) {
d.endErrorCondition(recognizer)
}
// {@inheritDoc}
// ReportError is the default implementation of error reporting.
// It returns immediately if the handler is already
// in error recovery mode. Otherwise, it calls [beginErrorCondition]
// and dispatches the Reporting task based on the runtime type of e
// according to the following table.
//
// <p>The default implementation returns immediately if the handler is already
// in error recovery mode. Otherwise, it calls {@link //beginErrorCondition}
// and dispatches the Reporting task based on the runtime type of {@code e}
// according to the following table.</p>
//
// <ul>
// <li>{@link NoViableAltException}: Dispatches the call to
// {@link //ReportNoViableAlternative}</li>
// <li>{@link InputMisMatchException}: Dispatches the call to
// {@link //ReportInputMisMatch}</li>
// <li>{@link FailedPredicateException}: Dispatches the call to
// {@link //ReportFailedPredicate}</li>
// <li>All other types: calls {@link Parser//NotifyErrorListeners} to Report
// the exception</li>
// </ul>
// [NoViableAltException] : Dispatches the call to [ReportNoViableAlternative]
// [InputMisMatchException] : Dispatches the call to [ReportInputMisMatch]
// [FailedPredicateException] : Dispatches the call to [ReportFailedPredicate]
// All other types : Calls [NotifyErrorListeners] to Report the exception
func (d *DefaultErrorStrategy) ReportError(recognizer Parser, e RecognitionException) {
// if we've already Reported an error and have not Matched a token
// yet successfully, don't Report any errors.
@@ -128,12 +115,10 @@ func (d *DefaultErrorStrategy) ReportError(recognizer Parser, e RecognitionExcep
}
}
// {@inheritDoc}
//
// <p>The default implementation reSynchronizes the parser by consuming tokens
// until we find one in the reSynchronization set--loosely the set of tokens
// that can follow the current rule.</p>
func (d *DefaultErrorStrategy) Recover(recognizer Parser, e RecognitionException) {
// Recover is the default recovery implementation.
// It reSynchronizes the parser by consuming tokens until we find one in the reSynchronization set -
// loosely the set of tokens that can follow the current rule.
func (d *DefaultErrorStrategy) Recover(recognizer Parser, _ RecognitionException) {
if d.lastErrorIndex == recognizer.GetInputStream().Index() &&
d.lastErrorStates != nil && d.lastErrorStates.contains(recognizer.GetState()) {
@@ -148,54 +133,58 @@ func (d *DefaultErrorStrategy) Recover(recognizer Parser, e RecognitionException
d.lastErrorStates = NewIntervalSet()
}
d.lastErrorStates.addOne(recognizer.GetState())
followSet := d.getErrorRecoverySet(recognizer)
followSet := d.GetErrorRecoverySet(recognizer)
d.consumeUntil(recognizer, followSet)
}
// The default implementation of {@link ANTLRErrorStrategy//Sync} makes sure
// that the current lookahead symbol is consistent with what were expecting
// at d point in the ATN. You can call d anytime but ANTLR only
// generates code to check before subrules/loops and each iteration.
// Sync is the default implementation of error strategy synchronization.
//
// <p>Implements Jim Idle's magic Sync mechanism in closures and optional
// subrules. E.g.,</p>
// This Sync makes sure that the current lookahead symbol is consistent with what were expecting
// at this point in the [ATN]. You can call this anytime but ANTLR only
// generates code to check before sub-rules/loops and each iteration.
//
// <pre>
// a : Sync ( stuff Sync )*
// Sync : {consume to what can follow Sync}
// </pre>
// Implements [Jim Idle]'s magic Sync mechanism in closures and optional
// sub-rules. E.g.:
//
// At the start of a sub rule upon error, {@link //Sync} performs single
// a : Sync ( stuff Sync )*
// Sync : {consume to what can follow Sync}
//
// At the start of a sub-rule upon error, Sync performs single
// token deletion, if possible. If it can't do that, it bails on the current
// rule and uses the default error recovery, which consumes until the
// reSynchronization set of the current rule.
//
// <p>If the sub rule is optional ({@code (...)?}, {@code (...)*}, or block
// with an empty alternative), then the expected set includes what follows
// the subrule.</p>
// If the sub-rule is optional
//
// <p>During loop iteration, it consumes until it sees a token that can start a
// sub rule or what follows loop. Yes, that is pretty aggressive. We opt to
// stay in the loop as long as possible.</p>
// ({@code (...)?}, {@code (...)*},
//
// <p><strong>ORIGINS</strong></p>
// or a block with an empty alternative), then the expected set includes what follows
// the sub-rule.
//
// <p>Previous versions of ANTLR did a poor job of their recovery within loops.
// During loop iteration, it consumes until it sees a token that can start a
// sub-rule or what follows loop. Yes, that is pretty aggressive. We opt to
// stay in the loop as long as possible.
//
// # Origins
//
// Previous versions of ANTLR did a poor job of their recovery within loops.
// A single mismatch token or missing token would force the parser to bail
// out of the entire rules surrounding the loop. So, for rule</p>
// out of the entire rules surrounding the loop. So, for rule:
//
// <pre>
// classfunc : 'class' ID '{' member* '}'
// </pre>
// classfunc : 'class' ID '{' member* '}'
//
// input with an extra token between members would force the parser to
// consume until it found the next class definition rather than the next
// member definition of the current class.
//
// <p>This functionality cost a little bit of effort because the parser has to
// compare token set at the start of the loop and at each iteration. If for
// some reason speed is suffering for you, you can turn off d
// functionality by simply overriding d method as a blank { }.</p>
// This functionality cost a bit of effort because the parser has to
// compare the token set at the start of the loop and at each iteration. If for
// some reason speed is suffering for you, you can turn off this
// functionality by simply overriding this method as empty:
//
// { }
//
// [Jim Idle]: https://github.com/jimidle
func (d *DefaultErrorStrategy) Sync(recognizer Parser) {
// If already recovering, don't try to Sync
if d.InErrorRecoveryMode(recognizer) {
@@ -217,25 +206,21 @@ func (d *DefaultErrorStrategy) Sync(recognizer Parser) {
if d.SingleTokenDeletion(recognizer) != nil {
return
}
panic(NewInputMisMatchException(recognizer))
recognizer.SetError(NewInputMisMatchException(recognizer))
case ATNStatePlusLoopBack, ATNStateStarLoopBack:
d.ReportUnwantedToken(recognizer)
expecting := NewIntervalSet()
expecting.addSet(recognizer.GetExpectedTokens())
whatFollowsLoopIterationOrRule := expecting.addSet(d.getErrorRecoverySet(recognizer))
whatFollowsLoopIterationOrRule := expecting.addSet(d.GetErrorRecoverySet(recognizer))
d.consumeUntil(recognizer, whatFollowsLoopIterationOrRule)
default:
// do nothing if we can't identify the exact kind of ATN state
}
}
// This is called by {@link //ReportError} when the exception is a
// {@link NoViableAltException}.
// ReportNoViableAlternative is called by [ReportError] when the exception is a [NoViableAltException].
//
// @see //ReportError
//
// @param recognizer the parser instance
// @param e the recognition exception
// See also [ReportError]
func (d *DefaultErrorStrategy) ReportNoViableAlternative(recognizer Parser, e *NoViableAltException) {
tokens := recognizer.GetTokenStream()
var input string
@@ -252,48 +237,38 @@ func (d *DefaultErrorStrategy) ReportNoViableAlternative(recognizer Parser, e *N
recognizer.NotifyErrorListeners(msg, e.offendingToken, e)
}
// This is called by {@link //ReportError} when the exception is an
// {@link InputMisMatchException}.
// ReportInputMisMatch is called by [ReportError] when the exception is an [InputMisMatchException]
//
// @see //ReportError
//
// @param recognizer the parser instance
// @param e the recognition exception
func (this *DefaultErrorStrategy) ReportInputMisMatch(recognizer Parser, e *InputMisMatchException) {
msg := "mismatched input " + this.GetTokenErrorDisplay(e.offendingToken) +
// See also: [ReportError]
func (d *DefaultErrorStrategy) ReportInputMisMatch(recognizer Parser, e *InputMisMatchException) {
msg := "mismatched input " + d.GetTokenErrorDisplay(e.offendingToken) +
" expecting " + e.getExpectedTokens().StringVerbose(recognizer.GetLiteralNames(), recognizer.GetSymbolicNames(), false)
recognizer.NotifyErrorListeners(msg, e.offendingToken, e)
}
// This is called by {@link //ReportError} when the exception is a
// {@link FailedPredicateException}.
// ReportFailedPredicate is called by [ReportError] when the exception is a [FailedPredicateException].
//
// @see //ReportError
//
// @param recognizer the parser instance
// @param e the recognition exception
// See also: [ReportError]
func (d *DefaultErrorStrategy) ReportFailedPredicate(recognizer Parser, e *FailedPredicateException) {
ruleName := recognizer.GetRuleNames()[recognizer.GetParserRuleContext().GetRuleIndex()]
msg := "rule " + ruleName + " " + e.message
recognizer.NotifyErrorListeners(msg, e.offendingToken, e)
}
// This method is called to Report a syntax error which requires the removal
// ReportUnwantedToken is called to report a syntax error that requires the removal
// of a token from the input stream. At the time d method is called, the
// erroneous symbol is current {@code LT(1)} symbol and has not yet been
// removed from the input stream. When d method returns,
// {@code recognizer} is in error recovery mode.
// erroneous symbol is the current LT(1) symbol and has not yet been
// removed from the input stream. When this method returns,
// recognizer is in error recovery mode.
//
// <p>This method is called when {@link //singleTokenDeletion} identifies
// This method is called when singleTokenDeletion identifies
// single-token deletion as a viable recovery strategy for a mismatched
// input error.</p>
// input error.
//
// <p>The default implementation simply returns if the handler is already in
// error recovery mode. Otherwise, it calls {@link //beginErrorCondition} to
// The default implementation simply returns if the handler is already in
// error recovery mode. Otherwise, it calls beginErrorCondition to
// enter error recovery mode, followed by calling
// {@link Parser//NotifyErrorListeners}.</p>
//
// @param recognizer the parser instance
// [NotifyErrorListeners]
func (d *DefaultErrorStrategy) ReportUnwantedToken(recognizer Parser) {
if d.InErrorRecoveryMode(recognizer) {
return
@@ -307,21 +282,18 @@ func (d *DefaultErrorStrategy) ReportUnwantedToken(recognizer Parser) {
recognizer.NotifyErrorListeners(msg, t, nil)
}
// This method is called to Report a syntax error which requires the
// insertion of a missing token into the input stream. At the time d
// method is called, the missing token has not yet been inserted. When d
// method returns, {@code recognizer} is in error recovery mode.
// ReportMissingToken is called to report a syntax error which requires the
// insertion of a missing token into the input stream. At the time this
// method is called, the missing token has not yet been inserted. When this
// method returns, recognizer is in error recovery mode.
//
// <p>This method is called when {@link //singleTokenInsertion} identifies
// This method is called when singleTokenInsertion identifies
// single-token insertion as a viable recovery strategy for a mismatched
// input error.</p>
// input error.
//
// <p>The default implementation simply returns if the handler is already in
// error recovery mode. Otherwise, it calls {@link //beginErrorCondition} to
// enter error recovery mode, followed by calling
// {@link Parser//NotifyErrorListeners}.</p>
//
// @param recognizer the parser instance
// The default implementation simply returns if the handler is already in
// error recovery mode. Otherwise, it calls beginErrorCondition to
// enter error recovery mode, followed by calling [NotifyErrorListeners]
func (d *DefaultErrorStrategy) ReportMissingToken(recognizer Parser) {
if d.InErrorRecoveryMode(recognizer) {
return
@@ -334,54 +306,48 @@ func (d *DefaultErrorStrategy) ReportMissingToken(recognizer Parser) {
recognizer.NotifyErrorListeners(msg, t, nil)
}
// <p>The default implementation attempts to recover from the mismatched input
// The RecoverInline default implementation attempts to recover from the mismatched input
// by using single token insertion and deletion as described below. If the
// recovery attempt fails, d method panics an
// {@link InputMisMatchException}.</p>
// recovery attempt fails, this method panics with [InputMisMatchException}.
// TODO: Not sure that panic() is the right thing to do here - JI
//
// <p><strong>EXTRA TOKEN</strong> (single token deletion)</p>
// # EXTRA TOKEN (single token deletion)
//
// <p>{@code LA(1)} is not what we are looking for. If {@code LA(2)} has the
// right token, however, then assume {@code LA(1)} is some extra spurious
// LA(1) is not what we are looking for. If LA(2) has the
// right token, however, then assume LA(1) is some extra spurious
// token and delete it. Then consume and return the next token (which was
// the {@code LA(2)} token) as the successful result of the Match operation.</p>
// the LA(2) token) as the successful result of the Match operation.
//
// <p>This recovery strategy is implemented by {@link
// //singleTokenDeletion}.</p>
// # This recovery strategy is implemented by singleTokenDeletion
//
// <p><strong>MISSING TOKEN</strong> (single token insertion)</p>
// # MISSING TOKEN (single token insertion)
//
// <p>If current token (at {@code LA(1)}) is consistent with what could come
// after the expected {@code LA(1)} token, then assume the token is missing
// and use the parser's {@link TokenFactory} to create it on the fly. The
// "insertion" is performed by returning the created token as the successful
// result of the Match operation.</p>
// If current token -at LA(1) - is consistent with what could come
// after the expected LA(1) token, then assume the token is missing
// and use the parser's [TokenFactory] to create it on the fly. The
// insertion is performed by returning the created token as the successful
// result of the Match operation.
//
// <p>This recovery strategy is implemented by {@link
// //singleTokenInsertion}.</p>
// This recovery strategy is implemented by [SingleTokenInsertion].
//
// <p><strong>EXAMPLE</strong></p>
// # Example
//
// <p>For example, Input {@code i=(3} is clearly missing the {@code ')'}. When
// the parser returns from the nested call to {@code expr}, it will have
// call chain:</p>
// For example, Input i=(3 is clearly missing the ')'. When
// the parser returns from the nested call to expr, it will have
// call the chain:
//
// <pre>
// stat &rarr expr &rarr atom
// </pre>
// stat → expr → atom
//
// and it will be trying to Match the {@code ')'} at d point in the
// and it will be trying to Match the ')' at this point in the
// derivation:
//
// <pre>
// =&gt ID '=' '(' INT ')' ('+' atom)* ”
// ^
// </pre>
// : ID '=' '(' INT ')' ('+' atom)* ';'
// ^
//
// The attempt to Match {@code ')'} will fail when it sees {@code ”} and
// call {@link //recoverInline}. To recover, it sees that {@code LA(1)==”}
// is in the set of tokens that can follow the {@code ')'} token reference
// in rule {@code atom}. It can assume that you forgot the {@code ')'}.
// The attempt to [Match] ')' will fail when it sees ';' and
// call [RecoverInline]. To recover, it sees that LA(1)==';'
// is in the set of tokens that can follow the ')' token reference
// in rule atom. It can assume that you forgot the ')'.
func (d *DefaultErrorStrategy) RecoverInline(recognizer Parser) Token {
// SINGLE TOKEN DELETION
MatchedSymbol := d.SingleTokenDeletion(recognizer)
@@ -396,24 +362,24 @@ func (d *DefaultErrorStrategy) RecoverInline(recognizer Parser) Token {
return d.GetMissingSymbol(recognizer)
}
// even that didn't work must panic the exception
panic(NewInputMisMatchException(recognizer))
recognizer.SetError(NewInputMisMatchException(recognizer))
return nil
}
// This method implements the single-token insertion inline error recovery
// strategy. It is called by {@link //recoverInline} if the single-token
// SingleTokenInsertion implements the single-token insertion inline error recovery
// strategy. It is called by [RecoverInline] if the single-token
// deletion strategy fails to recover from the mismatched input. If this
// method returns {@code true}, {@code recognizer} will be in error recovery
// mode.
//
// <p>This method determines whether or not single-token insertion is viable by
// checking if the {@code LA(1)} input symbol could be successfully Matched
// if it were instead the {@code LA(2)} symbol. If d method returns
// This method determines whether single-token insertion is viable by
// checking if the LA(1) input symbol could be successfully Matched
// if it were instead the LA(2) symbol. If this method returns
// {@code true}, the caller is responsible for creating and inserting a
// token with the correct type to produce d behavior.</p>
// token with the correct type to produce this behavior.</p>
//
// @param recognizer the parser instance
// @return {@code true} if single-token insertion is a viable recovery
// strategy for the current mismatched input, otherwise {@code false}
// This func returns true if single-token insertion is a viable recovery
// strategy for the current mismatched input.
func (d *DefaultErrorStrategy) SingleTokenInsertion(recognizer Parser) bool {
currentSymbolType := recognizer.GetTokenStream().LA(1)
// if current token is consistent with what could come after current
@@ -431,23 +397,21 @@ func (d *DefaultErrorStrategy) SingleTokenInsertion(recognizer Parser) bool {
return false
}
// This method implements the single-token deletion inline error recovery
// strategy. It is called by {@link //recoverInline} to attempt to recover
// SingleTokenDeletion implements the single-token deletion inline error recovery
// strategy. It is called by [RecoverInline] to attempt to recover
// from mismatched input. If this method returns nil, the parser and error
// handler state will not have changed. If this method returns non-nil,
// {@code recognizer} will <em>not</em> be in error recovery mode since the
// recognizer will not be in error recovery mode since the
// returned token was a successful Match.
//
// <p>If the single-token deletion is successful, d method calls
// {@link //ReportUnwantedToken} to Report the error, followed by
// {@link Parser//consume} to actually "delete" the extraneous token. Then,
// before returning {@link //ReportMatch} is called to signal a successful
// Match.</p>
// If the single-token deletion is successful, this method calls
// [ReportUnwantedToken] to Report the error, followed by
// [Consume] to actually delete the extraneous token. Then,
// before returning, [ReportMatch] is called to signal a successful
// Match.
//
// @param recognizer the parser instance
// @return the successfully Matched {@link Token} instance if single-token
// deletion successfully recovers from the mismatched input, otherwise
// {@code nil}
// The func returns the successfully Matched [Token] instance if single-token
// deletion successfully recovers from the mismatched input, otherwise nil.
func (d *DefaultErrorStrategy) SingleTokenDeletion(recognizer Parser) Token {
NextTokenType := recognizer.GetTokenStream().LA(2)
expecting := d.GetExpectedTokens(recognizer)
@@ -467,24 +431,28 @@ func (d *DefaultErrorStrategy) SingleTokenDeletion(recognizer Parser) Token {
return nil
}
// Conjure up a missing token during error recovery.
// GetMissingSymbol conjures up a missing token during error recovery.
//
// The recognizer attempts to recover from single missing
// symbols. But, actions might refer to that missing symbol.
// For example, x=ID {f($x)}. The action clearly assumes
// For example:
//
// x=ID {f($x)}.
//
// The action clearly assumes
// that there has been an identifier Matched previously and that
// $x points at that token. If that token is missing, but
// the next token in the stream is what we want we assume that
// d token is missing and we keep going. Because we
// this token is missing, and we keep going. Because we
// have to return some token to replace the missing token,
// we have to conjure one up. This method gives the user control
// over the tokens returned for missing tokens. Mostly,
// you will want to create something special for identifier
// tokens. For literals such as '{' and ',', the default
// action in the parser or tree parser works. It simply creates
// a CommonToken of the appropriate type. The text will be the token.
// If you change what tokens must be created by the lexer,
// override d method to create the appropriate tokens.
// a [CommonToken] of the appropriate type. The text will be the token name.
// If you need to change which tokens must be created by the lexer,
// override this method to create the appropriate tokens.
func (d *DefaultErrorStrategy) GetMissingSymbol(recognizer Parser) Token {
currentSymbol := recognizer.GetCurrentToken()
expecting := d.GetExpectedTokens(recognizer)
@@ -498,7 +466,7 @@ func (d *DefaultErrorStrategy) GetMissingSymbol(recognizer Parser) Token {
if expectedTokenType > 0 && expectedTokenType < len(ln) {
tokenText = "<missing " + recognizer.GetLiteralNames()[expectedTokenType] + ">"
} else {
tokenText = "<missing undefined>" // TODO matches the JS impl
tokenText = "<missing undefined>" // TODO: matches the JS impl
}
}
current := currentSymbol
@@ -516,13 +484,13 @@ func (d *DefaultErrorStrategy) GetExpectedTokens(recognizer Parser) *IntervalSet
return recognizer.GetExpectedTokens()
}
// How should a token be displayed in an error message? The default
// is to display just the text, but during development you might
// want to have a lot of information spit out. Override in that case
// to use t.String() (which, for CommonToken, dumps everything about
// GetTokenErrorDisplay determines how a token should be displayed in an error message.
// The default is to display just the text, but during development you might
// want to have a lot of information spit out. Override this func in that case
// to use t.String() (which, for [CommonToken], dumps everything about
// the token). This is better than forcing you to override a method in
// your token objects because you don't have to go modify your lexer
// so that it creates a NewJava type.
// so that it creates a new type.
func (d *DefaultErrorStrategy) GetTokenErrorDisplay(t Token) string {
if t == nil {
return "<no token>"
@@ -545,52 +513,57 @@ func (d *DefaultErrorStrategy) escapeWSAndQuote(s string) string {
return "'" + s + "'"
}
// Compute the error recovery set for the current rule. During
// GetErrorRecoverySet computes the error recovery set for the current rule. During
// rule invocation, the parser pushes the set of tokens that can
// follow that rule reference on the stack d amounts to
// follow that rule reference on the stack. This amounts to
// computing FIRST of what follows the rule reference in the
// enclosing rule. See LinearApproximator.FIRST().
//
// This local follow set only includes tokens
// from within the rule i.e., the FIRST computation done by
// ANTLR stops at the end of a rule.
//
// # EXAMPLE
// # Example
//
// When you find a "no viable alt exception", the input is not
// consistent with any of the alternatives for rule r. The best
// thing to do is to consume tokens until you see something that
// can legally follow a call to r//or* any rule that called r.
// can legally follow a call to r or any rule that called r.
// You don't want the exact set of viable next tokens because the
// input might just be missing a token--you might consume the
// rest of the input looking for one of the missing tokens.
//
// Consider grammar:
// Consider the grammar:
//
// a : '[' b ']'
// | '(' b ')'
// a : '[' b ']'
// | '(' b ')'
// ;
//
// b : c '^' INT
// c : ID
// | INT
// b : c '^' INT
// ;
//
// c : ID
// | INT
// ;
//
// At each rule invocation, the set of tokens that could follow
// that rule is pushed on a stack. Here are the various
// context-sensitive follow sets:
//
// FOLLOW(b1_in_a) = FIRST(']') = ']'
// FOLLOW(b2_in_a) = FIRST(')') = ')'
// FOLLOW(c_in_b) = FIRST('^') = '^'
// FOLLOW(b1_in_a) = FIRST(']') = ']'
// FOLLOW(b2_in_a) = FIRST(')') = ')'
// FOLLOW(c_in_b) = FIRST('^') = '^'
//
// Upon erroneous input "[]", the call chain is
// Upon erroneous input [], the call chain is
//
// a -> b -> c
// a b c
//
// and, hence, the follow context stack is:
//
// depth follow set start of rule execution
// 0 <EOF> a (from main())
// 1 ']' b
// 2 '^' c
// Depth Follow set Start of rule execution
// 0 <EOF> a (from main())
// 1 ']' b
// 2 '^' c
//
// Notice that ')' is not included, because b would have to have
// been called from a different context in rule a for ')' to be
@@ -598,11 +571,14 @@ func (d *DefaultErrorStrategy) escapeWSAndQuote(s string) string {
//
// For error recovery, we cannot consider FOLLOW(c)
// (context-sensitive or otherwise). We need the combined set of
// all context-sensitive FOLLOW sets--the set of all tokens that
// all context-sensitive FOLLOW sets - the set of all tokens that
// could follow any reference in the call chain. We need to
// reSync to one of those tokens. Note that FOLLOW(c)='^' and if
// we reSync'd to that token, we'd consume until EOF. We need to
// Sync to context-sensitive FOLLOWs for a, b, and c: {']','^'}.
// Sync to context-sensitive FOLLOWs for a, b, and c:
//
// {']','^'}
//
// In this case, for input "[]", LA(1) is ']' and in the set, so we would
// not consume anything. After printing an error, rule c would
// return normally. Rule b would not find the required '^' though.
@@ -620,22 +596,19 @@ func (d *DefaultErrorStrategy) escapeWSAndQuote(s string) string {
//
// ANTLR's error recovery mechanism is based upon original ideas:
//
// "Algorithms + Data Structures = Programs" by Niklaus Wirth
// [Algorithms + Data Structures = Programs] by Niklaus Wirth and
// [A note on error recovery in recursive descent parsers].
//
// and
// Later, Josef Grosch had some good ideas in [Efficient and Comfortable Error Recovery in Recursive Descent
// Parsers]
//
// "A note on error recovery in recursive descent parsers":
// http://portal.acm.org/citation.cfm?id=947902.947905
// Like Grosch I implement context-sensitive FOLLOW sets that are combined at run-time upon error to avoid overhead
// during parsing. Later, the runtime Sync was improved for loops/sub-rules see [Sync] docs
//
// Later, Josef Grosch had some good ideas:
//
// "Efficient and Comfortable Error Recovery in Recursive Descent
// Parsers":
// ftp://www.cocolab.com/products/cocktail/doca4.ps/ell.ps.zip
//
// Like Grosch I implement context-sensitive FOLLOW sets that are combined
// at run-time upon error to avoid overhead during parsing.
func (d *DefaultErrorStrategy) getErrorRecoverySet(recognizer Parser) *IntervalSet {
// [A note on error recovery in recursive descent parsers]: http://portal.acm.org/citation.cfm?id=947902.947905
// [Algorithms + Data Structures = Programs]: https://t.ly/5QzgE
// [Efficient and Comfortable Error Recovery in Recursive Descent Parsers]: ftp://www.cocolab.com/products/cocktail/doca4.ps/ell.ps.zip
func (d *DefaultErrorStrategy) GetErrorRecoverySet(recognizer Parser) *IntervalSet {
atn := recognizer.GetInterpreter().atn
ctx := recognizer.GetParserRuleContext()
recoverSet := NewIntervalSet()
@@ -660,40 +633,36 @@ func (d *DefaultErrorStrategy) consumeUntil(recognizer Parser, set *IntervalSet)
}
}
//
// This implementation of {@link ANTLRErrorStrategy} responds to syntax errors
// The BailErrorStrategy implementation of ANTLRErrorStrategy responds to syntax errors
// by immediately canceling the parse operation with a
// {@link ParseCancellationException}. The implementation ensures that the
// {@link ParserRuleContext//exception} field is set for all parse tree nodes
// [ParseCancellationException]. The implementation ensures that the
// [ParserRuleContext//exception] field is set for all parse tree nodes
// that were not completed prior to encountering the error.
//
// <p>
// This error strategy is useful in the following scenarios.</p>
// This error strategy is useful in the following scenarios.
//
// <ul>
// <li><strong>Two-stage parsing:</strong> This error strategy allows the first
// stage of two-stage parsing to immediately terminate if an error is
// encountered, and immediately fall back to the second stage. In addition to
// avoiding wasted work by attempting to recover from errors here, the empty
// implementation of {@link BailErrorStrategy//Sync} improves the performance of
// the first stage.</li>
// <li><strong>Silent validation:</strong> When syntax errors are not being
// Reported or logged, and the parse result is simply ignored if errors occur,
// the {@link BailErrorStrategy} avoids wasting work on recovering from errors
// when the result will be ignored either way.</li>
// </ul>
// - Two-stage parsing: This error strategy allows the first
// stage of two-stage parsing to immediately terminate if an error is
// encountered, and immediately fall back to the second stage. In addition to
// avoiding wasted work by attempting to recover from errors here, the empty
// implementation of [BailErrorStrategy.Sync] improves the performance of
// the first stage.
//
// <p>
// {@code myparser.setErrorHandler(NewBailErrorStrategy())}</p>
// - Silent validation: When syntax errors are not being
// Reported or logged, and the parse result is simply ignored if errors occur,
// the [BailErrorStrategy] avoids wasting work on recovering from errors
// when the result will be ignored either way.
//
// @see Parser//setErrorHandler(ANTLRErrorStrategy)
// myparser.SetErrorHandler(NewBailErrorStrategy())
//
// See also: [Parser.SetErrorHandler(ANTLRErrorStrategy)]
type BailErrorStrategy struct {
*DefaultErrorStrategy
}
var _ ErrorStrategy = &BailErrorStrategy{}
//goland:noinspection GoUnusedExportedFunction
func NewBailErrorStrategy() *BailErrorStrategy {
b := new(BailErrorStrategy)
@@ -703,10 +672,10 @@ func NewBailErrorStrategy() *BailErrorStrategy {
return b
}
// Instead of recovering from exception {@code e}, re-panic it wrapped
// in a {@link ParseCancellationException} so it is not caught by the
// rule func catches. Use {@link Exception//getCause()} to get the
// original {@link RecognitionException}.
// Recover Instead of recovering from exception e, re-panic it wrapped
// in a [ParseCancellationException] so it is not caught by the
// rule func catches. Use Exception.GetCause() to get the
// original [RecognitionException].
func (b *BailErrorStrategy) Recover(recognizer Parser, e RecognitionException) {
context := recognizer.GetParserRuleContext()
for context != nil {
@@ -717,10 +686,10 @@ func (b *BailErrorStrategy) Recover(recognizer Parser, e RecognitionException) {
context = nil
}
}
panic(NewParseCancellationException()) // TODO we don't emit e properly
recognizer.SetError(NewParseCancellationException()) // TODO: we don't emit e properly
}
// Make sure we don't attempt to recover inline if the parser
// RecoverInline makes sure we don't attempt to recover inline if the parser
// successfully recovers, it won't panic an exception.
func (b *BailErrorStrategy) RecoverInline(recognizer Parser) Token {
b.Recover(recognizer, NewInputMisMatchException(recognizer))
@@ -728,7 +697,6 @@ func (b *BailErrorStrategy) RecoverInline(recognizer Parser) Token {
return nil
}
// Make sure we don't attempt to recover from problems in subrules.//
func (b *BailErrorStrategy) Sync(recognizer Parser) {
// pass
// Sync makes sure we don't attempt to recover from problems in sub-rules.
func (b *BailErrorStrategy) Sync(_ Parser) {
}

View File

@@ -35,7 +35,7 @@ func NewBaseRecognitionException(message string, recognizer Recognizer, input In
// } else {
// stack := NewError().stack
// }
// TODO may be able to use - "runtime" func Stack(buf []byte, all bool) int
// TODO: may be able to use - "runtime" func Stack(buf []byte, all bool) int
t := new(BaseRecognitionException)
@@ -43,15 +43,17 @@ func NewBaseRecognitionException(message string, recognizer Recognizer, input In
t.recognizer = recognizer
t.input = input
t.ctx = ctx
// The current {@link Token} when an error occurred. Since not all streams
// The current Token when an error occurred. Since not all streams
// support accessing symbols by index, we have to track the {@link Token}
// instance itself.
//
t.offendingToken = nil
// Get the ATN state number the parser was in at the time the error
// occurred. For {@link NoViableAltException} and
// {@link LexerNoViableAltException} exceptions, this is the
// {@link DecisionState} number. For others, it is the state whose outgoing
// edge we couldn't Match.
// occurred. For NoViableAltException and LexerNoViableAltException exceptions, this is the
// DecisionState number. For others, it is the state whose outgoing edge we couldn't Match.
//
t.offendingState = -1
if t.recognizer != nil {
t.offendingState = t.recognizer.GetState()
@@ -74,15 +76,15 @@ func (b *BaseRecognitionException) GetInputStream() IntStream {
// <p>If the state number is not known, b method returns -1.</p>
// Gets the set of input symbols which could potentially follow the
// previously Matched symbol at the time b exception was panicn.
// getExpectedTokens gets the set of input symbols which could potentially follow the
// previously Matched symbol at the time this exception was raised.
//
// <p>If the set of expected tokens is not known and could not be computed,
// b method returns {@code nil}.</p>
// If the set of expected tokens is not known and could not be computed,
// this method returns nil.
//
// @return The set of token types that could potentially follow the current
// state in the ATN, or {@code nil} if the information is not available.
// /
// The func returns the set of token types that could potentially follow the current
// state in the {ATN}, or nil if the information is not available.
func (b *BaseRecognitionException) getExpectedTokens() *IntervalSet {
if b.recognizer != nil {
return b.recognizer.GetATN().getExpectedTokens(b.offendingState, b.ctx)
@@ -99,10 +101,10 @@ type LexerNoViableAltException struct {
*BaseRecognitionException
startIndex int
deadEndConfigs ATNConfigSet
deadEndConfigs *ATNConfigSet
}
func NewLexerNoViableAltException(lexer Lexer, input CharStream, startIndex int, deadEndConfigs ATNConfigSet) *LexerNoViableAltException {
func NewLexerNoViableAltException(lexer Lexer, input CharStream, startIndex int, deadEndConfigs *ATNConfigSet) *LexerNoViableAltException {
l := new(LexerNoViableAltException)
@@ -128,14 +130,16 @@ type NoViableAltException struct {
startToken Token
offendingToken Token
ctx ParserRuleContext
deadEndConfigs ATNConfigSet
deadEndConfigs *ATNConfigSet
}
// Indicates that the parser could not decide which of two or more paths
// NewNoViableAltException creates an exception indicating that the parser could not decide which of two or more paths
// to take based upon the remaining input. It tracks the starting token
// of the offending input and also knows where the parser was
// in the various paths when the error. Reported by ReportNoViableAlternative()
func NewNoViableAltException(recognizer Parser, input TokenStream, startToken Token, offendingToken Token, deadEndConfigs ATNConfigSet, ctx ParserRuleContext) *NoViableAltException {
// in the various paths when the error.
//
// Reported by [ReportNoViableAlternative]
func NewNoViableAltException(recognizer Parser, input TokenStream, startToken Token, offendingToken Token, deadEndConfigs *ATNConfigSet, ctx ParserRuleContext) *NoViableAltException {
if ctx == nil {
ctx = recognizer.GetParserRuleContext()
@@ -157,12 +161,14 @@ func NewNoViableAltException(recognizer Parser, input TokenStream, startToken To
n.BaseRecognitionException = NewBaseRecognitionException("", recognizer, input, ctx)
// Which configurations did we try at input.Index() that couldn't Match
// input.LT(1)?//
// input.LT(1)
n.deadEndConfigs = deadEndConfigs
// The token object at the start index the input stream might
// not be buffering tokens so get a reference to it. (At the
// time the error occurred, of course the stream needs to keep a
// buffer all of the tokens but later we might not have access to those.)
// not be buffering tokens so get a reference to it.
//
// At the time the error occurred, of course the stream needs to keep a
// buffer of all the tokens, but later we might not have access to those.
n.startToken = startToken
n.offendingToken = offendingToken
@@ -173,7 +179,7 @@ type InputMisMatchException struct {
*BaseRecognitionException
}
// This signifies any kind of mismatched input exceptions such as
// NewInputMisMatchException creates an exception that signifies any kind of mismatched input exceptions such as
// when the current input does not Match the expected token.
func NewInputMisMatchException(recognizer Parser) *InputMisMatchException {
@@ -186,11 +192,10 @@ func NewInputMisMatchException(recognizer Parser) *InputMisMatchException {
}
// A semantic predicate failed during validation. Validation of predicates
// FailedPredicateException indicates that a semantic predicate failed during validation. Validation of predicates
// occurs when normally parsing the alternative just like Matching a token.
// Disambiguating predicate evaluation occurs when we test a predicate during
// prediction.
type FailedPredicateException struct {
*BaseRecognitionException
@@ -199,6 +204,7 @@ type FailedPredicateException struct {
predicate string
}
//goland:noinspection GoUnusedExportedFunction
func NewFailedPredicateException(recognizer Parser, predicate string, message string) *FailedPredicateException {
f := new(FailedPredicateException)
@@ -231,6 +237,21 @@ func (f *FailedPredicateException) formatMessage(predicate, message string) stri
type ParseCancellationException struct {
}
func (p ParseCancellationException) GetOffendingToken() Token {
//TODO implement me
panic("implement me")
}
func (p ParseCancellationException) GetMessage() string {
//TODO implement me
panic("implement me")
}
func (p ParseCancellationException) GetInputStream() IntStream {
//TODO implement me
panic("implement me")
}
func NewParseCancellationException() *ParseCancellationException {
// Error.call(this)
// Error.captureStackTrace(this, ParseCancellationException)

View File

@@ -5,8 +5,7 @@
package antlr
import (
"bytes"
"io"
"bufio"
"os"
)
@@ -14,34 +13,53 @@ import (
// when you construct the object.
type FileStream struct {
*InputStream
InputStream
filename string
}
//goland:noinspection GoUnusedExportedFunction
func NewFileStream(fileName string) (*FileStream, error) {
buf := bytes.NewBuffer(nil)
f, err := os.Open(fileName)
if err != nil {
return nil, err
}
defer f.Close()
_, err = io.Copy(buf, f)
defer func(f *os.File) {
errF := f.Close()
if errF != nil {
}
}(f)
reader := bufio.NewReader(f)
fInfo, err := f.Stat()
if err != nil {
return nil, err
}
fs := new(FileStream)
fs := &FileStream{
InputStream: InputStream{
index: 0,
name: fileName,
},
filename: fileName,
}
fs.filename = fileName
s := string(buf.Bytes())
fs.InputStream = NewInputStream(s)
// Pre-build the buffer and read runes efficiently
//
fs.data = make([]rune, 0, fInfo.Size())
for {
r, _, err := reader.ReadRune()
if err != nil {
break
}
fs.data = append(fs.data, r)
}
fs.size = len(fs.data) // Size in runes
// All done.
//
return fs, nil
}
func (f *FileStream) GetSourceName() string {

157
vendor/github.com/antlr4-go/antlr/v4/input_stream.go generated vendored Normal file
View File

@@ -0,0 +1,157 @@
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
package antlr
import (
"bufio"
"io"
)
type InputStream struct {
name string
index int
data []rune
size int
}
// NewIoStream creates a new input stream from the given io.Reader reader.
// Note that the reader is read completely into memory and so it must actually
// have a stopping point - you cannot pass in a reader on an open-ended source such
// as a socket for instance.
func NewIoStream(reader io.Reader) *InputStream {
rReader := bufio.NewReader(reader)
is := &InputStream{
name: "<empty>",
index: 0,
}
// Pre-build the buffer and read runes reasonably efficiently given that
// we don't exactly know how big the input is.
//
is.data = make([]rune, 0, 512)
for {
r, _, err := rReader.ReadRune()
if err != nil {
break
}
is.data = append(is.data, r)
}
is.size = len(is.data) // number of runes
return is
}
// NewInputStream creates a new input stream from the given string
func NewInputStream(data string) *InputStream {
is := &InputStream{
name: "<empty>",
index: 0,
data: []rune(data), // This is actually the most efficient way
}
is.size = len(is.data) // number of runes, but we could also use len(data), which is efficient too
return is
}
func (is *InputStream) reset() {
is.index = 0
}
// Consume moves the input pointer to the next character in the input stream
func (is *InputStream) Consume() {
if is.index >= is.size {
// assert is.LA(1) == TokenEOF
panic("cannot consume EOF")
}
is.index++
}
// LA returns the character at the given offset from the start of the input stream
func (is *InputStream) LA(offset int) int {
if offset == 0 {
return 0 // nil
}
if offset < 0 {
offset++ // e.g., translate LA(-1) to use offset=0
}
pos := is.index + offset - 1
if pos < 0 || pos >= is.size { // invalid
return TokenEOF
}
return int(is.data[pos])
}
// LT returns the character at the given offset from the start of the input stream
func (is *InputStream) LT(offset int) int {
return is.LA(offset)
}
// Index returns the current offset in to the input stream
func (is *InputStream) Index() int {
return is.index
}
// Size returns the total number of characters in the input stream
func (is *InputStream) Size() int {
return is.size
}
// Mark does nothing here as we have entire buffer
func (is *InputStream) Mark() int {
return -1
}
// Release does nothing here as we have entire buffer
func (is *InputStream) Release(_ int) {
}
// Seek the input point to the provided index offset
func (is *InputStream) Seek(index int) {
if index <= is.index {
is.index = index // just jump don't update stream state (line,...)
return
}
// seek forward
is.index = intMin(index, is.size)
}
// GetText returns the text from the input stream from the start to the stop index
func (is *InputStream) GetText(start int, stop int) string {
if stop >= is.size {
stop = is.size - 1
}
if start >= is.size {
return ""
}
return string(is.data[start : stop+1])
}
// GetTextFromTokens returns the text from the input stream from the first character of the start token to the last
// character of the stop token
func (is *InputStream) GetTextFromTokens(start, stop Token) string {
if start != nil && stop != nil {
return is.GetTextFromInterval(NewInterval(start.GetTokenIndex(), stop.GetTokenIndex()))
}
return ""
}
func (is *InputStream) GetTextFromInterval(i Interval) string {
return is.GetText(i.Start, i.Stop)
}
func (*InputStream) GetSourceName() string {
return ""
}
// String returns the entire input stream as a string
func (is *InputStream) String() string {
return string(is.data)
}

View File

@@ -14,20 +14,21 @@ type Interval struct {
Stop int
}
/* stop is not included! */
func NewInterval(start, stop int) *Interval {
i := new(Interval)
i.Start = start
i.Stop = stop
return i
// NewInterval creates a new interval with the given start and stop values.
func NewInterval(start, stop int) Interval {
return Interval{
Start: start,
Stop: stop,
}
}
func (i *Interval) Contains(item int) bool {
// Contains returns true if the given item is contained within the interval.
func (i Interval) Contains(item int) bool {
return item >= i.Start && item < i.Stop
}
func (i *Interval) String() string {
// String generates a string representation of the interval.
func (i Interval) String() string {
if i.Start == i.Stop-1 {
return strconv.Itoa(i.Start)
}
@@ -35,15 +36,18 @@ func (i *Interval) String() string {
return strconv.Itoa(i.Start) + ".." + strconv.Itoa(i.Stop-1)
}
func (i *Interval) length() int {
// Length returns the length of the interval.
func (i Interval) Length() int {
return i.Stop - i.Start
}
// IntervalSet represents a collection of [Intervals], which may be read-only.
type IntervalSet struct {
intervals []*Interval
intervals []Interval
readOnly bool
}
// NewIntervalSet creates a new empty, writable, interval set.
func NewIntervalSet() *IntervalSet {
i := new(IntervalSet)
@@ -54,6 +58,20 @@ func NewIntervalSet() *IntervalSet {
return i
}
func (i *IntervalSet) Equals(other *IntervalSet) bool {
if len(i.intervals) != len(other.intervals) {
return false
}
for k, v := range i.intervals {
if v.Start != other.intervals[k].Start || v.Stop != other.intervals[k].Stop {
return false
}
}
return true
}
func (i *IntervalSet) first() int {
if len(i.intervals) == 0 {
return TokenInvalidType
@@ -70,16 +88,16 @@ func (i *IntervalSet) addRange(l, h int) {
i.addInterval(NewInterval(l, h+1))
}
func (i *IntervalSet) addInterval(v *Interval) {
func (i *IntervalSet) addInterval(v Interval) {
if i.intervals == nil {
i.intervals = make([]*Interval, 0)
i.intervals = make([]Interval, 0)
i.intervals = append(i.intervals, v)
} else {
// find insert pos
for k, interval := range i.intervals {
// distinct range -> insert
if v.Stop < interval.Start {
i.intervals = append(i.intervals[0:k], append([]*Interval{v}, i.intervals[k:]...)...)
i.intervals = append(i.intervals[0:k], append([]Interval{v}, i.intervals[k:]...)...)
return
} else if v.Stop == interval.Start {
i.intervals[k].Start = v.Start
@@ -139,16 +157,16 @@ func (i *IntervalSet) contains(item int) bool {
}
func (i *IntervalSet) length() int {
len := 0
iLen := 0
for _, v := range i.intervals {
len += v.length()
iLen += v.Length()
}
return len
return iLen
}
func (i *IntervalSet) removeRange(v *Interval) {
func (i *IntervalSet) removeRange(v Interval) {
if v.Start == v.Stop-1 {
i.removeOne(v.Start)
} else if i.intervals != nil {
@@ -162,7 +180,7 @@ func (i *IntervalSet) removeRange(v *Interval) {
i.intervals[k] = NewInterval(ni.Start, v.Start)
x := NewInterval(v.Stop, ni.Stop)
// i.intervals.splice(k, 0, x)
i.intervals = append(i.intervals[0:k], append([]*Interval{x}, i.intervals[k:]...)...)
i.intervals = append(i.intervals[0:k], append([]Interval{x}, i.intervals[k:]...)...)
return
} else if v.Start <= ni.Start && v.Stop >= ni.Stop {
// i.intervals.splice(k, 1)
@@ -199,7 +217,7 @@ func (i *IntervalSet) removeOne(v int) {
x := NewInterval(ki.Start, v)
ki.Start = v + 1
// i.intervals.splice(k, 0, x)
i.intervals = append(i.intervals[0:k], append([]*Interval{x}, i.intervals[k:]...)...)
i.intervals = append(i.intervals[0:k], append([]Interval{x}, i.intervals[k:]...)...)
return
}
}
@@ -223,7 +241,7 @@ func (i *IntervalSet) StringVerbose(literalNames []string, symbolicNames []strin
return i.toIndexString()
}
func (i *IntervalSet) GetIntervals() []*Interval {
func (i *IntervalSet) GetIntervals() []Interval {
return i.intervals
}

685
vendor/github.com/antlr4-go/antlr/v4/jcollect.go generated vendored Normal file
View File

@@ -0,0 +1,685 @@
package antlr
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
import (
"container/list"
"runtime/debug"
"sort"
"sync"
)
// Collectable is an interface that a struct should implement if it is to be
// usable as a key in these collections.
type Collectable[T any] interface {
Hash() int
Equals(other Collectable[T]) bool
}
type Comparator[T any] interface {
Hash1(o T) int
Equals2(T, T) bool
}
type CollectionSource int
type CollectionDescriptor struct {
SybolicName string
Description string
}
const (
UnknownCollection CollectionSource = iota
ATNConfigLookupCollection
ATNStateCollection
DFAStateCollection
ATNConfigCollection
PredictionContextCollection
SemanticContextCollection
ClosureBusyCollection
PredictionVisitedCollection
MergeCacheCollection
PredictionContextCacheCollection
AltSetCollection
ReachSetCollection
)
var CollectionDescriptors = map[CollectionSource]CollectionDescriptor{
UnknownCollection: {
SybolicName: "UnknownCollection",
Description: "Unknown collection type. Only used if the target author thought it was an unimportant collection.",
},
ATNConfigCollection: {
SybolicName: "ATNConfigCollection",
Description: "ATNConfig collection. Used to store the ATNConfigs for a particular state in the ATN." +
"For instance, it is used to store the results of the closure() operation in the ATN.",
},
ATNConfigLookupCollection: {
SybolicName: "ATNConfigLookupCollection",
Description: "ATNConfigLookup collection. Used to store the ATNConfigs for a particular state in the ATN." +
"This is used to prevent duplicating equivalent states in an ATNConfigurationSet.",
},
ATNStateCollection: {
SybolicName: "ATNStateCollection",
Description: "ATNState collection. This is used to store the states of the ATN.",
},
DFAStateCollection: {
SybolicName: "DFAStateCollection",
Description: "DFAState collection. This is used to store the states of the DFA.",
},
PredictionContextCollection: {
SybolicName: "PredictionContextCollection",
Description: "PredictionContext collection. This is used to store the prediction contexts of the ATN and cache computes.",
},
SemanticContextCollection: {
SybolicName: "SemanticContextCollection",
Description: "SemanticContext collection. This is used to store the semantic contexts of the ATN.",
},
ClosureBusyCollection: {
SybolicName: "ClosureBusyCollection",
Description: "ClosureBusy collection. This is used to check and prevent infinite recursion right recursive rules." +
"It stores ATNConfigs that are currently being processed in the closure() operation.",
},
PredictionVisitedCollection: {
SybolicName: "PredictionVisitedCollection",
Description: "A map that records whether we have visited a particular context when searching through cached entries.",
},
MergeCacheCollection: {
SybolicName: "MergeCacheCollection",
Description: "A map that records whether we have already merged two particular contexts and can save effort by not repeating it.",
},
PredictionContextCacheCollection: {
SybolicName: "PredictionContextCacheCollection",
Description: "A map that records whether we have already created a particular context and can save effort by not computing it again.",
},
AltSetCollection: {
SybolicName: "AltSetCollection",
Description: "Used to eliminate duplicate alternatives in an ATN config set.",
},
ReachSetCollection: {
SybolicName: "ReachSetCollection",
Description: "Used as merge cache to prevent us needing to compute the merge of two states if we have already done it.",
},
}
// JStore implements a container that allows the use of a struct to calculate the key
// for a collection of values akin to map. This is not meant to be a full-blown HashMap but just
// serve the needs of the ANTLR Go runtime.
//
// For ease of porting the logic of the runtime from the master target (Java), this collection
// operates in a similar way to Java, in that it can use any struct that supplies a Hash() and Equals()
// function as the key. The values are stored in a standard go map which internally is a form of hashmap
// itself, the key for the go map is the hash supplied by the key object. The collection is able to deal with
// hash conflicts by using a simple slice of values associated with the hash code indexed bucket. That isn't
// particularly efficient, but it is simple, and it works. As this is specifically for the ANTLR runtime, and
// we understand the requirements, then this is fine - this is not a general purpose collection.
type JStore[T any, C Comparator[T]] struct {
store map[int][]T
len int
comparator Comparator[T]
stats *JStatRec
}
func NewJStore[T any, C Comparator[T]](comparator Comparator[T], cType CollectionSource, desc string) *JStore[T, C] {
if comparator == nil {
panic("comparator cannot be nil")
}
s := &JStore[T, C]{
store: make(map[int][]T, 1),
comparator: comparator,
}
if collectStats {
s.stats = &JStatRec{
Source: cType,
Description: desc,
}
// Track where we created it from if we are being asked to do so
if runtimeConfig.statsTraceStacks {
s.stats.CreateStack = debug.Stack()
}
Statistics.AddJStatRec(s.stats)
}
return s
}
// Put will store given value in the collection. Note that the key for storage is generated from
// the value itself - this is specifically because that is what ANTLR needs - this would not be useful
// as any kind of general collection.
//
// If the key has a hash conflict, then the value will be added to the slice of values associated with the
// hash, unless the value is already in the slice, in which case the existing value is returned. Value equivalence is
// tested by calling the equals() method on the key.
//
// # If the given value is already present in the store, then the existing value is returned as v and exists is set to true
//
// If the given value is not present in the store, then the value is added to the store and returned as v and exists is set to false.
func (s *JStore[T, C]) Put(value T) (v T, exists bool) {
if collectStats {
s.stats.Puts++
}
kh := s.comparator.Hash1(value)
var hClash bool
for _, v1 := range s.store[kh] {
hClash = true
if s.comparator.Equals2(value, v1) {
if collectStats {
s.stats.PutHits++
s.stats.PutHashConflicts++
}
return v1, true
}
if collectStats {
s.stats.PutMisses++
}
}
if collectStats && hClash {
s.stats.PutHashConflicts++
}
s.store[kh] = append(s.store[kh], value)
if collectStats {
if len(s.store[kh]) > s.stats.MaxSlotSize {
s.stats.MaxSlotSize = len(s.store[kh])
}
}
s.len++
if collectStats {
s.stats.CurSize = s.len
if s.len > s.stats.MaxSize {
s.stats.MaxSize = s.len
}
}
return value, false
}
// Get will return the value associated with the key - the type of the key is the same type as the value
// which would not generally be useful, but this is a specific thing for ANTLR where the key is
// generated using the object we are going to store.
func (s *JStore[T, C]) Get(key T) (T, bool) {
if collectStats {
s.stats.Gets++
}
kh := s.comparator.Hash1(key)
var hClash bool
for _, v := range s.store[kh] {
hClash = true
if s.comparator.Equals2(key, v) {
if collectStats {
s.stats.GetHits++
s.stats.GetHashConflicts++
}
return v, true
}
if collectStats {
s.stats.GetMisses++
}
}
if collectStats {
if hClash {
s.stats.GetHashConflicts++
}
s.stats.GetNoEnt++
}
return key, false
}
// Contains returns true if the given key is present in the store
func (s *JStore[T, C]) Contains(key T) bool {
_, present := s.Get(key)
return present
}
func (s *JStore[T, C]) SortedSlice(less func(i, j T) bool) []T {
vs := make([]T, 0, len(s.store))
for _, v := range s.store {
vs = append(vs, v...)
}
sort.Slice(vs, func(i, j int) bool {
return less(vs[i], vs[j])
})
return vs
}
func (s *JStore[T, C]) Each(f func(T) bool) {
for _, e := range s.store {
for _, v := range e {
f(v)
}
}
}
func (s *JStore[T, C]) Len() int {
return s.len
}
func (s *JStore[T, C]) Values() []T {
vs := make([]T, 0, len(s.store))
for _, e := range s.store {
vs = append(vs, e...)
}
return vs
}
type entry[K, V any] struct {
key K
val V
}
type JMap[K, V any, C Comparator[K]] struct {
store map[int][]*entry[K, V]
len int
comparator Comparator[K]
stats *JStatRec
}
func NewJMap[K, V any, C Comparator[K]](comparator Comparator[K], cType CollectionSource, desc string) *JMap[K, V, C] {
m := &JMap[K, V, C]{
store: make(map[int][]*entry[K, V], 1),
comparator: comparator,
}
if collectStats {
m.stats = &JStatRec{
Source: cType,
Description: desc,
}
// Track where we created it from if we are being asked to do so
if runtimeConfig.statsTraceStacks {
m.stats.CreateStack = debug.Stack()
}
Statistics.AddJStatRec(m.stats)
}
return m
}
func (m *JMap[K, V, C]) Put(key K, val V) (V, bool) {
if collectStats {
m.stats.Puts++
}
kh := m.comparator.Hash1(key)
var hClash bool
for _, e := range m.store[kh] {
hClash = true
if m.comparator.Equals2(e.key, key) {
if collectStats {
m.stats.PutHits++
m.stats.PutHashConflicts++
}
return e.val, true
}
if collectStats {
m.stats.PutMisses++
}
}
if collectStats {
if hClash {
m.stats.PutHashConflicts++
}
}
m.store[kh] = append(m.store[kh], &entry[K, V]{key, val})
if collectStats {
if len(m.store[kh]) > m.stats.MaxSlotSize {
m.stats.MaxSlotSize = len(m.store[kh])
}
}
m.len++
if collectStats {
m.stats.CurSize = m.len
if m.len > m.stats.MaxSize {
m.stats.MaxSize = m.len
}
}
return val, false
}
func (m *JMap[K, V, C]) Values() []V {
vs := make([]V, 0, len(m.store))
for _, e := range m.store {
for _, v := range e {
vs = append(vs, v.val)
}
}
return vs
}
func (m *JMap[K, V, C]) Get(key K) (V, bool) {
if collectStats {
m.stats.Gets++
}
var none V
kh := m.comparator.Hash1(key)
var hClash bool
for _, e := range m.store[kh] {
hClash = true
if m.comparator.Equals2(e.key, key) {
if collectStats {
m.stats.GetHits++
m.stats.GetHashConflicts++
}
return e.val, true
}
if collectStats {
m.stats.GetMisses++
}
}
if collectStats {
if hClash {
m.stats.GetHashConflicts++
}
m.stats.GetNoEnt++
}
return none, false
}
func (m *JMap[K, V, C]) Len() int {
return m.len
}
func (m *JMap[K, V, C]) Delete(key K) {
kh := m.comparator.Hash1(key)
for i, e := range m.store[kh] {
if m.comparator.Equals2(e.key, key) {
m.store[kh] = append(m.store[kh][:i], m.store[kh][i+1:]...)
m.len--
return
}
}
}
func (m *JMap[K, V, C]) Clear() {
m.store = make(map[int][]*entry[K, V])
}
type JPCMap struct {
store *JMap[*PredictionContext, *JMap[*PredictionContext, *PredictionContext, *ObjEqComparator[*PredictionContext]], *ObjEqComparator[*PredictionContext]]
size int
stats *JStatRec
}
func NewJPCMap(cType CollectionSource, desc string) *JPCMap {
m := &JPCMap{
store: NewJMap[*PredictionContext, *JMap[*PredictionContext, *PredictionContext, *ObjEqComparator[*PredictionContext]], *ObjEqComparator[*PredictionContext]](pContextEqInst, cType, desc),
}
if collectStats {
m.stats = &JStatRec{
Source: cType,
Description: desc,
}
// Track where we created it from if we are being asked to do so
if runtimeConfig.statsTraceStacks {
m.stats.CreateStack = debug.Stack()
}
Statistics.AddJStatRec(m.stats)
}
return m
}
func (pcm *JPCMap) Get(k1, k2 *PredictionContext) (*PredictionContext, bool) {
if collectStats {
pcm.stats.Gets++
}
// Do we have a map stored by k1?
//
m2, present := pcm.store.Get(k1)
if present {
if collectStats {
pcm.stats.GetHits++
}
// We found a map of values corresponding to k1, so now we need to look up k2 in that map
//
return m2.Get(k2)
}
if collectStats {
pcm.stats.GetMisses++
}
return nil, false
}
func (pcm *JPCMap) Put(k1, k2, v *PredictionContext) {
if collectStats {
pcm.stats.Puts++
}
// First does a map already exist for k1?
//
if m2, present := pcm.store.Get(k1); present {
if collectStats {
pcm.stats.PutHits++
}
_, present = m2.Put(k2, v)
if !present {
pcm.size++
if collectStats {
pcm.stats.CurSize = pcm.size
if pcm.size > pcm.stats.MaxSize {
pcm.stats.MaxSize = pcm.size
}
}
}
} else {
// No map found for k1, so we create it, add in our value, then store is
//
if collectStats {
pcm.stats.PutMisses++
m2 = NewJMap[*PredictionContext, *PredictionContext, *ObjEqComparator[*PredictionContext]](pContextEqInst, pcm.stats.Source, pcm.stats.Description+" map entry")
} else {
m2 = NewJMap[*PredictionContext, *PredictionContext, *ObjEqComparator[*PredictionContext]](pContextEqInst, PredictionContextCacheCollection, "map entry")
}
m2.Put(k2, v)
pcm.store.Put(k1, m2)
pcm.size++
}
}
type JPCMap2 struct {
store map[int][]JPCEntry
size int
stats *JStatRec
}
type JPCEntry struct {
k1, k2, v *PredictionContext
}
func NewJPCMap2(cType CollectionSource, desc string) *JPCMap2 {
m := &JPCMap2{
store: make(map[int][]JPCEntry, 1000),
}
if collectStats {
m.stats = &JStatRec{
Source: cType,
Description: desc,
}
// Track where we created it from if we are being asked to do so
if runtimeConfig.statsTraceStacks {
m.stats.CreateStack = debug.Stack()
}
Statistics.AddJStatRec(m.stats)
}
return m
}
func dHash(k1, k2 *PredictionContext) int {
return k1.cachedHash*31 + k2.cachedHash
}
func (pcm *JPCMap2) Get(k1, k2 *PredictionContext) (*PredictionContext, bool) {
if collectStats {
pcm.stats.Gets++
}
h := dHash(k1, k2)
var hClash bool
for _, e := range pcm.store[h] {
hClash = true
if e.k1.Equals(k1) && e.k2.Equals(k2) {
if collectStats {
pcm.stats.GetHits++
pcm.stats.GetHashConflicts++
}
return e.v, true
}
if collectStats {
pcm.stats.GetMisses++
}
}
if collectStats {
if hClash {
pcm.stats.GetHashConflicts++
}
pcm.stats.GetNoEnt++
}
return nil, false
}
func (pcm *JPCMap2) Put(k1, k2, v *PredictionContext) (*PredictionContext, bool) {
if collectStats {
pcm.stats.Puts++
}
h := dHash(k1, k2)
var hClash bool
for _, e := range pcm.store[h] {
hClash = true
if e.k1.Equals(k1) && e.k2.Equals(k2) {
if collectStats {
pcm.stats.PutHits++
pcm.stats.PutHashConflicts++
}
return e.v, true
}
if collectStats {
pcm.stats.PutMisses++
}
}
if collectStats {
if hClash {
pcm.stats.PutHashConflicts++
}
}
pcm.store[h] = append(pcm.store[h], JPCEntry{k1, k2, v})
pcm.size++
if collectStats {
pcm.stats.CurSize = pcm.size
if pcm.size > pcm.stats.MaxSize {
pcm.stats.MaxSize = pcm.size
}
}
return nil, false
}
type VisitEntry struct {
k *PredictionContext
v *PredictionContext
}
type VisitRecord struct {
store map[*PredictionContext]*PredictionContext
len int
stats *JStatRec
}
type VisitList struct {
cache *list.List
lock sync.RWMutex
}
var visitListPool = VisitList{
cache: list.New(),
lock: sync.RWMutex{},
}
// NewVisitRecord returns a new VisitRecord instance from the pool if available.
// Note that this "map" uses a pointer as a key because we are emulating the behavior of
// IdentityHashMap in Java, which uses the `==` operator to compare whether the keys are equal,
// which means is the key the same reference to an object rather than is it .equals() to another
// object.
func NewVisitRecord() *VisitRecord {
visitListPool.lock.Lock()
el := visitListPool.cache.Front()
defer visitListPool.lock.Unlock()
var vr *VisitRecord
if el == nil {
vr = &VisitRecord{
store: make(map[*PredictionContext]*PredictionContext),
}
if collectStats {
vr.stats = &JStatRec{
Source: PredictionContextCacheCollection,
Description: "VisitRecord",
}
// Track where we created it from if we are being asked to do so
if runtimeConfig.statsTraceStacks {
vr.stats.CreateStack = debug.Stack()
}
}
} else {
vr = el.Value.(*VisitRecord)
visitListPool.cache.Remove(el)
vr.store = make(map[*PredictionContext]*PredictionContext)
}
if collectStats {
Statistics.AddJStatRec(vr.stats)
}
return vr
}
func (vr *VisitRecord) Release() {
vr.len = 0
vr.store = nil
if collectStats {
vr.stats.MaxSize = 0
vr.stats.CurSize = 0
vr.stats.Gets = 0
vr.stats.GetHits = 0
vr.stats.GetMisses = 0
vr.stats.GetHashConflicts = 0
vr.stats.GetNoEnt = 0
vr.stats.Puts = 0
vr.stats.PutHits = 0
vr.stats.PutMisses = 0
vr.stats.PutHashConflicts = 0
vr.stats.MaxSlotSize = 0
}
visitListPool.lock.Lock()
visitListPool.cache.PushBack(vr)
visitListPool.lock.Unlock()
}
func (vr *VisitRecord) Get(k *PredictionContext) (*PredictionContext, bool) {
if collectStats {
vr.stats.Gets++
}
v := vr.store[k]
if v != nil {
if collectStats {
vr.stats.GetHits++
}
return v, true
}
if collectStats {
vr.stats.GetNoEnt++
}
return nil, false
}
func (vr *VisitRecord) Put(k, v *PredictionContext) (*PredictionContext, bool) {
if collectStats {
vr.stats.Puts++
}
vr.store[k] = v
vr.len++
if collectStats {
vr.stats.CurSize = vr.len
if vr.len > vr.stats.MaxSize {
vr.stats.MaxSize = vr.len
}
}
return v, false
}

View File

@@ -69,7 +69,7 @@ func NewBaseLexer(input CharStream) *BaseLexer {
// create a single token. NextToken will return l object after
// Matching lexer rule(s). If you subclass to allow multiple token
// emissions, then set l to the last token to be Matched or
// something nonnil so that the auto token emit mechanism will not
// something non nil so that the auto token emit mechanism will not
// emit another token.
lexer.token = nil
@@ -111,6 +111,7 @@ const (
LexerSkip = -3
)
//goland:noinspection GoUnusedConst
const (
LexerDefaultTokenChannel = TokenDefaultChannel
LexerHidden = TokenHiddenChannel
@@ -118,7 +119,7 @@ const (
LexerMaxCharValue = 0x10FFFF
)
func (b *BaseLexer) reset() {
func (b *BaseLexer) Reset() {
// wack Lexer state variables
if b.input != nil {
b.input.Seek(0) // rewind the input
@@ -176,7 +177,7 @@ func (b *BaseLexer) safeMatch() (ret int) {
return b.Interpreter.Match(b.input, b.mode)
}
// Return a token from l source i.e., Match a token on the char stream.
// NextToken returns a token from the lexer input source i.e., Match a token on the source char stream.
func (b *BaseLexer) NextToken() Token {
if b.input == nil {
panic("NextToken requires a non-nil input stream.")
@@ -205,9 +206,8 @@ func (b *BaseLexer) NextToken() Token {
continueOuter := false
for {
b.thetype = TokenInvalidType
ttype := LexerSkip
ttype = b.safeMatch()
ttype := b.safeMatch()
if b.input.LA(1) == TokenEOF {
b.hitEOF = true
@@ -234,12 +234,11 @@ func (b *BaseLexer) NextToken() Token {
}
}
// Instruct the lexer to Skip creating a token for current lexer rule
// and look for another token. NextToken() knows to keep looking when
// a lexer rule finishes with token set to SKIPTOKEN. Recall that
// Skip instructs the lexer to Skip creating a token for current lexer rule
// and look for another token. [NextToken] knows to keep looking when
// a lexer rule finishes with token set to [SKIPTOKEN]. Recall that
// if token==nil at end of any token rule, it creates one for you
// and emits it.
// /
func (b *BaseLexer) Skip() {
b.thetype = LexerSkip
}
@@ -248,23 +247,29 @@ func (b *BaseLexer) More() {
b.thetype = LexerMore
}
// SetMode changes the lexer to a new mode. The lexer will use this mode from hereon in and the rules for that mode
// will be in force.
func (b *BaseLexer) SetMode(m int) {
b.mode = m
}
// PushMode saves the current lexer mode so that it can be restored later. See [PopMode], then sets the
// current lexer mode to the supplied mode m.
func (b *BaseLexer) PushMode(m int) {
if LexerATNSimulatorDebug {
if runtimeConfig.lexerATNSimulatorDebug {
fmt.Println("pushMode " + strconv.Itoa(m))
}
b.modeStack.Push(b.mode)
b.mode = m
}
// PopMode restores the lexer mode saved by a call to [PushMode]. It is a panic error if there is no saved mode to
// return to.
func (b *BaseLexer) PopMode() int {
if len(b.modeStack) == 0 {
panic("Empty Stack")
}
if LexerATNSimulatorDebug {
if runtimeConfig.lexerATNSimulatorDebug {
fmt.Println("popMode back to " + fmt.Sprint(b.modeStack[0:len(b.modeStack)-1]))
}
i, _ := b.modeStack.Pop()
@@ -280,7 +285,7 @@ func (b *BaseLexer) inputStream() CharStream {
func (b *BaseLexer) SetInputStream(input CharStream) {
b.input = nil
b.tokenFactorySourcePair = &TokenSourceCharStreamPair{b, b.input}
b.reset()
b.Reset()
b.input = input
b.tokenFactorySourcePair = &TokenSourceCharStreamPair{b, b.input}
}
@@ -289,20 +294,19 @@ func (b *BaseLexer) GetTokenSourceCharStreamPair() *TokenSourceCharStreamPair {
return b.tokenFactorySourcePair
}
// By default does not support multiple emits per NextToken invocation
// for efficiency reasons. Subclass and override l method, NextToken,
// and GetToken (to push tokens into a list and pull from that list
// rather than a single variable as l implementation does).
// /
// EmitToken by default does not support multiple emits per [NextToken] invocation
// for efficiency reasons. Subclass and override this func, [NextToken],
// and [GetToken] (to push tokens into a list and pull from that list
// rather than a single variable as this implementation does).
func (b *BaseLexer) EmitToken(token Token) {
b.token = token
}
// The standard method called to automatically emit a token at the
// Emit is the standard method called to automatically emit a token at the
// outermost lexical rule. The token object should point into the
// char buffer start..stop. If there is a text override in 'text',
// use that to set the token's text. Override l method to emit
// custom Token objects or provide a Newfactory.
// use that to set the token's text. Override this method to emit
// custom [Token] objects or provide a new factory.
// /
func (b *BaseLexer) Emit() Token {
t := b.factory.Create(b.tokenFactorySourcePair, b.thetype, b.text, b.channel, b.TokenStartCharIndex, b.GetCharIndex()-1, b.TokenStartLine, b.TokenStartColumn)
@@ -310,6 +314,7 @@ func (b *BaseLexer) Emit() Token {
return t
}
// EmitEOF emits an EOF token. By default, this is the last token emitted
func (b *BaseLexer) EmitEOF() Token {
cpos := b.GetCharPositionInLine()
lpos := b.GetLine()
@@ -318,6 +323,7 @@ func (b *BaseLexer) EmitEOF() Token {
return eof
}
// GetCharPositionInLine returns the current position in the current line as far as the lexer is concerned.
func (b *BaseLexer) GetCharPositionInLine() int {
return b.Interpreter.GetCharPositionInLine()
}
@@ -334,13 +340,12 @@ func (b *BaseLexer) SetType(t int) {
b.thetype = t
}
// What is the index of the current character of lookahead?///
// GetCharIndex returns the index of the current character of lookahead
func (b *BaseLexer) GetCharIndex() int {
return b.input.Index()
}
// Return the text Matched so far for the current token or any text override.
// Set the complete text of l token it wipes any previous changes to the text.
// GetText returns the text Matched so far for the current token or any text override.
func (b *BaseLexer) GetText() string {
if b.text != "" {
return b.text
@@ -349,17 +354,20 @@ func (b *BaseLexer) GetText() string {
return b.Interpreter.GetText(b.input)
}
// SetText sets the complete text of this token; it wipes any previous changes to the text.
func (b *BaseLexer) SetText(text string) {
b.text = text
}
// GetATN returns the ATN used by the lexer.
func (b *BaseLexer) GetATN() *ATN {
return b.Interpreter.ATN()
}
// Return a list of all Token objects in input char stream.
// Forces load of all tokens. Does not include EOF token.
// /
// GetAllTokens returns a list of all [Token] objects in input char stream.
// Forces a load of all tokens that can be made from the input char stream.
//
// Does not include EOF token.
func (b *BaseLexer) GetAllTokens() []Token {
vl := b.Virt
tokens := make([]Token, 0)
@@ -398,11 +406,13 @@ func (b *BaseLexer) getCharErrorDisplay(c rune) string {
return "'" + b.getErrorDisplayForChar(c) + "'"
}
// Lexers can normally Match any char in it's vocabulary after Matching
// a token, so do the easy thing and just kill a character and hope
// Recover can normally Match any char in its vocabulary after Matching
// a token, so here we do the easy thing and just kill a character and hope
// it all works out. You can instead use the rule invocation stack
// to do sophisticated error recovery if you are in a fragment rule.
// /
//
// In general, lexers should not need to recover and should have rules that cover any eventuality, such as
// a character that makes no sense to the recognizer.
func (b *BaseLexer) Recover(re RecognitionException) {
if b.input.LA(1) != TokenEOF {
if _, ok := re.(*LexerNoViableAltException); ok {

View File

@@ -7,14 +7,29 @@ package antlr
import "strconv"
const (
LexerActionTypeChannel = 0 //The type of a {@link LexerChannelAction} action.
LexerActionTypeCustom = 1 //The type of a {@link LexerCustomAction} action.
LexerActionTypeMode = 2 //The type of a {@link LexerModeAction} action.
LexerActionTypeMore = 3 //The type of a {@link LexerMoreAction} action.
LexerActionTypePopMode = 4 //The type of a {@link LexerPopModeAction} action.
LexerActionTypePushMode = 5 //The type of a {@link LexerPushModeAction} action.
LexerActionTypeSkip = 6 //The type of a {@link LexerSkipAction} action.
LexerActionTypeType = 7 //The type of a {@link LexerTypeAction} action.
// LexerActionTypeChannel represents a [LexerChannelAction] action.
LexerActionTypeChannel = 0
// LexerActionTypeCustom represents a [LexerCustomAction] action.
LexerActionTypeCustom = 1
// LexerActionTypeMode represents a [LexerModeAction] action.
LexerActionTypeMode = 2
// LexerActionTypeMore represents a [LexerMoreAction] action.
LexerActionTypeMore = 3
// LexerActionTypePopMode represents a [LexerPopModeAction] action.
LexerActionTypePopMode = 4
// LexerActionTypePushMode represents a [LexerPushModeAction] action.
LexerActionTypePushMode = 5
// LexerActionTypeSkip represents a [LexerSkipAction] action.
LexerActionTypeSkip = 6
// LexerActionTypeType represents a [LexerTypeAction] action.
LexerActionTypeType = 7
)
type LexerAction interface {
@@ -39,7 +54,7 @@ func NewBaseLexerAction(action int) *BaseLexerAction {
return la
}
func (b *BaseLexerAction) execute(lexer Lexer) {
func (b *BaseLexerAction) execute(_ Lexer) {
panic("Not implemented")
}
@@ -52,17 +67,19 @@ func (b *BaseLexerAction) getIsPositionDependent() bool {
}
func (b *BaseLexerAction) Hash() int {
return b.actionType
h := murmurInit(0)
h = murmurUpdate(h, b.actionType)
return murmurFinish(h, 1)
}
func (b *BaseLexerAction) Equals(other LexerAction) bool {
return b == other
return b.actionType == other.getActionType()
}
// Implements the {@code Skip} lexer action by calling {@link Lexer//Skip}.
// LexerSkipAction implements the [BaseLexerAction.Skip] lexer action by calling [Lexer.Skip].
//
// <p>The {@code Skip} command does not have any parameters, so l action is
// implemented as a singleton instance exposed by {@link //INSTANCE}.</p>
// The Skip command does not have any parameters, so this action is
// implemented as a singleton instance exposed by the [LexerSkipActionINSTANCE].
type LexerSkipAction struct {
*BaseLexerAction
}
@@ -73,17 +90,22 @@ func NewLexerSkipAction() *LexerSkipAction {
return la
}
// Provides a singleton instance of l parameterless lexer action.
// LexerSkipActionINSTANCE provides a singleton instance of this parameterless lexer action.
var LexerSkipActionINSTANCE = NewLexerSkipAction()
func (l *LexerSkipAction) execute(lexer Lexer) {
lexer.Skip()
}
// String returns a string representation of the current [LexerSkipAction].
func (l *LexerSkipAction) String() string {
return "skip"
}
func (b *LexerSkipAction) Equals(other LexerAction) bool {
return other.getActionType() == LexerActionTypeSkip
}
// Implements the {@code type} lexer action by calling {@link Lexer//setType}
//
// with the assigned type.
@@ -125,11 +147,10 @@ func (l *LexerTypeAction) String() string {
return "actionType(" + strconv.Itoa(l.thetype) + ")"
}
// Implements the {@code pushMode} lexer action by calling
// {@link Lexer//pushMode} with the assigned mode.
// LexerPushModeAction implements the pushMode lexer action by calling
// [Lexer.pushMode] with the assigned mode.
type LexerPushModeAction struct {
*BaseLexerAction
mode int
}
@@ -169,10 +190,10 @@ func (l *LexerPushModeAction) String() string {
return "pushMode(" + strconv.Itoa(l.mode) + ")"
}
// Implements the {@code popMode} lexer action by calling {@link Lexer//popMode}.
// LexerPopModeAction implements the popMode lexer action by calling [Lexer.popMode].
//
// <p>The {@code popMode} command does not have any parameters, so l action is
// implemented as a singleton instance exposed by {@link //INSTANCE}.</p>
// The popMode command does not have any parameters, so this action is
// implemented as a singleton instance exposed by [LexerPopModeActionINSTANCE]
type LexerPopModeAction struct {
*BaseLexerAction
}
@@ -224,11 +245,10 @@ func (l *LexerMoreAction) String() string {
return "more"
}
// Implements the {@code mode} lexer action by calling {@link Lexer//mode} with
// LexerModeAction implements the mode lexer action by calling [Lexer.mode] with
// the assigned mode.
type LexerModeAction struct {
*BaseLexerAction
mode int
}
@@ -322,16 +342,19 @@ func (l *LexerCustomAction) Equals(other LexerAction) bool {
}
}
// Implements the {@code channel} lexer action by calling
// {@link Lexer//setChannel} with the assigned channel.
// Constructs a New{@code channel} action with the specified channel value.
// @param channel The channel value to pass to {@link Lexer//setChannel}.
// LexerChannelAction implements the channel lexer action by calling
// [Lexer.setChannel] with the assigned channel.
//
// Constructs a new channel action with the specified channel value.
type LexerChannelAction struct {
*BaseLexerAction
channel int
}
// NewLexerChannelAction creates a channel lexer action by calling
// [Lexer.setChannel] with the assigned channel.
//
// Constructs a new channel action with the specified channel value.
func NewLexerChannelAction(channel int) *LexerChannelAction {
l := new(LexerChannelAction)
l.BaseLexerAction = NewBaseLexerAction(LexerActionTypeChannel)
@@ -375,25 +398,22 @@ func (l *LexerChannelAction) String() string {
// lexer actions, see {@link LexerActionExecutor//append} and
// {@link LexerActionExecutor//fixOffsetBeforeMatch}.</p>
// Constructs a Newindexed custom action by associating a character offset
// with a {@link LexerAction}.
//
// <p>Note: This class is only required for lexer actions for which
// {@link LexerAction//isPositionDependent} returns {@code true}.</p>
//
// @param offset The offset into the input {@link CharStream}, relative to
// the token start index, at which the specified lexer action should be
// executed.
// @param action The lexer action to execute at a particular offset in the
// input {@link CharStream}.
type LexerIndexedCustomAction struct {
*BaseLexerAction
offset int
lexerAction LexerAction
isPositionDependent bool
}
// NewLexerIndexedCustomAction constructs a new indexed custom action by associating a character offset
// with a [LexerAction].
//
// Note: This class is only required for lexer actions for which
// [LexerAction.isPositionDependent] returns true.
//
// The offset points into the input [CharStream], relative to
// the token start index, at which the specified lexerAction should be
// executed.
func NewLexerIndexedCustomAction(offset int, lexerAction LexerAction) *LexerIndexedCustomAction {
l := new(LexerIndexedCustomAction)

View File

@@ -29,28 +29,20 @@ func NewLexerActionExecutor(lexerActions []LexerAction) *LexerActionExecutor {
l.lexerActions = lexerActions
// Caches the result of {@link //hashCode} since the hash code is an element
// of the performance-critical {@link LexerATNConfig//hashCode} operation.
l.cachedHash = murmurInit(57)
// of the performance-critical {@link ATNConfig//hashCode} operation.
l.cachedHash = murmurInit(0)
for _, a := range lexerActions {
l.cachedHash = murmurUpdate(l.cachedHash, a.Hash())
}
l.cachedHash = murmurFinish(l.cachedHash, len(lexerActions))
return l
}
// Creates a {@link LexerActionExecutor} which executes the actions for
// the input {@code lexerActionExecutor} followed by a specified
// {@code lexerAction}.
//
// @param lexerActionExecutor The executor for actions already traversed by
// the lexer while Matching a token within a particular
// {@link LexerATNConfig}. If this is {@code nil}, the method behaves as
// though it were an empty executor.
// @param lexerAction The lexer action to execute after the actions
// specified in {@code lexerActionExecutor}.
//
// @return A {@link LexerActionExecutor} for executing the combine actions
// of {@code lexerActionExecutor} and {@code lexerAction}.
// LexerActionExecutorappend creates a [LexerActionExecutor] which executes the actions for
// the input [LexerActionExecutor] followed by a specified
// [LexerAction].
// TODO: This does not match the Java code
func LexerActionExecutorappend(lexerActionExecutor *LexerActionExecutor, lexerAction LexerAction) *LexerActionExecutor {
if lexerActionExecutor == nil {
return NewLexerActionExecutor([]LexerAction{lexerAction})
@@ -59,47 +51,42 @@ func LexerActionExecutorappend(lexerActionExecutor *LexerActionExecutor, lexerAc
return NewLexerActionExecutor(append(lexerActionExecutor.lexerActions, lexerAction))
}
// Creates a {@link LexerActionExecutor} which encodes the current offset
// fixOffsetBeforeMatch creates a [LexerActionExecutor] which encodes the current offset
// for position-dependent lexer actions.
//
// <p>Normally, when the executor encounters lexer actions where
// {@link LexerAction//isPositionDependent} returns {@code true}, it calls
// {@link IntStream//seek} on the input {@link CharStream} to set the input
// position to the <em>end</em> of the current token. This behavior provides
// for efficient DFA representation of lexer actions which appear at the end
// Normally, when the executor encounters lexer actions where
// [LexerAction.isPositionDependent] returns true, it calls
// [IntStream.Seek] on the input [CharStream] to set the input
// position to the end of the current token. This behavior provides
// for efficient [DFA] representation of lexer actions which appear at the end
// of a lexer rule, even when the lexer rule Matches a variable number of
// characters.</p>
// characters.
//
// <p>Prior to traversing a Match transition in the ATN, the current offset
// Prior to traversing a Match transition in the [ATN], the current offset
// from the token start index is assigned to all position-dependent lexer
// actions which have not already been assigned a fixed offset. By storing
// the offsets relative to the token start index, the DFA representation of
// the offsets relative to the token start index, the [DFA] representation of
// lexer actions which appear in the middle of tokens remains efficient due
// to sharing among tokens of the same length, regardless of their absolute
// position in the input stream.</p>
// to sharing among tokens of the same Length, regardless of their absolute
// position in the input stream.
//
// <p>If the current executor already has offsets assigned to all
// position-dependent lexer actions, the method returns {@code this}.</p>
// If the current executor already has offsets assigned to all
// position-dependent lexer actions, the method returns this instance.
//
// @param offset The current offset to assign to all position-dependent
// The offset is assigned to all position-dependent
// lexer actions which do not already have offsets assigned.
//
// @return A {@link LexerActionExecutor} which stores input stream offsets
// The func returns a [LexerActionExecutor] that stores input stream offsets
// for all position-dependent lexer actions.
// /
func (l *LexerActionExecutor) fixOffsetBeforeMatch(offset int) *LexerActionExecutor {
var updatedLexerActions []LexerAction
for i := 0; i < len(l.lexerActions); i++ {
_, ok := l.lexerActions[i].(*LexerIndexedCustomAction)
if l.lexerActions[i].getIsPositionDependent() && !ok {
if updatedLexerActions == nil {
updatedLexerActions = make([]LexerAction, 0)
for _, a := range l.lexerActions {
updatedLexerActions = append(updatedLexerActions, a)
}
updatedLexerActions = make([]LexerAction, 0, len(l.lexerActions))
updatedLexerActions = append(updatedLexerActions, l.lexerActions...)
}
updatedLexerActions[i] = NewLexerIndexedCustomAction(offset, l.lexerActions[i])
}
}

View File

@@ -10,10 +10,8 @@ import (
"strings"
)
//goland:noinspection GoUnusedGlobalVariable
var (
LexerATNSimulatorDebug = false
LexerATNSimulatorDFADebug = false
LexerATNSimulatorMinDFAEdge = 0
LexerATNSimulatorMaxDFAEdge = 127 // forces unicode to stay in ATN
@@ -32,11 +30,11 @@ type ILexerATNSimulator interface {
}
type LexerATNSimulator struct {
*BaseATNSimulator
BaseATNSimulator
recog Lexer
predictionMode int
mergeCache DoubleDict
mergeCache *JPCMap2
startIndex int
Line int
CharPositionInLine int
@@ -46,27 +44,35 @@ type LexerATNSimulator struct {
}
func NewLexerATNSimulator(recog Lexer, atn *ATN, decisionToDFA []*DFA, sharedContextCache *PredictionContextCache) *LexerATNSimulator {
l := new(LexerATNSimulator)
l.BaseATNSimulator = NewBaseATNSimulator(atn, sharedContextCache)
l := &LexerATNSimulator{
BaseATNSimulator: BaseATNSimulator{
atn: atn,
sharedContextCache: sharedContextCache,
},
}
l.decisionToDFA = decisionToDFA
l.recog = recog
// The current token's starting index into the character stream.
// Shared across DFA to ATN simulation in case the ATN fails and the
// DFA did not have a previous accept state. In l case, we use the
// ATN-generated exception object.
l.startIndex = -1
// line number 1..n within the input///
// line number 1..n within the input
l.Line = 1
// The index of the character relative to the beginning of the line
// 0..n-1///
// 0..n-1
l.CharPositionInLine = 0
l.mode = LexerDefaultMode
// Used during DFA/ATN exec to record the most recent accept configuration
// info
l.prevAccept = NewSimState()
// done
return l
}
@@ -114,7 +120,7 @@ func (l *LexerATNSimulator) reset() {
func (l *LexerATNSimulator) MatchATN(input CharStream) int {
startState := l.atn.modeToStartState[l.mode]
if LexerATNSimulatorDebug {
if runtimeConfig.lexerATNSimulatorDebug {
fmt.Println("MatchATN mode " + strconv.Itoa(l.mode) + " start: " + startState.String())
}
oldMode := l.mode
@@ -126,7 +132,7 @@ func (l *LexerATNSimulator) MatchATN(input CharStream) int {
predict := l.execATN(input, next)
if LexerATNSimulatorDebug {
if runtimeConfig.lexerATNSimulatorDebug {
fmt.Println("DFA after MatchATN: " + l.decisionToDFA[oldMode].ToLexerString())
}
return predict
@@ -134,18 +140,18 @@ func (l *LexerATNSimulator) MatchATN(input CharStream) int {
func (l *LexerATNSimulator) execATN(input CharStream, ds0 *DFAState) int {
if LexerATNSimulatorDebug {
if runtimeConfig.lexerATNSimulatorDebug {
fmt.Println("start state closure=" + ds0.configs.String())
}
if ds0.isAcceptState {
// allow zero-length tokens
// allow zero-Length tokens
l.captureSimState(l.prevAccept, input, ds0)
}
t := input.LA(1)
s := ds0 // s is current/from DFA state
for { // while more work
if LexerATNSimulatorDebug {
if runtimeConfig.lexerATNSimulatorDebug {
fmt.Println("execATN loop starting closure: " + s.configs.String())
}
@@ -188,7 +194,7 @@ func (l *LexerATNSimulator) execATN(input CharStream, ds0 *DFAState) int {
}
}
t = input.LA(1)
s = target // flip current DFA target becomes Newsrc/from state
s = target // flip current DFA target becomes new src/from state
}
return l.failOrAccept(l.prevAccept, input, s.configs, t)
@@ -214,43 +220,39 @@ func (l *LexerATNSimulator) getExistingTargetState(s *DFAState, t int) *DFAState
return nil
}
target := s.getIthEdge(t - LexerATNSimulatorMinDFAEdge)
if LexerATNSimulatorDebug && target != nil {
if runtimeConfig.lexerATNSimulatorDebug && target != nil {
fmt.Println("reuse state " + strconv.Itoa(s.stateNumber) + " edge to " + strconv.Itoa(target.stateNumber))
}
return target
}
// Compute a target state for an edge in the DFA, and attempt to add the
// computed state and corresponding edge to the DFA.
// computeTargetState computes a target state for an edge in the [DFA], and attempt to add the
// computed state and corresponding edge to the [DFA].
//
// @param input The input stream
// @param s The current DFA state
// @param t The next input symbol
//
// @return The computed target DFA state for the given input symbol
// {@code t}. If {@code t} does not lead to a valid DFA state, l method
// returns {@link //ERROR}.
// The func returns the computed target [DFA] state for the given input symbol t.
// If this does not lead to a valid [DFA] state, this method
// returns ATNSimulatorError.
func (l *LexerATNSimulator) computeTargetState(input CharStream, s *DFAState, t int) *DFAState {
reach := NewOrderedATNConfigSet()
// if we don't find an existing DFA state
// Fill reach starting from closure, following t transitions
l.getReachableConfigSet(input, s.configs, reach.BaseATNConfigSet, t)
l.getReachableConfigSet(input, s.configs, reach, t)
if len(reach.configs) == 0 { // we got nowhere on t from s
if !reach.hasSemanticContext {
// we got nowhere on t, don't panic out l knowledge it'd
// cause a failover from DFA later.
// cause a fail-over from DFA later.
l.addDFAEdge(s, t, ATNSimulatorError, nil)
}
// stop when we can't Match any more char
return ATNSimulatorError
}
// Add an edge from s to target DFA found/created for reach
return l.addDFAEdge(s, t, nil, reach.BaseATNConfigSet)
return l.addDFAEdge(s, t, nil, reach)
}
func (l *LexerATNSimulator) failOrAccept(prevAccept *SimState, input CharStream, reach ATNConfigSet, t int) int {
func (l *LexerATNSimulator) failOrAccept(prevAccept *SimState, input CharStream, reach *ATNConfigSet, t int) int {
if l.prevAccept.dfaState != nil {
lexerActionExecutor := prevAccept.dfaState.lexerActionExecutor
l.accept(input, lexerActionExecutor, l.startIndex, prevAccept.index, prevAccept.line, prevAccept.column)
@@ -265,34 +267,35 @@ func (l *LexerATNSimulator) failOrAccept(prevAccept *SimState, input CharStream,
panic(NewLexerNoViableAltException(l.recog, input, l.startIndex, reach))
}
// Given a starting configuration set, figure out all ATN configurations
// we can reach upon input {@code t}. Parameter {@code reach} is a return
// parameter.
func (l *LexerATNSimulator) getReachableConfigSet(input CharStream, closure ATNConfigSet, reach ATNConfigSet, t int) {
// getReachableConfigSet when given a starting configuration set, figures out all [ATN] configurations
// we can reach upon input t.
//
// Parameter reach is a return parameter.
func (l *LexerATNSimulator) getReachableConfigSet(input CharStream, closure *ATNConfigSet, reach *ATNConfigSet, t int) {
// l is used to Skip processing for configs which have a lower priority
// than a config that already reached an accept state for the same rule
// than a runtimeConfig that already reached an accept state for the same rule
SkipAlt := ATNInvalidAltNumber
for _, cfg := range closure.GetItems() {
currentAltReachedAcceptState := (cfg.GetAlt() == SkipAlt)
if currentAltReachedAcceptState && cfg.(*LexerATNConfig).passedThroughNonGreedyDecision {
for _, cfg := range closure.configs {
currentAltReachedAcceptState := cfg.GetAlt() == SkipAlt
if currentAltReachedAcceptState && cfg.passedThroughNonGreedyDecision {
continue
}
if LexerATNSimulatorDebug {
if runtimeConfig.lexerATNSimulatorDebug {
fmt.Printf("testing %s at %s\n", l.GetTokenName(t), cfg.String()) // l.recog, true))
fmt.Printf("testing %s at %s\n", l.GetTokenName(t), cfg.String())
}
for _, trans := range cfg.GetState().GetTransitions() {
target := l.getReachableTarget(trans, t)
if target != nil {
lexerActionExecutor := cfg.(*LexerATNConfig).lexerActionExecutor
lexerActionExecutor := cfg.lexerActionExecutor
if lexerActionExecutor != nil {
lexerActionExecutor = lexerActionExecutor.fixOffsetBeforeMatch(input.Index() - l.startIndex)
}
treatEOFAsEpsilon := (t == TokenEOF)
config := NewLexerATNConfig3(cfg.(*LexerATNConfig), target, lexerActionExecutor)
treatEOFAsEpsilon := t == TokenEOF
config := NewLexerATNConfig3(cfg, target, lexerActionExecutor)
if l.closure(input, config, reach,
currentAltReachedAcceptState, true, treatEOFAsEpsilon) {
// any remaining configs for l alt have a lower priority
@@ -305,7 +308,7 @@ func (l *LexerATNSimulator) getReachableConfigSet(input CharStream, closure ATNC
}
func (l *LexerATNSimulator) accept(input CharStream, lexerActionExecutor *LexerActionExecutor, startIndex, index, line, charPos int) {
if LexerATNSimulatorDebug {
if runtimeConfig.lexerATNSimulatorDebug {
fmt.Printf("ACTION %v\n", lexerActionExecutor)
}
// seek to after last char in token
@@ -325,7 +328,7 @@ func (l *LexerATNSimulator) getReachableTarget(trans Transition, t int) ATNState
return nil
}
func (l *LexerATNSimulator) computeStartState(input CharStream, p ATNState) *OrderedATNConfigSet {
func (l *LexerATNSimulator) computeStartState(input CharStream, p ATNState) *ATNConfigSet {
configs := NewOrderedATNConfigSet()
for i := 0; i < len(p.GetTransitions()); i++ {
target := p.GetTransitions()[i].getTarget()
@@ -336,25 +339,24 @@ func (l *LexerATNSimulator) computeStartState(input CharStream, p ATNState) *Ord
return configs
}
// Since the alternatives within any lexer decision are ordered by
// preference, l method stops pursuing the closure as soon as an accept
// closure since the alternatives within any lexer decision are ordered by
// preference, this method stops pursuing the closure as soon as an accept
// state is reached. After the first accept state is reached by depth-first
// search from {@code config}, all other (potentially reachable) states for
// l rule would have a lower priority.
// search from runtimeConfig, all other (potentially reachable) states for
// this rule would have a lower priority.
//
// @return {@code true} if an accept state is reached, otherwise
// {@code false}.
func (l *LexerATNSimulator) closure(input CharStream, config *LexerATNConfig, configs ATNConfigSet,
// The func returns true if an accept state is reached.
func (l *LexerATNSimulator) closure(input CharStream, config *ATNConfig, configs *ATNConfigSet,
currentAltReachedAcceptState, speculative, treatEOFAsEpsilon bool) bool {
if LexerATNSimulatorDebug {
fmt.Println("closure(" + config.String() + ")") // config.String(l.recog, true) + ")")
if runtimeConfig.lexerATNSimulatorDebug {
fmt.Println("closure(" + config.String() + ")")
}
_, ok := config.state.(*RuleStopState)
if ok {
if LexerATNSimulatorDebug {
if runtimeConfig.lexerATNSimulatorDebug {
if l.recog != nil {
fmt.Printf("closure at %s rule stop %s\n", l.recog.GetRuleNames()[config.state.GetRuleIndex()], config)
} else {
@@ -401,10 +403,10 @@ func (l *LexerATNSimulator) closure(input CharStream, config *LexerATNConfig, co
}
// side-effect: can alter configs.hasSemanticContext
func (l *LexerATNSimulator) getEpsilonTarget(input CharStream, config *LexerATNConfig, trans Transition,
configs ATNConfigSet, speculative, treatEOFAsEpsilon bool) *LexerATNConfig {
func (l *LexerATNSimulator) getEpsilonTarget(input CharStream, config *ATNConfig, trans Transition,
configs *ATNConfigSet, speculative, treatEOFAsEpsilon bool) *ATNConfig {
var cfg *LexerATNConfig
var cfg *ATNConfig
if trans.getSerializationType() == TransitionRULE {
@@ -435,10 +437,10 @@ func (l *LexerATNSimulator) getEpsilonTarget(input CharStream, config *LexerATNC
pt := trans.(*PredicateTransition)
if LexerATNSimulatorDebug {
if runtimeConfig.lexerATNSimulatorDebug {
fmt.Println("EVAL rule " + strconv.Itoa(trans.(*PredicateTransition).ruleIndex) + ":" + strconv.Itoa(pt.predIndex))
}
configs.SetHasSemanticContext(true)
configs.hasSemanticContext = true
if l.evaluatePredicate(input, pt.ruleIndex, pt.predIndex, speculative) {
cfg = NewLexerATNConfig4(config, trans.getTarget())
}
@@ -449,7 +451,7 @@ func (l *LexerATNSimulator) getEpsilonTarget(input CharStream, config *LexerATNC
// TODO: if the entry rule is invoked recursively, some
// actions may be executed during the recursive call. The
// problem can appear when hasEmptyPath() is true but
// isEmpty() is false. In l case, the config needs to be
// isEmpty() is false. In this case, the config needs to be
// split into two contexts - one with just the empty path
// and another with everything but the empty path.
// Unfortunately, the current algorithm does not allow
@@ -476,26 +478,18 @@ func (l *LexerATNSimulator) getEpsilonTarget(input CharStream, config *LexerATNC
return cfg
}
// Evaluate a predicate specified in the lexer.
// evaluatePredicate eEvaluates a predicate specified in the lexer.
//
// <p>If {@code speculative} is {@code true}, l method was called before
// {@link //consume} for the Matched character. This method should call
// {@link //consume} before evaluating the predicate to ensure position
// sensitive values, including {@link Lexer//GetText}, {@link Lexer//GetLine},
// and {@link Lexer//getcolumn}, properly reflect the current
// lexer state. This method should restore {@code input} and the simulator
// to the original state before returning (i.e. undo the actions made by the
// call to {@link //consume}.</p>
// If speculative is true, this method was called before
// [consume] for the Matched character. This method should call
// [consume] before evaluating the predicate to ensure position
// sensitive values, including [GetText], [GetLine],
// and [GetColumn], properly reflect the current
// lexer state. This method should restore input and the simulator
// to the original state before returning, i.e. undo the actions made by the
// call to [Consume].
//
// @param input The input stream.
// @param ruleIndex The rule containing the predicate.
// @param predIndex The index of the predicate within the rule.
// @param speculative {@code true} if the current index in {@code input} is
// one character before the predicate's location.
//
// @return {@code true} if the specified predicate evaluates to
// {@code true}.
// /
// The func returns true if the specified predicate evaluates to true.
func (l *LexerATNSimulator) evaluatePredicate(input CharStream, ruleIndex, predIndex int, speculative bool) bool {
// assume true if no recognizer was provided
if l.recog == nil {
@@ -527,7 +521,7 @@ func (l *LexerATNSimulator) captureSimState(settings *SimState, input CharStream
settings.dfaState = dfaState
}
func (l *LexerATNSimulator) addDFAEdge(from *DFAState, tk int, to *DFAState, cfgs ATNConfigSet) *DFAState {
func (l *LexerATNSimulator) addDFAEdge(from *DFAState, tk int, to *DFAState, cfgs *ATNConfigSet) *DFAState {
if to == nil && cfgs != nil {
// leading to l call, ATNConfigSet.hasSemanticContext is used as a
// marker indicating dynamic predicate evaluation makes l edge
@@ -539,10 +533,9 @@ func (l *LexerATNSimulator) addDFAEdge(from *DFAState, tk int, to *DFAState, cfg
// TJP notes: next time through the DFA, we see a pred again and eval.
// If that gets us to a previously created (but dangling) DFA
// state, we can continue in pure DFA mode from there.
// /
suppressEdge := cfgs.HasSemanticContext()
cfgs.SetHasSemanticContext(false)
//
suppressEdge := cfgs.hasSemanticContext
cfgs.hasSemanticContext = false
to = l.addDFAState(cfgs, true)
if suppressEdge {
@@ -554,7 +547,7 @@ func (l *LexerATNSimulator) addDFAEdge(from *DFAState, tk int, to *DFAState, cfg
// Only track edges within the DFA bounds
return to
}
if LexerATNSimulatorDebug {
if runtimeConfig.lexerATNSimulatorDebug {
fmt.Println("EDGE " + from.String() + " -> " + to.String() + " upon " + strconv.Itoa(tk))
}
l.atn.edgeMu.Lock()
@@ -572,13 +565,12 @@ func (l *LexerATNSimulator) addDFAEdge(from *DFAState, tk int, to *DFAState, cfg
// configurations already. This method also detects the first
// configuration containing an ATN rule stop state. Later, when
// traversing the DFA, we will know which rule to accept.
func (l *LexerATNSimulator) addDFAState(configs ATNConfigSet, suppressEdge bool) *DFAState {
func (l *LexerATNSimulator) addDFAState(configs *ATNConfigSet, suppressEdge bool) *DFAState {
proposed := NewDFAState(-1, configs)
var firstConfigWithRuleStopState ATNConfig
for _, cfg := range configs.GetItems() {
var firstConfigWithRuleStopState *ATNConfig
for _, cfg := range configs.configs {
_, ok := cfg.GetState().(*RuleStopState)
if ok {
@@ -588,14 +580,14 @@ func (l *LexerATNSimulator) addDFAState(configs ATNConfigSet, suppressEdge bool)
}
if firstConfigWithRuleStopState != nil {
proposed.isAcceptState = true
proposed.lexerActionExecutor = firstConfigWithRuleStopState.(*LexerATNConfig).lexerActionExecutor
proposed.lexerActionExecutor = firstConfigWithRuleStopState.lexerActionExecutor
proposed.setPrediction(l.atn.ruleToTokenType[firstConfigWithRuleStopState.GetState().GetRuleIndex()])
}
dfa := l.decisionToDFA[l.mode]
l.atn.stateMu.Lock()
defer l.atn.stateMu.Unlock()
existing, present := dfa.states.Get(proposed)
existing, present := dfa.Get(proposed)
if present {
// This state was already present, so just return it.
@@ -605,10 +597,11 @@ func (l *LexerATNSimulator) addDFAState(configs ATNConfigSet, suppressEdge bool)
// We need to add the new state
//
proposed.stateNumber = dfa.states.Len()
configs.SetReadOnly(true)
proposed.stateNumber = dfa.Len()
configs.readOnly = true
configs.configLookup = nil // Not needed now
proposed.configs = configs
dfa.states.Put(proposed)
dfa.Put(proposed)
}
if !suppressEdge {
dfa.setS0(proposed)
@@ -620,7 +613,7 @@ func (l *LexerATNSimulator) getDFA(mode int) *DFA {
return l.decisionToDFA[mode]
}
// Get the text Matched so far for the current token.
// GetText returns the text [Match]ed so far for the current token.
func (l *LexerATNSimulator) GetText(input CharStream) string {
// index is first lookahead char, don't include.
return input.GetTextFromInterval(NewInterval(l.startIndex, input.Index()-1))

View File

@@ -14,11 +14,11 @@ func NewLL1Analyzer(atn *ATN) *LL1Analyzer {
return la
}
// - Special value added to the lookahead sets to indicate that we hit
// a predicate during analysis if {@code seeThruPreds==false}.
//
// /
const (
// LL1AnalyzerHitPred is a special value added to the lookahead sets to indicate that we hit
// a predicate during analysis if
//
// seeThruPreds==false
LL1AnalyzerHitPred = TokenInvalidType
)
@@ -38,11 +38,12 @@ func (la *LL1Analyzer) getDecisionLookahead(s ATNState) []*IntervalSet {
count := len(s.GetTransitions())
look := make([]*IntervalSet, count)
for alt := 0; alt < count; alt++ {
look[alt] = NewIntervalSet()
lookBusy := NewJStore[ATNConfig, Comparator[ATNConfig]](aConfEqInst)
seeThruPreds := false // fail to get lookahead upon pred
la.look1(s.GetTransitions()[alt].getTarget(), nil, BasePredictionContextEMPTY, look[alt], lookBusy, NewBitSet(), seeThruPreds, false)
// Wipe out lookahead for la alternative if we found nothing
lookBusy := NewJStore[*ATNConfig, Comparator[*ATNConfig]](aConfEqInst, ClosureBusyCollection, "LL1Analyzer.getDecisionLookahead for lookBusy")
la.look1(s.GetTransitions()[alt].getTarget(), nil, BasePredictionContextEMPTY, look[alt], lookBusy, NewBitSet(), false, false)
// Wipe out lookahead for la alternative if we found nothing,
// or we had a predicate when we !seeThruPreds
if look[alt].length() == 0 || look[alt].contains(LL1AnalyzerHitPred) {
look[alt] = nil
@@ -51,32 +52,31 @@ func (la *LL1Analyzer) getDecisionLookahead(s ATNState) []*IntervalSet {
return look
}
// *
// Compute set of tokens that can follow {@code s} in the ATN in the
// specified {@code ctx}.
// Look computes the set of tokens that can follow s in the [ATN] in the
// specified ctx.
//
// <p>If {@code ctx} is {@code nil} and the end of the rule containing
// {@code s} is reached, {@link Token//EPSILON} is added to the result set.
// If {@code ctx} is not {@code nil} and the end of the outermost rule is
// reached, {@link Token//EOF} is added to the result set.</p>
// If ctx is nil and the end of the rule containing
// s is reached, [EPSILON] is added to the result set.
//
// @param s the ATN state
// @param stopState the ATN state to stop at. This can be a
// {@link BlockEndState} to detect epsilon paths through a closure.
// @param ctx the complete parser context, or {@code nil} if the context
// If ctx is not nil and the end of the outermost rule is
// reached, [EOF] is added to the result set.
//
// Parameter s the ATN state, and stopState is the ATN state to stop at. This can be a
// [BlockEndState] to detect epsilon paths through a closure.
//
// Parameter ctx is the complete parser context, or nil if the context
// should be ignored
//
// @return The set of tokens that can follow {@code s} in the ATN in the
// specified {@code ctx}.
// /
// The func returns the set of tokens that can follow s in the [ATN] in the
// specified ctx.
func (la *LL1Analyzer) Look(s, stopState ATNState, ctx RuleContext) *IntervalSet {
r := NewIntervalSet()
seeThruPreds := true // ignore preds get all lookahead
var lookContext PredictionContext
var lookContext *PredictionContext
if ctx != nil {
lookContext = predictionContextFromRuleContext(s.GetATN(), ctx)
}
la.look1(s, stopState, lookContext, r, NewJStore[ATNConfig, Comparator[ATNConfig]](aConfEqInst), NewBitSet(), seeThruPreds, true)
la.look1(s, stopState, lookContext, r, NewJStore[*ATNConfig, Comparator[*ATNConfig]](aConfEqInst, ClosureBusyCollection, "LL1Analyzer.Look for la.look1()"),
NewBitSet(), true, true)
return r
}
@@ -110,16 +110,17 @@ func (la *LL1Analyzer) Look(s, stopState ATNState, ctx RuleContext) *IntervalSet
// outermost context is reached. This parameter has no effect if {@code ctx}
// is {@code nil}.
func (la *LL1Analyzer) look2(s, stopState ATNState, ctx PredictionContext, look *IntervalSet, lookBusy *JStore[ATNConfig, Comparator[ATNConfig]], calledRuleStack *BitSet, seeThruPreds, addEOF bool, i int) {
func (la *LL1Analyzer) look2(_, stopState ATNState, ctx *PredictionContext, look *IntervalSet, lookBusy *JStore[*ATNConfig, Comparator[*ATNConfig]],
calledRuleStack *BitSet, seeThruPreds, addEOF bool, i int) {
returnState := la.atn.states[ctx.getReturnState(i)]
la.look1(returnState, stopState, ctx.GetParent(i), look, lookBusy, calledRuleStack, seeThruPreds, addEOF)
}
func (la *LL1Analyzer) look1(s, stopState ATNState, ctx PredictionContext, look *IntervalSet, lookBusy *JStore[ATNConfig, Comparator[ATNConfig]], calledRuleStack *BitSet, seeThruPreds, addEOF bool) {
func (la *LL1Analyzer) look1(s, stopState ATNState, ctx *PredictionContext, look *IntervalSet, lookBusy *JStore[*ATNConfig, Comparator[*ATNConfig]], calledRuleStack *BitSet, seeThruPreds, addEOF bool) {
c := NewBaseATNConfig6(s, 0, ctx)
c := NewATNConfig6(s, 0, ctx)
if lookBusy.Contains(c) {
return
@@ -151,7 +152,7 @@ func (la *LL1Analyzer) look1(s, stopState ATNState, ctx PredictionContext, look
return
}
if ctx != BasePredictionContextEMPTY {
if ctx.pcType != PredictionContextEmpty {
removed := calledRuleStack.contains(s.GetRuleIndex())
defer func() {
if removed {
@@ -202,7 +203,8 @@ func (la *LL1Analyzer) look1(s, stopState ATNState, ctx PredictionContext, look
}
}
func (la *LL1Analyzer) look3(stopState ATNState, ctx PredictionContext, look *IntervalSet, lookBusy *JStore[ATNConfig, Comparator[ATNConfig]], calledRuleStack *BitSet, seeThruPreds, addEOF bool, t1 *RuleTransition) {
func (la *LL1Analyzer) look3(stopState ATNState, ctx *PredictionContext, look *IntervalSet, lookBusy *JStore[*ATNConfig, Comparator[*ATNConfig]],
calledRuleStack *BitSet, seeThruPreds, addEOF bool, t1 *RuleTransition) {
newContext := SingletonBasePredictionContextCreate(ctx, t1.followState.GetStateNumber())

47
vendor/github.com/antlr4-go/antlr/v4/nostatistics.go generated vendored Normal file
View File

@@ -0,0 +1,47 @@
//go:build !antlr.stats
package antlr
// This file is compiled when the build configuration antlr.stats is not enabled.
// which then allows the compiler to optimize out all the code that is not used.
const collectStats = false
// goRunStats is a dummy struct used when build configuration antlr.stats is not enabled.
type goRunStats struct {
}
var Statistics = &goRunStats{}
func (s *goRunStats) AddJStatRec(_ *JStatRec) {
// Do nothing - compiler will optimize this out (hopefully)
}
func (s *goRunStats) CollectionAnomalies() {
// Do nothing - compiler will optimize this out (hopefully)
}
func (s *goRunStats) Reset() {
// Do nothing - compiler will optimize this out (hopefully)
}
func (s *goRunStats) Report(dir string, prefix string) error {
// Do nothing - compiler will optimize this out (hopefully)
return nil
}
func (s *goRunStats) Analyze() {
// Do nothing - compiler will optimize this out (hopefully)
}
type statsOption func(*goRunStats) error
func (s *goRunStats) Configure(options ...statsOption) error {
// Do nothing - compiler will optimize this out (hopefully)
return nil
}
func WithTopN(topN int) statsOption {
return func(s *goRunStats) error {
return nil
}
}

View File

@@ -48,8 +48,10 @@ type BaseParser struct {
_SyntaxErrors int
}
// p.is all the parsing support code essentially most of it is error
// recovery stuff.//
// NewBaseParser contains all the parsing support code to embed in parsers. Essentially most of it is error
// recovery stuff.
//
//goland:noinspection GoUnusedExportedFunction
func NewBaseParser(input TokenStream) *BaseParser {
p := new(BaseParser)
@@ -58,39 +60,46 @@ func NewBaseParser(input TokenStream) *BaseParser {
// The input stream.
p.input = nil
// The error handling strategy for the parser. The default value is a new
// instance of {@link DefaultErrorStrategy}.
p.errHandler = NewDefaultErrorStrategy()
p.precedenceStack = make([]int, 0)
p.precedenceStack.Push(0)
// The {@link ParserRuleContext} object for the currently executing rule.
// The ParserRuleContext object for the currently executing rule.
// p.is always non-nil during the parsing process.
p.ctx = nil
// Specifies whether or not the parser should construct a parse tree during
// Specifies whether the parser should construct a parse tree during
// the parsing process. The default value is {@code true}.
p.BuildParseTrees = true
// When {@link //setTrace}{@code (true)} is called, a reference to the
// {@link TraceListener} is stored here so it can be easily removed in a
// later call to {@link //setTrace}{@code (false)}. The listener itself is
// When setTrace(true) is called, a reference to the
// TraceListener is stored here, so it can be easily removed in a
// later call to setTrace(false). The listener itself is
// implemented as a parser listener so p.field is not directly used by
// other parser methods.
p.tracer = nil
// The list of {@link ParseTreeListener} listeners registered to receive
// The list of ParseTreeListener listeners registered to receive
// events during the parse.
p.parseListeners = nil
// The number of syntax errors Reported during parsing. p.value is
// incremented each time {@link //NotifyErrorListeners} is called.
// incremented each time NotifyErrorListeners is called.
p._SyntaxErrors = 0
p.SetInputStream(input)
return p
}
// p.field maps from the serialized ATN string to the deserialized {@link
// ATN} with
// This field maps from the serialized ATN string to the deserialized [ATN] with
// bypass alternatives.
//
// @see ATNDeserializationOptions//isGenerateRuleBypassTransitions()
// [ATNDeserializationOptions.isGenerateRuleBypassTransitions]
//
//goland:noinspection GoUnusedGlobalVariable
var bypassAltsAtnCache = make(map[string]int)
// reset the parser's state//
@@ -143,10 +152,13 @@ func (p *BaseParser) Match(ttype int) Token {
p.Consume()
} else {
t = p.errHandler.RecoverInline(p)
if p.HasError() {
return nil
}
if p.BuildParseTrees && t.GetTokenIndex() == -1 {
// we must have conjured up a Newtoken during single token
// insertion
// if it's not the current symbol
// we must have conjured up a new token during single token
// insertion if it's not the current symbol
p.ctx.AddErrorNode(t)
}
}
@@ -178,9 +190,8 @@ func (p *BaseParser) MatchWildcard() Token {
} else {
t = p.errHandler.RecoverInline(p)
if p.BuildParseTrees && t.GetTokenIndex() == -1 {
// we must have conjured up a Newtoken during single token
// insertion
// if it's not the current symbol
// we must have conjured up a new token during single token
// insertion if it's not the current symbol
p.ctx.AddErrorNode(t)
}
}
@@ -202,33 +213,27 @@ func (p *BaseParser) GetParseListeners() []ParseTreeListener {
return p.parseListeners
}
// Registers {@code listener} to receive events during the parsing process.
// AddParseListener registers listener to receive events during the parsing process.
//
// <p>To support output-preserving grammar transformations (including but not
// To support output-preserving grammar transformations (including but not
// limited to left-recursion removal, automated left-factoring, and
// optimized code generation), calls to listener methods during the parse
// may differ substantially from calls made by
// {@link ParseTreeWalker//DEFAULT} used after the parse is complete. In
// [ParseTreeWalker.DEFAULT] used after the parse is complete. In
// particular, rule entry and exit events may occur in a different order
// during the parse than after the parser. In addition, calls to certain
// rule entry methods may be omitted.</p>
// rule entry methods may be omitted.
//
// <p>With the following specific exceptions, calls to listener events are
// <em>deterministic</em>, i.e. for identical input the calls to listener
// methods will be the same.</p>
// With the following specific exceptions, calls to listener events are
// deterministic, i.e. for identical input the calls to listener
// methods will be the same.
//
// <ul>
// <li>Alterations to the grammar used to generate code may change the
// behavior of the listener calls.</li>
// <li>Alterations to the command line options passed to ANTLR 4 when
// generating the parser may change the behavior of the listener calls.</li>
// <li>Changing the version of the ANTLR Tool used to generate the parser
// may change the behavior of the listener calls.</li>
// </ul>
//
// @param listener the listener to add
//
// @panics nilPointerException if {@code} listener is {@code nil}
// - Alterations to the grammar used to generate code may change the
// behavior of the listener calls.
// - Alterations to the command line options passed to ANTLR 4 when
// generating the parser may change the behavior of the listener calls.
// - Changing the version of the ANTLR Tool used to generate the parser
// may change the behavior of the listener calls.
func (p *BaseParser) AddParseListener(listener ParseTreeListener) {
if listener == nil {
panic("listener")
@@ -239,11 +244,10 @@ func (p *BaseParser) AddParseListener(listener ParseTreeListener) {
p.parseListeners = append(p.parseListeners, listener)
}
// Remove {@code listener} from the list of parse listeners.
// RemoveParseListener removes listener from the list of parse listeners.
//
// <p>If {@code listener} is {@code nil} or has not been added as a parse
// listener, p.method does nothing.</p>
// @param listener the listener to remove
// If listener is nil or has not been added as a parse
// listener, this func does nothing.
func (p *BaseParser) RemoveParseListener(listener ParseTreeListener) {
if p.parseListeners != nil {
@@ -274,7 +278,7 @@ func (p *BaseParser) removeParseListeners() {
p.parseListeners = nil
}
// Notify any parse listeners of an enter rule event.
// TriggerEnterRuleEvent notifies all parse listeners of an enter rule event.
func (p *BaseParser) TriggerEnterRuleEvent() {
if p.parseListeners != nil {
ctx := p.ctx
@@ -285,9 +289,7 @@ func (p *BaseParser) TriggerEnterRuleEvent() {
}
}
// Notify any parse listeners of an exit rule event.
//
// @see //addParseListener
// TriggerExitRuleEvent notifies any parse listeners of an exit rule event.
func (p *BaseParser) TriggerExitRuleEvent() {
if p.parseListeners != nil {
// reverse order walk of listeners
@@ -314,19 +316,16 @@ func (p *BaseParser) GetTokenFactory() TokenFactory {
return p.input.GetTokenSource().GetTokenFactory()
}
// Tell our token source and error strategy about a Newway to create tokens.//
// setTokenFactory is used to tell our token source and error strategy about a new way to create tokens.
func (p *BaseParser) setTokenFactory(factory TokenFactory) {
p.input.GetTokenSource().setTokenFactory(factory)
}
// The ATN with bypass alternatives is expensive to create so we create it
// GetATNWithBypassAlts - the ATN with bypass alternatives is expensive to create, so we create it
// lazily.
//
// @panics UnsupportedOperationException if the current parser does not
// implement the {@link //getSerializedATN()} method.
func (p *BaseParser) GetATNWithBypassAlts() {
// TODO
// TODO - Implement this?
panic("Not implemented!")
// serializedAtn := p.getSerializedATN()
@@ -354,6 +353,7 @@ func (p *BaseParser) GetATNWithBypassAlts() {
// String id = m.Get("ID")
// </pre>
//goland:noinspection GoUnusedParameter
func (p *BaseParser) compileParseTreePattern(pattern, patternRuleIndex, lexer Lexer) {
panic("NewParseTreePatternMatcher not implemented!")
@@ -386,14 +386,16 @@ func (p *BaseParser) GetTokenStream() TokenStream {
return p.input
}
// Set the token stream and reset the parser.//
// SetTokenStream installs input as the token stream and resets the parser.
func (p *BaseParser) SetTokenStream(input TokenStream) {
p.input = nil
p.reset()
p.input = input
}
// Match needs to return the current input symbol, which gets put
// GetCurrentToken returns the current token at LT(1).
//
// [Match] needs to return the current input symbol, which gets put
// into the label for the associated token ref e.g., x=ID.
func (p *BaseParser) GetCurrentToken() Token {
return p.input.LT(1)
@@ -446,7 +448,7 @@ func (p *BaseParser) addContextToParseTree() {
}
}
func (p *BaseParser) EnterRule(localctx ParserRuleContext, state, ruleIndex int) {
func (p *BaseParser) EnterRule(localctx ParserRuleContext, state, _ int) {
p.SetState(state)
p.ctx = localctx
p.ctx.SetStart(p.input.LT(1))
@@ -474,7 +476,7 @@ func (p *BaseParser) ExitRule() {
func (p *BaseParser) EnterOuterAlt(localctx ParserRuleContext, altNum int) {
localctx.SetAltNumber(altNum)
// if we have Newlocalctx, make sure we replace existing ctx
// if we have a new localctx, make sure we replace existing ctx
// that is previous child of parse tree
if p.BuildParseTrees && p.ctx != localctx {
if p.ctx.GetParent() != nil {
@@ -498,7 +500,7 @@ func (p *BaseParser) GetPrecedence() int {
return p.precedenceStack[len(p.precedenceStack)-1]
}
func (p *BaseParser) EnterRecursionRule(localctx ParserRuleContext, state, ruleIndex, precedence int) {
func (p *BaseParser) EnterRecursionRule(localctx ParserRuleContext, state, _, precedence int) {
p.SetState(state)
p.precedenceStack.Push(precedence)
p.ctx = localctx
@@ -512,7 +514,7 @@ func (p *BaseParser) EnterRecursionRule(localctx ParserRuleContext, state, ruleI
//
// Like {@link //EnterRule} but for recursive rules.
func (p *BaseParser) PushNewRecursionContext(localctx ParserRuleContext, state, ruleIndex int) {
func (p *BaseParser) PushNewRecursionContext(localctx ParserRuleContext, state, _ int) {
previous := p.ctx
previous.SetParent(localctx)
previous.SetInvokingState(state)
@@ -530,7 +532,7 @@ func (p *BaseParser) PushNewRecursionContext(localctx ParserRuleContext, state,
}
func (p *BaseParser) UnrollRecursionContexts(parentCtx ParserRuleContext) {
p.precedenceStack.Pop()
_, _ = p.precedenceStack.Pop()
p.ctx.SetStop(p.input.LT(-1))
retCtx := p.ctx // save current ctx (return value)
// unroll so ctx is as it was before call to recursive method
@@ -561,29 +563,22 @@ func (p *BaseParser) GetInvokingContext(ruleIndex int) ParserRuleContext {
return nil
}
func (p *BaseParser) Precpred(localctx RuleContext, precedence int) bool {
func (p *BaseParser) Precpred(_ RuleContext, precedence int) bool {
return precedence >= p.precedenceStack[len(p.precedenceStack)-1]
}
//goland:noinspection GoUnusedParameter
func (p *BaseParser) inContext(context ParserRuleContext) bool {
// TODO: useful in parser?
return false
}
//
// Checks whether or not {@code symbol} can follow the current state in the
// ATN. The behavior of p.method is equivalent to the following, but is
// IsExpectedToken checks whether symbol can follow the current state in the
// {ATN}. The behavior of p.method is equivalent to the following, but is
// implemented such that the complete context-sensitive follow set does not
// need to be explicitly constructed.
//
// <pre>
// return getExpectedTokens().contains(symbol)
// </pre>
//
// @param symbol the symbol type to check
// @return {@code true} if {@code symbol} can follow the current state in
// the ATN, otherwise {@code false}.
// return getExpectedTokens().contains(symbol)
func (p *BaseParser) IsExpectedToken(symbol int) bool {
atn := p.Interpreter.atn
ctx := p.ctx
@@ -611,11 +606,9 @@ func (p *BaseParser) IsExpectedToken(symbol int) bool {
return false
}
// Computes the set of input symbols which could follow the current parser
// state and context, as given by {@link //GetState} and {@link //GetContext},
// GetExpectedTokens and returns the set of input symbols which could follow the current parser
// state and context, as given by [GetState] and [GetContext],
// respectively.
//
// @see ATN//getExpectedTokens(int, RuleContext)
func (p *BaseParser) GetExpectedTokens() *IntervalSet {
return p.Interpreter.atn.getExpectedTokens(p.state, p.ctx)
}
@@ -626,7 +619,7 @@ func (p *BaseParser) GetExpectedTokensWithinCurrentRule() *IntervalSet {
return atn.NextTokens(s, nil)
}
// Get a rule's index (i.e., {@code RULE_ruleName} field) or -1 if not found.//
// GetRuleIndex get a rule's index (i.e., RULE_ruleName field) or -1 if not found.
func (p *BaseParser) GetRuleIndex(ruleName string) int {
var ruleIndex, ok = p.GetRuleIndexMap()[ruleName]
if ok {
@@ -636,13 +629,10 @@ func (p *BaseParser) GetRuleIndex(ruleName string) int {
return -1
}
// Return List&ltString&gt of the rule names in your parser instance
// GetRuleInvocationStack returns a list of the rule names in your parser instance
// leading up to a call to the current rule. You could override if
// you want more details such as the file/line info of where
// in the ATN a rule is invoked.
//
// this very useful for error messages.
func (p *BaseParser) GetRuleInvocationStack(c ParserRuleContext) []string {
if c == nil {
c = p.ctx
@@ -668,16 +658,16 @@ func (p *BaseParser) GetRuleInvocationStack(c ParserRuleContext) []string {
return stack
}
// For debugging and other purposes.//
// GetDFAStrings returns a list of all DFA states used for debugging purposes
func (p *BaseParser) GetDFAStrings() string {
return fmt.Sprint(p.Interpreter.decisionToDFA)
}
// For debugging and other purposes.//
// DumpDFA prints the whole of the DFA for debugging
func (p *BaseParser) DumpDFA() {
seenOne := false
for _, dfa := range p.Interpreter.decisionToDFA {
if dfa.states.Len() > 0 {
if dfa.Len() > 0 {
if seenOne {
fmt.Println()
}
@@ -692,8 +682,10 @@ func (p *BaseParser) GetSourceName() string {
return p.GrammarFileName
}
// During a parse is sometimes useful to listen in on the rule entry and exit
// events as well as token Matches. p.is for quick and dirty debugging.
// SetTrace installs a trace listener for the parse.
//
// During a parse it is sometimes useful to listen in on the rule entry and exit
// events as well as token Matches. This is for quick and dirty debugging.
func (p *BaseParser) SetTrace(trace *TraceListener) {
if trace == nil {
p.RemoveParseListener(p.tracer)

View File

@@ -31,7 +31,9 @@ type ParserRuleContext interface {
}
type BaseParserRuleContext struct {
*BaseRuleContext
parentCtx RuleContext
invokingState int
RuleIndex int
start, stop Token
exception RecognitionException
@@ -40,8 +42,22 @@ type BaseParserRuleContext struct {
func NewBaseParserRuleContext(parent ParserRuleContext, invokingStateNumber int) *BaseParserRuleContext {
prc := new(BaseParserRuleContext)
InitBaseParserRuleContext(prc, parent, invokingStateNumber)
return prc
}
prc.BaseRuleContext = NewBaseRuleContext(parent, invokingStateNumber)
func InitBaseParserRuleContext(prc *BaseParserRuleContext, parent ParserRuleContext, invokingStateNumber int) {
// What context invoked b rule?
prc.parentCtx = parent
// What state invoked the rule associated with b context?
// The "return address" is the followState of invokingState
// If parent is nil, b should be -1.
if parent == nil {
prc.invokingState = -1
} else {
prc.invokingState = invokingStateNumber
}
prc.RuleIndex = -1
// * If we are debugging or building a parse tree for a Visitor,
@@ -56,8 +72,6 @@ func NewBaseParserRuleContext(parent ParserRuleContext, invokingStateNumber int)
// The exception that forced prc rule to return. If the rule successfully
// completed, prc is {@code nil}.
prc.exception = nil
return prc
}
func (prc *BaseParserRuleContext) SetException(e RecognitionException) {
@@ -90,14 +104,15 @@ func (prc *BaseParserRuleContext) GetText() string {
return s
}
// Double dispatch methods for listeners
func (prc *BaseParserRuleContext) EnterRule(listener ParseTreeListener) {
// EnterRule is called when any rule is entered.
func (prc *BaseParserRuleContext) EnterRule(_ ParseTreeListener) {
}
func (prc *BaseParserRuleContext) ExitRule(listener ParseTreeListener) {
// ExitRule is called when any rule is exited.
func (prc *BaseParserRuleContext) ExitRule(_ ParseTreeListener) {
}
// * Does not set parent link other add methods do that///
// * Does not set parent link other add methods do that
func (prc *BaseParserRuleContext) addTerminalNodeChild(child TerminalNode) TerminalNode {
if prc.children == nil {
prc.children = make([]Tree, 0)
@@ -120,10 +135,9 @@ func (prc *BaseParserRuleContext) AddChild(child RuleContext) RuleContext {
return child
}
// * Used by EnterOuterAlt to toss out a RuleContext previously added as
// we entered a rule. If we have // label, we will need to remove
// generic ruleContext object.
// /
// RemoveLastChild is used by [EnterOuterAlt] to toss out a [RuleContext] previously added as
// we entered a rule. If we have a label, we will need to remove
// the generic ruleContext object.
func (prc *BaseParserRuleContext) RemoveLastChild() {
if prc.children != nil && len(prc.children) > 0 {
prc.children = prc.children[0 : len(prc.children)-1]
@@ -293,7 +307,7 @@ func (prc *BaseParserRuleContext) GetChildCount() int {
return len(prc.children)
}
func (prc *BaseParserRuleContext) GetSourceInterval() *Interval {
func (prc *BaseParserRuleContext) GetSourceInterval() Interval {
if prc.start == nil || prc.stop == nil {
return TreeInvalidInterval
}
@@ -340,6 +354,50 @@ func (prc *BaseParserRuleContext) String(ruleNames []string, stop RuleContext) s
return s
}
func (prc *BaseParserRuleContext) SetParent(v Tree) {
if v == nil {
prc.parentCtx = nil
} else {
prc.parentCtx = v.(RuleContext)
}
}
func (prc *BaseParserRuleContext) GetInvokingState() int {
return prc.invokingState
}
func (prc *BaseParserRuleContext) SetInvokingState(t int) {
prc.invokingState = t
}
func (prc *BaseParserRuleContext) GetRuleIndex() int {
return prc.RuleIndex
}
func (prc *BaseParserRuleContext) GetAltNumber() int {
return ATNInvalidAltNumber
}
func (prc *BaseParserRuleContext) SetAltNumber(_ int) {}
// IsEmpty returns true if the context of b is empty.
//
// A context is empty if there is no invoking state, meaning nobody calls
// current context.
func (prc *BaseParserRuleContext) IsEmpty() bool {
return prc.invokingState == -1
}
// GetParent returns the combined text of all child nodes. This method only considers
// tokens which have been added to the parse tree.
//
// Since tokens on hidden channels (e.g. whitespace or comments) are not
// added to the parse trees, they will not appear in the output of this
// method.
func (prc *BaseParserRuleContext) GetParent() Tree {
return prc.parentCtx
}
var ParserRuleContextEmpty = NewBaseParserRuleContext(nil, -1)
type InterpreterRuleContext interface {
@@ -350,6 +408,7 @@ type BaseInterpreterRuleContext struct {
*BaseParserRuleContext
}
//goland:noinspection GoUnusedExportedFunction
func NewBaseInterpreterRuleContext(parent BaseInterpreterRuleContext, invokingStateNumber, ruleIndex int) *BaseInterpreterRuleContext {
prc := new(BaseInterpreterRuleContext)

View File

@@ -0,0 +1,727 @@
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
package antlr
import (
"fmt"
"golang.org/x/exp/slices"
"strconv"
)
var _emptyPredictionContextHash int
func init() {
_emptyPredictionContextHash = murmurInit(1)
_emptyPredictionContextHash = murmurFinish(_emptyPredictionContextHash, 0)
}
func calculateEmptyHash() int {
return _emptyPredictionContextHash
}
const (
// BasePredictionContextEmptyReturnState represents {@code $} in an array in full context mode, $
// doesn't mean wildcard:
//
// $ + x = [$,x]
//
// Here,
//
// $ = EmptyReturnState
BasePredictionContextEmptyReturnState = 0x7FFFFFFF
)
// TODO: JI These are meant to be atomics - this does not seem to match the Java runtime here
//
//goland:noinspection GoUnusedGlobalVariable
var (
BasePredictionContextglobalNodeCount = 1
BasePredictionContextid = BasePredictionContextglobalNodeCount
)
const (
PredictionContextEmpty = iota
PredictionContextSingleton
PredictionContextArray
)
// PredictionContext is a go idiomatic implementation of PredictionContext that does not rty to
// emulate inheritance from Java, and can be used without an interface definition. An interface
// is not required because no user code will ever need to implement this interface.
type PredictionContext struct {
cachedHash int
pcType int
parentCtx *PredictionContext
returnState int
parents []*PredictionContext
returnStates []int
}
func NewEmptyPredictionContext() *PredictionContext {
nep := &PredictionContext{}
nep.cachedHash = calculateEmptyHash()
nep.pcType = PredictionContextEmpty
nep.returnState = BasePredictionContextEmptyReturnState
return nep
}
func NewBaseSingletonPredictionContext(parent *PredictionContext, returnState int) *PredictionContext {
pc := &PredictionContext{}
pc.pcType = PredictionContextSingleton
pc.returnState = returnState
pc.parentCtx = parent
if parent != nil {
pc.cachedHash = calculateHash(parent, returnState)
} else {
pc.cachedHash = calculateEmptyHash()
}
return pc
}
func SingletonBasePredictionContextCreate(parent *PredictionContext, returnState int) *PredictionContext {
if returnState == BasePredictionContextEmptyReturnState && parent == nil {
// someone can pass in the bits of an array ctx that mean $
return BasePredictionContextEMPTY
}
return NewBaseSingletonPredictionContext(parent, returnState)
}
func NewArrayPredictionContext(parents []*PredictionContext, returnStates []int) *PredictionContext {
// Parent can be nil only if full ctx mode and we make an array
// from {@link //EMPTY} and non-empty. We merge {@link //EMPTY} by using
// nil parent and
// returnState == {@link //EmptyReturnState}.
hash := murmurInit(1)
for _, parent := range parents {
hash = murmurUpdate(hash, parent.Hash())
}
for _, returnState := range returnStates {
hash = murmurUpdate(hash, returnState)
}
hash = murmurFinish(hash, len(parents)<<1)
nec := &PredictionContext{}
nec.cachedHash = hash
nec.pcType = PredictionContextArray
nec.parents = parents
nec.returnStates = returnStates
return nec
}
func (p *PredictionContext) Hash() int {
return p.cachedHash
}
func (p *PredictionContext) Equals(other Collectable[*PredictionContext]) bool {
switch p.pcType {
case PredictionContextEmpty:
otherP := other.(*PredictionContext)
return other == nil || otherP == nil || otherP.isEmpty()
case PredictionContextSingleton:
return p.SingletonEquals(other)
case PredictionContextArray:
return p.ArrayEquals(other)
}
return false
}
func (p *PredictionContext) ArrayEquals(o Collectable[*PredictionContext]) bool {
if o == nil {
return false
}
other := o.(*PredictionContext)
if other == nil || other.pcType != PredictionContextArray {
return false
}
if p.cachedHash != other.Hash() {
return false // can't be same if hash is different
}
// Must compare the actual array elements and not just the array address
//
return slices.Equal(p.returnStates, other.returnStates) &&
slices.EqualFunc(p.parents, other.parents, func(x, y *PredictionContext) bool {
return x.Equals(y)
})
}
func (p *PredictionContext) SingletonEquals(other Collectable[*PredictionContext]) bool {
if other == nil {
return false
}
otherP := other.(*PredictionContext)
if otherP == nil {
return false
}
if p.cachedHash != otherP.Hash() {
return false // Can't be same if hash is different
}
if p.returnState != otherP.getReturnState(0) {
return false
}
// Both parents must be nil if one is
if p.parentCtx == nil {
return otherP.parentCtx == nil
}
return p.parentCtx.Equals(otherP.parentCtx)
}
func (p *PredictionContext) GetParent(i int) *PredictionContext {
switch p.pcType {
case PredictionContextEmpty:
return nil
case PredictionContextSingleton:
return p.parentCtx
case PredictionContextArray:
return p.parents[i]
}
return nil
}
func (p *PredictionContext) getReturnState(i int) int {
switch p.pcType {
case PredictionContextArray:
return p.returnStates[i]
default:
return p.returnState
}
}
func (p *PredictionContext) GetReturnStates() []int {
switch p.pcType {
case PredictionContextArray:
return p.returnStates
default:
return []int{p.returnState}
}
}
func (p *PredictionContext) length() int {
switch p.pcType {
case PredictionContextArray:
return len(p.returnStates)
default:
return 1
}
}
func (p *PredictionContext) hasEmptyPath() bool {
switch p.pcType {
case PredictionContextSingleton:
return p.returnState == BasePredictionContextEmptyReturnState
}
return p.getReturnState(p.length()-1) == BasePredictionContextEmptyReturnState
}
func (p *PredictionContext) String() string {
switch p.pcType {
case PredictionContextEmpty:
return "$"
case PredictionContextSingleton:
var up string
if p.parentCtx == nil {
up = ""
} else {
up = p.parentCtx.String()
}
if len(up) == 0 {
if p.returnState == BasePredictionContextEmptyReturnState {
return "$"
}
return strconv.Itoa(p.returnState)
}
return strconv.Itoa(p.returnState) + " " + up
case PredictionContextArray:
if p.isEmpty() {
return "[]"
}
s := "["
for i := 0; i < len(p.returnStates); i++ {
if i > 0 {
s = s + ", "
}
if p.returnStates[i] == BasePredictionContextEmptyReturnState {
s = s + "$"
continue
}
s = s + strconv.Itoa(p.returnStates[i])
if !p.parents[i].isEmpty() {
s = s + " " + p.parents[i].String()
} else {
s = s + "nil"
}
}
return s + "]"
default:
return "unknown"
}
}
func (p *PredictionContext) isEmpty() bool {
switch p.pcType {
case PredictionContextEmpty:
return true
case PredictionContextArray:
// since EmptyReturnState can only appear in the last position, we
// don't need to verify that size==1
return p.returnStates[0] == BasePredictionContextEmptyReturnState
default:
return false
}
}
func (p *PredictionContext) Type() int {
return p.pcType
}
func calculateHash(parent *PredictionContext, returnState int) int {
h := murmurInit(1)
h = murmurUpdate(h, parent.Hash())
h = murmurUpdate(h, returnState)
return murmurFinish(h, 2)
}
// Convert a {@link RuleContext} tree to a {@link BasePredictionContext} graph.
// Return {@link //EMPTY} if {@code outerContext} is empty or nil.
// /
func predictionContextFromRuleContext(a *ATN, outerContext RuleContext) *PredictionContext {
if outerContext == nil {
outerContext = ParserRuleContextEmpty
}
// if we are in RuleContext of start rule, s, then BasePredictionContext
// is EMPTY. Nobody called us. (if we are empty, return empty)
if outerContext.GetParent() == nil || outerContext == ParserRuleContextEmpty {
return BasePredictionContextEMPTY
}
// If we have a parent, convert it to a BasePredictionContext graph
parent := predictionContextFromRuleContext(a, outerContext.GetParent().(RuleContext))
state := a.states[outerContext.GetInvokingState()]
transition := state.GetTransitions()[0]
return SingletonBasePredictionContextCreate(parent, transition.(*RuleTransition).followState.GetStateNumber())
}
func merge(a, b *PredictionContext, rootIsWildcard bool, mergeCache *JPCMap) *PredictionContext {
// Share same graph if both same
//
if a == b || a.Equals(b) {
return a
}
if a.pcType == PredictionContextSingleton && b.pcType == PredictionContextSingleton {
return mergeSingletons(a, b, rootIsWildcard, mergeCache)
}
// At least one of a or b is array
// If one is $ and rootIsWildcard, return $ as wildcard
if rootIsWildcard {
if a.isEmpty() {
return a
}
if b.isEmpty() {
return b
}
}
// Convert either Singleton or Empty to arrays, so that we can merge them
//
ara := convertToArray(a)
arb := convertToArray(b)
return mergeArrays(ara, arb, rootIsWildcard, mergeCache)
}
func convertToArray(pc *PredictionContext) *PredictionContext {
switch pc.Type() {
case PredictionContextEmpty:
return NewArrayPredictionContext([]*PredictionContext{}, []int{})
case PredictionContextSingleton:
return NewArrayPredictionContext([]*PredictionContext{pc.GetParent(0)}, []int{pc.getReturnState(0)})
default:
// Already an array
}
return pc
}
// mergeSingletons merges two Singleton [PredictionContext] instances.
//
// Stack tops equal, parents merge is same return left graph.
// <embed src="images/SingletonMerge_SameRootSamePar.svg"
// type="image/svg+xml"/></p>
//
// <p>Same stack top, parents differ merge parents giving array node, then
// remainders of those graphs. A new root node is created to point to the
// merged parents.<br>
// <embed src="images/SingletonMerge_SameRootDiffPar.svg"
// type="image/svg+xml"/></p>
//
// <p>Different stack tops pointing to same parent. Make array node for the
// root where both element in the root point to the same (original)
// parent.<br>
// <embed src="images/SingletonMerge_DiffRootSamePar.svg"
// type="image/svg+xml"/></p>
//
// <p>Different stack tops pointing to different parents. Make array node for
// the root where each element points to the corresponding original
// parent.<br>
// <embed src="images/SingletonMerge_DiffRootDiffPar.svg"
// type="image/svg+xml"/></p>
//
// @param a the first {@link SingletonBasePredictionContext}
// @param b the second {@link SingletonBasePredictionContext}
// @param rootIsWildcard {@code true} if this is a local-context merge,
// otherwise false to indicate a full-context merge
// @param mergeCache
// /
func mergeSingletons(a, b *PredictionContext, rootIsWildcard bool, mergeCache *JPCMap) *PredictionContext {
if mergeCache != nil {
previous, present := mergeCache.Get(a, b)
if present {
return previous
}
previous, present = mergeCache.Get(b, a)
if present {
return previous
}
}
rootMerge := mergeRoot(a, b, rootIsWildcard)
if rootMerge != nil {
if mergeCache != nil {
mergeCache.Put(a, b, rootMerge)
}
return rootMerge
}
if a.returnState == b.returnState {
parent := merge(a.parentCtx, b.parentCtx, rootIsWildcard, mergeCache)
// if parent is same as existing a or b parent or reduced to a parent,
// return it
if parent.Equals(a.parentCtx) {
return a // ax + bx = ax, if a=b
}
if parent.Equals(b.parentCtx) {
return b // ax + bx = bx, if a=b
}
// else: ax + ay = a'[x,y]
// merge parents x and y, giving array node with x,y then remainders
// of those graphs. dup a, a' points at merged array.
// New joined parent so create a new singleton pointing to it, a'
spc := SingletonBasePredictionContextCreate(parent, a.returnState)
if mergeCache != nil {
mergeCache.Put(a, b, spc)
}
return spc
}
// a != b payloads differ
// see if we can collapse parents due to $+x parents if local ctx
var singleParent *PredictionContext
if a.Equals(b) || (a.parentCtx != nil && a.parentCtx.Equals(b.parentCtx)) { // ax +
// bx =
// [a,b]x
singleParent = a.parentCtx
}
if singleParent != nil { // parents are same
// sort payloads and use same parent
payloads := []int{a.returnState, b.returnState}
if a.returnState > b.returnState {
payloads[0] = b.returnState
payloads[1] = a.returnState
}
parents := []*PredictionContext{singleParent, singleParent}
apc := NewArrayPredictionContext(parents, payloads)
if mergeCache != nil {
mergeCache.Put(a, b, apc)
}
return apc
}
// parents differ and can't merge them. Just pack together
// into array can't merge.
// ax + by = [ax,by]
payloads := []int{a.returnState, b.returnState}
parents := []*PredictionContext{a.parentCtx, b.parentCtx}
if a.returnState > b.returnState { // sort by payload
payloads[0] = b.returnState
payloads[1] = a.returnState
parents = []*PredictionContext{b.parentCtx, a.parentCtx}
}
apc := NewArrayPredictionContext(parents, payloads)
if mergeCache != nil {
mergeCache.Put(a, b, apc)
}
return apc
}
// Handle case where at least one of {@code a} or {@code b} is
// {@link //EMPTY}. In the following diagrams, the symbol {@code $} is used
// to represent {@link //EMPTY}.
//
// <h2>Local-Context Merges</h2>
//
// <p>These local-context merge operations are used when {@code rootIsWildcard}
// is true.</p>
//
// <p>{@link //EMPTY} is superset of any graph return {@link //EMPTY}.<br>
// <embed src="images/LocalMerge_EmptyRoot.svg" type="image/svg+xml"/></p>
//
// <p>{@link //EMPTY} and anything is {@code //EMPTY}, so merged parent is
// {@code //EMPTY} return left graph.<br>
// <embed src="images/LocalMerge_EmptyParent.svg" type="image/svg+xml"/></p>
//
// <p>Special case of last merge if local context.<br>
// <embed src="images/LocalMerge_DiffRoots.svg" type="image/svg+xml"/></p>
//
// <h2>Full-Context Merges</h2>
//
// <p>These full-context merge operations are used when {@code rootIsWildcard}
// is false.</p>
//
// <p><embed src="images/FullMerge_EmptyRoots.svg" type="image/svg+xml"/></p>
//
// <p>Must keep all contexts {@link //EMPTY} in array is a special value (and
// nil parent).<br>
// <embed src="images/FullMerge_EmptyRoot.svg" type="image/svg+xml"/></p>
//
// <p><embed src="images/FullMerge_SameRoot.svg" type="image/svg+xml"/></p>
//
// @param a the first {@link SingletonBasePredictionContext}
// @param b the second {@link SingletonBasePredictionContext}
// @param rootIsWildcard {@code true} if this is a local-context merge,
// otherwise false to indicate a full-context merge
// /
func mergeRoot(a, b *PredictionContext, rootIsWildcard bool) *PredictionContext {
if rootIsWildcard {
if a.pcType == PredictionContextEmpty {
return BasePredictionContextEMPTY // // + b =//
}
if b.pcType == PredictionContextEmpty {
return BasePredictionContextEMPTY // a +// =//
}
} else {
if a.isEmpty() && b.isEmpty() {
return BasePredictionContextEMPTY // $ + $ = $
} else if a.isEmpty() { // $ + x = [$,x]
payloads := []int{b.getReturnState(-1), BasePredictionContextEmptyReturnState}
parents := []*PredictionContext{b.GetParent(-1), nil}
return NewArrayPredictionContext(parents, payloads)
} else if b.isEmpty() { // x + $ = [$,x] ($ is always first if present)
payloads := []int{a.getReturnState(-1), BasePredictionContextEmptyReturnState}
parents := []*PredictionContext{a.GetParent(-1), nil}
return NewArrayPredictionContext(parents, payloads)
}
}
return nil
}
// Merge two {@link ArrayBasePredictionContext} instances.
//
// <p>Different tops, different parents.<br>
// <embed src="images/ArrayMerge_DiffTopDiffPar.svg" type="image/svg+xml"/></p>
//
// <p>Shared top, same parents.<br>
// <embed src="images/ArrayMerge_ShareTopSamePar.svg" type="image/svg+xml"/></p>
//
// <p>Shared top, different parents.<br>
// <embed src="images/ArrayMerge_ShareTopDiffPar.svg" type="image/svg+xml"/></p>
//
// <p>Shared top, all shared parents.<br>
// <embed src="images/ArrayMerge_ShareTopSharePar.svg"
// type="image/svg+xml"/></p>
//
// <p>Equal tops, merge parents and reduce top to
// {@link SingletonBasePredictionContext}.<br>
// <embed src="images/ArrayMerge_EqualTop.svg" type="image/svg+xml"/></p>
//
//goland:noinspection GoBoolExpressions
func mergeArrays(a, b *PredictionContext, rootIsWildcard bool, mergeCache *JPCMap) *PredictionContext {
if mergeCache != nil {
previous, present := mergeCache.Get(a, b)
if present {
if runtimeConfig.parserATNSimulatorTraceATNSim {
fmt.Println("mergeArrays a=" + a.String() + ",b=" + b.String() + " -> previous")
}
return previous
}
previous, present = mergeCache.Get(b, a)
if present {
if runtimeConfig.parserATNSimulatorTraceATNSim {
fmt.Println("mergeArrays a=" + a.String() + ",b=" + b.String() + " -> previous")
}
return previous
}
}
// merge sorted payloads a + b => M
i := 0 // walks a
j := 0 // walks b
k := 0 // walks target M array
mergedReturnStates := make([]int, len(a.returnStates)+len(b.returnStates))
mergedParents := make([]*PredictionContext, len(a.returnStates)+len(b.returnStates))
// walk and merge to yield mergedParents, mergedReturnStates
for i < len(a.returnStates) && j < len(b.returnStates) {
aParent := a.parents[i]
bParent := b.parents[j]
if a.returnStates[i] == b.returnStates[j] {
// same payload (stack tops are equal), must yield merged singleton
payload := a.returnStates[i]
// $+$ = $
bothDollars := payload == BasePredictionContextEmptyReturnState && aParent == nil && bParent == nil
axAX := aParent != nil && bParent != nil && aParent.Equals(bParent) // ax+ax
// ->
// ax
if bothDollars || axAX {
mergedParents[k] = aParent // choose left
mergedReturnStates[k] = payload
} else { // ax+ay -> a'[x,y]
mergedParent := merge(aParent, bParent, rootIsWildcard, mergeCache)
mergedParents[k] = mergedParent
mergedReturnStates[k] = payload
}
i++ // hop over left one as usual
j++ // but also Skip one in right side since we merge
} else if a.returnStates[i] < b.returnStates[j] { // copy a[i] to M
mergedParents[k] = aParent
mergedReturnStates[k] = a.returnStates[i]
i++
} else { // b > a, copy b[j] to M
mergedParents[k] = bParent
mergedReturnStates[k] = b.returnStates[j]
j++
}
k++
}
// copy over any payloads remaining in either array
if i < len(a.returnStates) {
for p := i; p < len(a.returnStates); p++ {
mergedParents[k] = a.parents[p]
mergedReturnStates[k] = a.returnStates[p]
k++
}
} else {
for p := j; p < len(b.returnStates); p++ {
mergedParents[k] = b.parents[p]
mergedReturnStates[k] = b.returnStates[p]
k++
}
}
// trim merged if we combined a few that had same stack tops
if k < len(mergedParents) { // write index < last position trim
if k == 1 { // for just one merged element, return singleton top
pc := SingletonBasePredictionContextCreate(mergedParents[0], mergedReturnStates[0])
if mergeCache != nil {
mergeCache.Put(a, b, pc)
}
return pc
}
mergedParents = mergedParents[0:k]
mergedReturnStates = mergedReturnStates[0:k]
}
M := NewArrayPredictionContext(mergedParents, mergedReturnStates)
// if we created same array as a or b, return that instead
// TODO: JI track whether this is possible above during merge sort for speed and possibly avoid an allocation
if M.Equals(a) {
if mergeCache != nil {
mergeCache.Put(a, b, a)
}
if runtimeConfig.parserATNSimulatorTraceATNSim {
fmt.Println("mergeArrays a=" + a.String() + ",b=" + b.String() + " -> a")
}
return a
}
if M.Equals(b) {
if mergeCache != nil {
mergeCache.Put(a, b, b)
}
if runtimeConfig.parserATNSimulatorTraceATNSim {
fmt.Println("mergeArrays a=" + a.String() + ",b=" + b.String() + " -> b")
}
return b
}
combineCommonParents(&mergedParents)
if mergeCache != nil {
mergeCache.Put(a, b, M)
}
if runtimeConfig.parserATNSimulatorTraceATNSim {
fmt.Println("mergeArrays a=" + a.String() + ",b=" + b.String() + " -> " + M.String())
}
return M
}
// Make pass over all M parents and merge any Equals() ones.
// Note that we pass a pointer to the slice as we want to modify it in place.
//
//goland:noinspection GoUnusedFunction
func combineCommonParents(parents *[]*PredictionContext) {
uniqueParents := NewJStore[*PredictionContext, Comparator[*PredictionContext]](pContextEqInst, PredictionContextCollection, "combineCommonParents for PredictionContext")
for p := 0; p < len(*parents); p++ {
parent := (*parents)[p]
_, _ = uniqueParents.Put(parent)
}
for q := 0; q < len(*parents); q++ {
pc, _ := uniqueParents.Get((*parents)[q])
(*parents)[q] = pc
}
}
func getCachedBasePredictionContext(context *PredictionContext, contextCache *PredictionContextCache, visited *VisitRecord) *PredictionContext {
if context.isEmpty() {
return context
}
existing, present := visited.Get(context)
if present {
return existing
}
existing, present = contextCache.Get(context)
if present {
visited.Put(context, existing)
return existing
}
changed := false
parents := make([]*PredictionContext, context.length())
for i := 0; i < len(parents); i++ {
parent := getCachedBasePredictionContext(context.GetParent(i), contextCache, visited)
if changed || !parent.Equals(context.GetParent(i)) {
if !changed {
parents = make([]*PredictionContext, context.length())
for j := 0; j < context.length(); j++ {
parents[j] = context.GetParent(j)
}
changed = true
}
parents[i] = parent
}
}
if !changed {
contextCache.add(context)
visited.Put(context, context)
return context
}
var updated *PredictionContext
if len(parents) == 0 {
updated = BasePredictionContextEMPTY
} else if len(parents) == 1 {
updated = SingletonBasePredictionContextCreate(parents[0], context.getReturnState(0))
} else {
updated = NewArrayPredictionContext(parents, context.GetReturnStates())
}
contextCache.add(updated)
visited.Put(updated, updated)
visited.Put(context, updated)
return updated
}

View File

@@ -0,0 +1,48 @@
package antlr
var BasePredictionContextEMPTY = &PredictionContext{
cachedHash: calculateEmptyHash(),
pcType: PredictionContextEmpty,
returnState: BasePredictionContextEmptyReturnState,
}
// PredictionContextCache is Used to cache [PredictionContext] objects. It is used for the shared
// context cash associated with contexts in DFA states. This cache
// can be used for both lexers and parsers.
type PredictionContextCache struct {
cache *JMap[*PredictionContext, *PredictionContext, Comparator[*PredictionContext]]
}
func NewPredictionContextCache() *PredictionContextCache {
return &PredictionContextCache{
cache: NewJMap[*PredictionContext, *PredictionContext, Comparator[*PredictionContext]](pContextEqInst, PredictionContextCacheCollection, "NewPredictionContextCache()"),
}
}
// Add a context to the cache and return it. If the context already exists,
// return that one instead and do not add a new context to the cache.
// Protect shared cache from unsafe thread access.
func (p *PredictionContextCache) add(ctx *PredictionContext) *PredictionContext {
if ctx.isEmpty() {
return BasePredictionContextEMPTY
}
// Put will return the existing entry if it is present (note this is done via Equals, not whether it is
// the same pointer), otherwise it will add the new entry and return that.
//
existing, present := p.cache.Get(ctx)
if present {
return existing
}
p.cache.Put(ctx, ctx)
return ctx
}
func (p *PredictionContextCache) Get(ctx *PredictionContext) (*PredictionContext, bool) {
pc, exists := p.cache.Get(ctx)
return pc, exists
}
func (p *PredictionContextCache) length() int {
return p.cache.Len()
}

536
vendor/github.com/antlr4-go/antlr/v4/prediction_mode.go generated vendored Normal file
View File

@@ -0,0 +1,536 @@
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
package antlr
// This enumeration defines the prediction modes available in ANTLR 4 along with
// utility methods for analyzing configuration sets for conflicts and/or
// ambiguities.
const (
// PredictionModeSLL represents the SLL(*) prediction mode.
// This prediction mode ignores the current
// parser context when making predictions. This is the fastest prediction
// mode, and provides correct results for many grammars. This prediction
// mode is more powerful than the prediction mode provided by ANTLR 3, but
// may result in syntax errors for grammar and input combinations which are
// not SLL.
//
// When using this prediction mode, the parser will either return a correct
// parse tree (i.e. the same parse tree that would be returned with the
// [PredictionModeLL] prediction mode), or it will Report a syntax error. If a
// syntax error is encountered when using the SLL prediction mode,
// it may be due to either an actual syntax error in the input or indicate
// that the particular combination of grammar and input requires the more
// powerful LL prediction abilities to complete successfully.
//
// This prediction mode does not provide any guarantees for prediction
// behavior for syntactically-incorrect inputs.
//
PredictionModeSLL = 0
// PredictionModeLL represents the LL(*) prediction mode.
// This prediction mode allows the current parser
// context to be used for resolving SLL conflicts that occur during
// prediction. This is the fastest prediction mode that guarantees correct
// parse results for all combinations of grammars with syntactically correct
// inputs.
//
// When using this prediction mode, the parser will make correct decisions
// for all syntactically-correct grammar and input combinations. However, in
// cases where the grammar is truly ambiguous this prediction mode might not
// report a precise answer for exactly which alternatives are
// ambiguous.
//
// This prediction mode does not provide any guarantees for prediction
// behavior for syntactically-incorrect inputs.
//
PredictionModeLL = 1
// PredictionModeLLExactAmbigDetection represents the LL(*) prediction mode
// with exact ambiguity detection.
//
// In addition to the correctness guarantees provided by the [PredictionModeLL] prediction mode,
// this prediction mode instructs the prediction algorithm to determine the
// complete and exact set of ambiguous alternatives for every ambiguous
// decision encountered while parsing.
//
// This prediction mode may be used for diagnosing ambiguities during
// grammar development. Due to the performance overhead of calculating sets
// of ambiguous alternatives, this prediction mode should be avoided when
// the exact results are not necessary.
//
// This prediction mode does not provide any guarantees for prediction
// behavior for syntactically-incorrect inputs.
//
PredictionModeLLExactAmbigDetection = 2
)
// PredictionModehasSLLConflictTerminatingPrediction computes the SLL prediction termination condition.
//
// This method computes the SLL prediction termination condition for both of
// the following cases:
//
// - The usual SLL+LL fallback upon SLL conflict
// - Pure SLL without LL fallback
//
// # Combined SLL+LL Parsing
//
// When LL-fallback is enabled upon SLL conflict, correct predictions are
// ensured regardless of how the termination condition is computed by this
// method. Due to the substantially higher cost of LL prediction, the
// prediction should only fall back to LL when the additional lookahead
// cannot lead to a unique SLL prediction.
//
// Assuming combined SLL+LL parsing, an SLL configuration set with only
// conflicting subsets should fall back to full LL, even if the
// configuration sets don't resolve to the same alternative, e.g.
//
// {1,2} and {3,4}
//
// If there is at least one non-conflicting
// configuration, SLL could continue with the hopes that more lookahead will
// resolve via one of those non-conflicting configurations.
//
// Here's the prediction termination rule them: SLL (for SLL+LL parsing)
// stops when it sees only conflicting configuration subsets. In contrast,
// full LL keeps going when there is uncertainty.
//
// # Heuristic
//
// As a heuristic, we stop prediction when we see any conflicting subset
// unless we see a state that only has one alternative associated with it.
// The single-alt-state thing lets prediction continue upon rules like
// (otherwise, it would admit defeat too soon):
//
// [12|1|[], 6|2|[], 12|2|[]]. s : (ID | ID ID?) ;
//
// When the [ATN] simulation reaches the state before ';', it has a
// [DFA] state that looks like:
//
// [12|1|[], 6|2|[], 12|2|[]]
//
// Naturally
//
// 12|1|[] and 12|2|[]
//
// conflict, but we cannot stop processing this node because alternative to has another way to continue,
// via
//
// [6|2|[]]
//
// It also let's us continue for this rule:
//
// [1|1|[], 1|2|[], 8|3|[]] a : A | A | A B ;
//
// After Matching input A, we reach the stop state for rule A, state 1.
// State 8 is the state immediately before B. Clearly alternatives 1 and 2
// conflict and no amount of further lookahead will separate the two.
// However, alternative 3 will be able to continue, and so we do not stop
// working on this state. In the previous example, we're concerned with
// states associated with the conflicting alternatives. Here alt 3 is not
// associated with the conflicting configs, but since we can continue
// looking for input reasonably, don't declare the state done.
//
// # Pure SLL Parsing
//
// To handle pure SLL parsing, all we have to do is make sure that we
// combine stack contexts for configurations that differ only by semantic
// predicate. From there, we can do the usual SLL termination heuristic.
//
// # Predicates in SLL+LL Parsing
//
// SLL decisions don't evaluate predicates until after they reach [DFA] stop
// states because they need to create the [DFA] cache that works in all
// semantic situations. In contrast, full LL evaluates predicates collected
// during start state computation, so it can ignore predicates thereafter.
// This means that SLL termination detection can totally ignore semantic
// predicates.
//
// Implementation-wise, [ATNConfigSet] combines stack contexts but not
// semantic predicate contexts, so we might see two configurations like the
// following:
//
// (s, 1, x, {}), (s, 1, x', {p})
//
// Before testing these configurations against others, we have to merge
// x and x' (without modifying the existing configurations).
// For example, we test (x+x')==x” when looking for conflicts in
// the following configurations:
//
// (s, 1, x, {}), (s, 1, x', {p}), (s, 2, x”, {})
//
// If the configuration set has predicates (as indicated by
// [ATNConfigSet.hasSemanticContext]), this algorithm makes a copy of
// the configurations to strip out all the predicates so that a standard
// [ATNConfigSet] will merge everything ignoring predicates.
func PredictionModehasSLLConflictTerminatingPrediction(mode int, configs *ATNConfigSet) bool {
// Configs in rule stop states indicate reaching the end of the decision
// rule (local context) or end of start rule (full context). If all
// configs meet this condition, then none of the configurations is able
// to Match additional input, so we terminate prediction.
//
if PredictionModeallConfigsInRuleStopStates(configs) {
return true
}
// pure SLL mode parsing
if mode == PredictionModeSLL {
// Don't bother with combining configs from different semantic
// contexts if we can fail over to full LL costs more time
// since we'll often fail over anyway.
if configs.hasSemanticContext {
// dup configs, tossing out semantic predicates
dup := NewATNConfigSet(false)
for _, c := range configs.configs {
// NewATNConfig({semanticContext:}, c)
c = NewATNConfig2(c, SemanticContextNone)
dup.Add(c, nil)
}
configs = dup
}
// now we have combined contexts for configs with dissimilar predicates
}
// pure SLL or combined SLL+LL mode parsing
altsets := PredictionModegetConflictingAltSubsets(configs)
return PredictionModehasConflictingAltSet(altsets) && !PredictionModehasStateAssociatedWithOneAlt(configs)
}
// PredictionModehasConfigInRuleStopState checks if any configuration in the given configs is in a
// [RuleStopState]. Configurations meeting this condition have reached
// the end of the decision rule (local context) or end of start rule (full
// context).
//
// The func returns true if any configuration in the supplied configs is in a [RuleStopState]
func PredictionModehasConfigInRuleStopState(configs *ATNConfigSet) bool {
for _, c := range configs.configs {
if _, ok := c.GetState().(*RuleStopState); ok {
return true
}
}
return false
}
// PredictionModeallConfigsInRuleStopStates checks if all configurations in configs are in a
// [RuleStopState]. Configurations meeting this condition have reached
// the end of the decision rule (local context) or end of start rule (full
// context).
//
// the func returns true if all configurations in configs are in a
// [RuleStopState]
func PredictionModeallConfigsInRuleStopStates(configs *ATNConfigSet) bool {
for _, c := range configs.configs {
if _, ok := c.GetState().(*RuleStopState); !ok {
return false
}
}
return true
}
// PredictionModeresolvesToJustOneViableAlt checks full LL prediction termination.
//
// Can we stop looking ahead during [ATN] simulation or is there some
// uncertainty as to which alternative we will ultimately pick, after
// consuming more input? Even if there are partial conflicts, we might know
// that everything is going to resolve to the same minimum alternative. That
// means we can stop since no more lookahead will change that fact. On the
// other hand, there might be multiple conflicts that resolve to different
// minimums. That means we need more look ahead to decide which of those
// alternatives we should predict.
//
// The basic idea is to split the set of configurations 'C', into
// conflicting subsets (s, _, ctx, _) and singleton subsets with
// non-conflicting configurations. Two configurations conflict if they have
// identical [ATNConfig].state and [ATNConfig].context values
// but a different [ATNConfig].alt value, e.g.
//
// (s, i, ctx, _)
//
// and
//
// (s, j, ctx, _) ; for i != j
//
// Reduce these configuration subsets to the set of possible alternatives.
// You can compute the alternative subsets in one pass as follows:
//
// A_s,ctx = {i | (s, i, ctx, _)}
//
// for each configuration in C holding s and ctx fixed.
//
// Or in pseudo-code:
//
// for each configuration c in C:
// map[c] U = c.ATNConfig.alt alt // map hash/equals uses s and x, not alt and not pred
//
// The values in map are the set of
//
// A_s,ctx
//
// sets.
//
// If
//
// |A_s,ctx| = 1
//
// then there is no conflict associated with s and ctx.
//
// Reduce the subsets to singletons by choosing a minimum of each subset. If
// the union of these alternative subsets is a singleton, then no amount of
// further lookahead will help us. We will always pick that alternative. If,
// however, there is more than one alternative, then we are uncertain which
// alternative to predict and must continue looking for resolution. We may
// or may not discover an ambiguity in the future, even if there are no
// conflicting subsets this round.
//
// The biggest sin is to terminate early because it means we've made a
// decision but were uncertain as to the eventual outcome. We haven't used
// enough lookahead. On the other hand, announcing a conflict too late is no
// big deal; you will still have the conflict. It's just inefficient. It
// might even look until the end of file.
//
// No special consideration for semantic predicates is required because
// predicates are evaluated on-the-fly for full LL prediction, ensuring that
// no configuration contains a semantic context during the termination
// check.
//
// # Conflicting Configs
//
// Two configurations:
//
// (s, i, x) and (s, j, x')
//
// conflict when i != j but x = x'. Because we merge all
// (s, i, _) configurations together, that means that there are at
// most n configurations associated with state s for
// n possible alternatives in the decision. The merged stacks
// complicate the comparison of configuration contexts x and x'.
//
// Sam checks to see if one is a subset of the other by calling
// merge and checking to see if the merged result is either x or x'.
// If the x associated with lowest alternative i
// is the superset, then i is the only possible prediction since the
// others resolve to min(i) as well. However, if x is
// associated with j > i then at least one stack configuration for
// j is not in conflict with alternative i. The algorithm
// should keep going, looking for more lookahead due to the uncertainty.
//
// For simplicity, I'm doing an equality check between x and
// x', which lets the algorithm continue to consume lookahead longer
// than necessary. The reason I like the equality is of course the
// simplicity but also because that is the test you need to detect the
// alternatives that are actually in conflict.
//
// # Continue/Stop Rule
//
// Continue if the union of resolved alternative sets from non-conflicting and
// conflicting alternative subsets has more than one alternative. We are
// uncertain about which alternative to predict.
//
// The complete set of alternatives,
//
// [i for (_, i, _)]
//
// tells us which alternatives are still in the running for the amount of input we've
// consumed at this point. The conflicting sets let us to strip away
// configurations that won't lead to more states because we resolve
// conflicts to the configuration with a minimum alternate for the
// conflicting set.
//
// Cases
//
// - no conflicts and more than 1 alternative in set => continue
// - (s, 1, x), (s, 2, x), (s, 3, z), (s', 1, y), (s', 2, y) yields non-conflicting set
// {3} conflicting sets min({1,2}) min({1,2}) = {1,3} => continue
// - (s, 1, x), (s, 2, x), (s', 1, y), (s', 2, y), (s”, 1, z) yields non-conflicting set
// {1} conflicting sets min({1,2}) min({1,2}) = {1} => stop and predict 1
// - (s, 1, x), (s, 2, x), (s', 1, y), (s', 2, y) yields conflicting, reduced sets
// {1} {1} = {1} => stop and predict 1, can announce ambiguity {1,2}
// - (s, 1, x), (s, 2, x), (s', 2, y), (s', 3, y) yields conflicting, reduced sets
// {1} {2} = {1,2} => continue
// - (s, 1, x), (s, 2, x), (s', 2, y), (s', 3, y) yields conflicting, reduced sets
// {1} {2} = {1,2} => continue
// - (s, 1, x), (s, 2, x), (s', 3, y), (s', 4, y) yields conflicting, reduced sets
// {1} {3} = {1,3} => continue
//
// # Exact Ambiguity Detection
//
// If all states report the same conflicting set of alternatives, then we
// know we have the exact ambiguity set:
//
// |A_i| > 1
//
// and
//
// A_i = A_j ; for all i, j
//
// In other words, we continue examining lookahead until all A_i
// have more than one alternative and all A_i are the same. If
//
// A={{1,2}, {1,3}}
//
// then regular LL prediction would terminate because the resolved set is {1}.
// To determine what the real ambiguity is, we have to know whether the ambiguity is between one and
// two or one and three so we keep going. We can only stop prediction when
// we need exact ambiguity detection when the sets look like:
//
// A={{1,2}}
//
// or
//
// {{1,2},{1,2}}, etc...
func PredictionModeresolvesToJustOneViableAlt(altsets []*BitSet) int {
return PredictionModegetSingleViableAlt(altsets)
}
// PredictionModeallSubsetsConflict determines if every alternative subset in altsets contains more
// than one alternative.
//
// The func returns true if every [BitSet] in altsets has
// [BitSet].cardinality cardinality > 1
func PredictionModeallSubsetsConflict(altsets []*BitSet) bool {
return !PredictionModehasNonConflictingAltSet(altsets)
}
// PredictionModehasNonConflictingAltSet determines if any single alternative subset in altsets contains
// exactly one alternative.
//
// The func returns true if altsets contains at least one [BitSet] with
// [BitSet].cardinality cardinality 1
func PredictionModehasNonConflictingAltSet(altsets []*BitSet) bool {
for i := 0; i < len(altsets); i++ {
alts := altsets[i]
if alts.length() == 1 {
return true
}
}
return false
}
// PredictionModehasConflictingAltSet determines if any single alternative subset in altsets contains
// more than one alternative.
//
// The func returns true if altsets contains a [BitSet] with
// [BitSet].cardinality cardinality > 1, otherwise false
func PredictionModehasConflictingAltSet(altsets []*BitSet) bool {
for i := 0; i < len(altsets); i++ {
alts := altsets[i]
if alts.length() > 1 {
return true
}
}
return false
}
// PredictionModeallSubsetsEqual determines if every alternative subset in altsets is equivalent.
//
// The func returns true if every member of altsets is equal to the others.
func PredictionModeallSubsetsEqual(altsets []*BitSet) bool {
var first *BitSet
for i := 0; i < len(altsets); i++ {
alts := altsets[i]
if first == nil {
first = alts
} else if alts != first {
return false
}
}
return true
}
// PredictionModegetUniqueAlt returns the unique alternative predicted by all alternative subsets in
// altsets. If no such alternative exists, this method returns
// [ATNInvalidAltNumber].
//
// @param altsets a collection of alternative subsets
func PredictionModegetUniqueAlt(altsets []*BitSet) int {
all := PredictionModeGetAlts(altsets)
if all.length() == 1 {
return all.minValue()
}
return ATNInvalidAltNumber
}
// PredictionModeGetAlts returns the complete set of represented alternatives for a collection of
// alternative subsets. This method returns the union of each [BitSet]
// in altsets, being the set of represented alternatives in altsets.
func PredictionModeGetAlts(altsets []*BitSet) *BitSet {
all := NewBitSet()
for _, alts := range altsets {
all.or(alts)
}
return all
}
// PredictionModegetConflictingAltSubsets gets the conflicting alt subsets from a configuration set.
//
// for each configuration c in configs:
// map[c] U= c.ATNConfig.alt // map hash/equals uses s and x, not alt and not pred
func PredictionModegetConflictingAltSubsets(configs *ATNConfigSet) []*BitSet {
configToAlts := NewJMap[*ATNConfig, *BitSet, *ATNAltConfigComparator[*ATNConfig]](atnAltCfgEqInst, AltSetCollection, "PredictionModegetConflictingAltSubsets()")
for _, c := range configs.configs {
alts, ok := configToAlts.Get(c)
if !ok {
alts = NewBitSet()
configToAlts.Put(c, alts)
}
alts.add(c.GetAlt())
}
return configToAlts.Values()
}
// PredictionModeGetStateToAltMap gets a map from state to alt subset from a configuration set.
//
// for each configuration c in configs:
// map[c.ATNConfig.state] U= c.ATNConfig.alt}
func PredictionModeGetStateToAltMap(configs *ATNConfigSet) *AltDict {
m := NewAltDict()
for _, c := range configs.configs {
alts := m.Get(c.GetState().String())
if alts == nil {
alts = NewBitSet()
m.put(c.GetState().String(), alts)
}
alts.(*BitSet).add(c.GetAlt())
}
return m
}
func PredictionModehasStateAssociatedWithOneAlt(configs *ATNConfigSet) bool {
values := PredictionModeGetStateToAltMap(configs).values()
for i := 0; i < len(values); i++ {
if values[i].(*BitSet).length() == 1 {
return true
}
}
return false
}
// PredictionModegetSingleViableAlt gets the single alternative predicted by all alternative subsets in altsets
// if there is one.
//
// TODO: JI - Review this code - it does not seem to do the same thing as the Java code - maybe because [BitSet] is not like the Java utils BitSet
func PredictionModegetSingleViableAlt(altsets []*BitSet) int {
result := ATNInvalidAltNumber
for i := 0; i < len(altsets); i++ {
alts := altsets[i]
minAlt := alts.minValue()
if result == ATNInvalidAltNumber {
result = minAlt
} else if result != minAlt { // more than 1 viable alt
return ATNInvalidAltNumber
}
}
return result
}

View File

@@ -26,6 +26,9 @@ type Recognizer interface {
RemoveErrorListeners()
GetATN() *ATN
GetErrorListenerDispatch() ErrorListener
HasError() bool
GetError() RecognitionException
SetError(RecognitionException)
}
type BaseRecognizer struct {
@@ -36,6 +39,7 @@ type BaseRecognizer struct {
LiteralNames []string
SymbolicNames []string
GrammarFileName string
SynErr RecognitionException
}
func NewBaseRecognizer() *BaseRecognizer {
@@ -45,7 +49,10 @@ func NewBaseRecognizer() *BaseRecognizer {
return rec
}
//goland:noinspection GoUnusedGlobalVariable
var tokenTypeMapCache = make(map[string]int)
//goland:noinspection GoUnusedGlobalVariable
var ruleIndexMapCache = make(map[string]int)
func (b *BaseRecognizer) checkVersion(toolVersion string) {
@@ -55,7 +62,19 @@ func (b *BaseRecognizer) checkVersion(toolVersion string) {
}
}
func (b *BaseRecognizer) Action(context RuleContext, ruleIndex, actionIndex int) {
func (b *BaseRecognizer) SetError(err RecognitionException) {
b.SynErr = err
}
func (b *BaseRecognizer) HasError() bool {
return b.SynErr != nil
}
func (b *BaseRecognizer) GetError() RecognitionException {
return b.SynErr
}
func (b *BaseRecognizer) Action(_ RuleContext, _, _ int) {
panic("action not implemented on Recognizer!")
}
@@ -105,9 +124,11 @@ func (b *BaseRecognizer) SetState(v int) {
// return result
//}
// Get a map from rule names to rule indexes.
// GetRuleIndexMap Get a map from rule names to rule indexes.
//
// <p>Used for XPath and tree pattern compilation.</p>
// Used for XPath and tree pattern compilation.
//
// TODO: JI This is not yet implemented in the Go runtime. Maybe not needed.
func (b *BaseRecognizer) GetRuleIndexMap() map[string]int {
panic("Method not defined!")
@@ -124,7 +145,8 @@ func (b *BaseRecognizer) GetRuleIndexMap() map[string]int {
// return result
}
func (b *BaseRecognizer) GetTokenType(tokenName string) int {
// GetTokenType get the token type based upon its name
func (b *BaseRecognizer) GetTokenType(_ string) int {
panic("Method not defined!")
// var ttype = b.GetTokenTypeMap()[tokenName]
// if (ttype !=nil) {
@@ -162,26 +184,27 @@ func (b *BaseRecognizer) GetTokenType(tokenName string) int {
// }
//}
// What is the error header, normally line/character position information?//
// GetErrorHeader returns the error header, normally line/character position information.
//
// Can be overridden in sub structs embedding BaseRecognizer.
func (b *BaseRecognizer) GetErrorHeader(e RecognitionException) string {
line := e.GetOffendingToken().GetLine()
column := e.GetOffendingToken().GetColumn()
return "line " + strconv.Itoa(line) + ":" + strconv.Itoa(column)
}
// How should a token be displayed in an error message? The default
// GetTokenErrorDisplay shows how a token should be displayed in an error message.
//
// is to display just the text, but during development you might
// want to have a lot of information spit out. Override in that case
// to use t.String() (which, for CommonToken, dumps everything about
// the token). This is better than forcing you to override a method in
// your token objects because you don't have to go modify your lexer
// so that it creates a NewJava type.
// The default is to display just the text, but during development you might
// want to have a lot of information spit out. Override in that case
// to use t.String() (which, for CommonToken, dumps everything about
// the token). This is better than forcing you to override a method in
// your token objects because you don't have to go modify your lexer
// so that it creates a NewJava type.
//
// @deprecated This method is not called by the ANTLR 4 Runtime. Specific
// implementations of {@link ANTLRErrorStrategy} may provide a similar
// feature when necessary. For example, see
// {@link DefaultErrorStrategy//GetTokenErrorDisplay}.
// Deprecated: This method is not called by the ANTLR 4 Runtime. Specific
// implementations of [ANTLRErrorStrategy] may provide a similar
// feature when necessary. For example, see [DefaultErrorStrategy].GetTokenErrorDisplay()
func (b *BaseRecognizer) GetTokenErrorDisplay(t Token) string {
if t == nil {
return "<no token>"
@@ -205,12 +228,14 @@ func (b *BaseRecognizer) GetErrorListenerDispatch() ErrorListener {
return NewProxyErrorListener(b.listeners)
}
// subclass needs to override these if there are sempreds or actions
// that the ATN interp needs to execute
func (b *BaseRecognizer) Sempred(localctx RuleContext, ruleIndex int, actionIndex int) bool {
// Sempred embedding structs need to override this if there are sempreds or actions
// that the ATN interpreter needs to execute
func (b *BaseRecognizer) Sempred(_ RuleContext, _ int, _ int) bool {
return true
}
func (b *BaseRecognizer) Precpred(localctx RuleContext, precedence int) bool {
// Precpred embedding structs need to override this if there are preceding predicates
// that the ATN interpreter needs to execute
func (b *BaseRecognizer) Precpred(_ RuleContext, _ int) bool {
return true
}

40
vendor/github.com/antlr4-go/antlr/v4/rule_context.go generated vendored Normal file
View File

@@ -0,0 +1,40 @@
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
package antlr
// RuleContext is a record of a single rule invocation. It knows
// which context invoked it, if any. If there is no parent context, then
// naturally the invoking state is not valid. The parent link
// provides a chain upwards from the current rule invocation to the root
// of the invocation tree, forming a stack.
//
// We actually carry no information about the rule associated with this context (except
// when parsing). We keep only the state number of the invoking state from
// the [ATN] submachine that invoked this. Contrast this with the s
// pointer inside [ParserRuleContext] that tracks the current state
// being "executed" for the current rule.
//
// The parent contexts are useful for computing lookahead sets and
// getting error information.
//
// These objects are used during parsing and prediction.
// For the special case of parsers, we use the struct
// [ParserRuleContext], which embeds a RuleContext.
//
// @see ParserRuleContext
type RuleContext interface {
RuleNode
GetInvokingState() int
SetInvokingState(int)
GetRuleIndex() int
IsEmpty() bool
GetAltNumber() int
SetAltNumber(altNumber int)
String([]string, RuleContext) string
}

View File

@@ -9,14 +9,13 @@ import (
"strconv"
)
// A tree structure used to record the semantic context in which
// an ATN configuration is valid. It's either a single predicate,
// a conjunction {@code p1&&p2}, or a sum of products {@code p1||p2}.
// SemanticContext is a tree structure used to record the semantic context in which
//
// <p>I have scoped the {@link AND}, {@link OR}, and {@link Predicate} subclasses of
// {@link SemanticContext} within the scope of this outer class.</p>
// an ATN configuration is valid. It's either a single predicate,
// a conjunction p1 && p2, or a sum of products p1 || p2.
//
// I have scoped the AND, OR, and Predicate subclasses of
// [SemanticContext] within the scope of this outer ``class''
type SemanticContext interface {
Equals(other Collectable[SemanticContext]) bool
Hash() int
@@ -80,7 +79,7 @@ func NewPredicate(ruleIndex, predIndex int, isCtxDependent bool) *Predicate {
var SemanticContextNone = NewPredicate(-1, -1, false)
func (p *Predicate) evalPrecedence(parser Recognizer, outerContext RuleContext) SemanticContext {
func (p *Predicate) evalPrecedence(_ Recognizer, _ RuleContext) SemanticContext {
return p
}
@@ -198,7 +197,7 @@ type AND struct {
func NewAND(a, b SemanticContext) *AND {
operands := NewJStore[SemanticContext, Comparator[SemanticContext]](semctxEqInst)
operands := NewJStore[SemanticContext, Comparator[SemanticContext]](semctxEqInst, SemanticContextCollection, "NewAND() operands")
if aa, ok := a.(*AND); ok {
for _, o := range aa.opnds {
operands.Put(o)
@@ -230,9 +229,7 @@ func NewAND(a, b SemanticContext) *AND {
vs := operands.Values()
opnds := make([]SemanticContext, len(vs))
for i, v := range vs {
opnds[i] = v.(SemanticContext)
}
copy(opnds, vs)
and := new(AND)
and.opnds = opnds
@@ -316,12 +313,12 @@ func (a *AND) Hash() int {
return murmurFinish(h, len(a.opnds))
}
func (a *OR) Hash() int {
h := murmurInit(41) // Init with a value different from AND
for _, op := range a.opnds {
func (o *OR) Hash() int {
h := murmurInit(41) // Init with o value different from AND
for _, op := range o.opnds {
h = murmurUpdate(h, op.Hash())
}
return murmurFinish(h, len(a.opnds))
return murmurFinish(h, len(o.opnds))
}
func (a *AND) String() string {
@@ -349,7 +346,7 @@ type OR struct {
func NewOR(a, b SemanticContext) *OR {
operands := NewJStore[SemanticContext, Comparator[SemanticContext]](semctxEqInst)
operands := NewJStore[SemanticContext, Comparator[SemanticContext]](semctxEqInst, SemanticContextCollection, "NewOR() operands")
if aa, ok := a.(*OR); ok {
for _, o := range aa.opnds {
operands.Put(o)
@@ -382,9 +379,7 @@ func NewOR(a, b SemanticContext) *OR {
vs := operands.Values()
opnds := make([]SemanticContext, len(vs))
for i, v := range vs {
opnds[i] = v.(SemanticContext)
}
copy(opnds, vs)
o := new(OR)
o.opnds = opnds

281
vendor/github.com/antlr4-go/antlr/v4/statistics.go generated vendored Normal file
View File

@@ -0,0 +1,281 @@
//go:build antlr.stats
package antlr
import (
"fmt"
"log"
"os"
"path/filepath"
"sort"
"strconv"
"sync"
)
// This file allows the user to collect statistics about the runtime of the ANTLR runtime. It is not enabled by default
// and so incurs no time penalty. To enable it, you must build the runtime with the antlr.stats build tag.
//
// Tells various components to collect statistics - because it is only true when this file is included, it will
// allow the compiler to completely eliminate all the code that is only used when collecting statistics.
const collectStats = true
// goRunStats is a collection of all the various data the ANTLR runtime has collected about a particular run.
// It is exported so that it can be used by others to look for things that are not already looked for in the
// runtime statistics.
type goRunStats struct {
// jStats is a slice of all the [JStatRec] records that have been created, which is one for EVERY collection created
// during a run. It is exported so that it can be used by others to look for things that are not already looked for
// within this package.
//
jStats []*JStatRec
jStatsLock sync.RWMutex
topN int
topNByMax []*JStatRec
topNByUsed []*JStatRec
unusedCollections map[CollectionSource]int
counts map[CollectionSource]int
}
const (
collectionsFile = "collections"
)
var (
Statistics = &goRunStats{
topN: 10,
}
)
type statsOption func(*goRunStats) error
// Configure allows the statistics system to be configured as the user wants and override the defaults
func (s *goRunStats) Configure(options ...statsOption) error {
for _, option := range options {
err := option(s)
if err != nil {
return err
}
}
return nil
}
// WithTopN sets the number of things to list in the report when we are concerned with the top N things.
//
// For example, if you want to see the top 20 collections by size, you can do:
//
// antlr.Statistics.Configure(antlr.WithTopN(20))
func WithTopN(topN int) statsOption {
return func(s *goRunStats) error {
s.topN = topN
return nil
}
}
// Analyze looks through all the statistical records and computes all the outputs that might be useful to the user.
//
// The function gathers and analyzes a number of statistics about any particular run of
// an ANTLR generated recognizer. In the vast majority of cases, the statistics are only
// useful to maintainers of ANTLR itself, but they can be useful to users as well. They may be
// especially useful in tracking down bugs or performance problems when an ANTLR user could
// supply the output from this package, but cannot supply the grammar file(s) they are using, even
// privately to the maintainers.
//
// The statistics are gathered by the runtime itself, and are not gathered by the parser or lexer, but the user
// must call this function their selves to analyze the statistics. This is because none of the infrastructure is
// extant unless the calling program is built with the antlr.stats tag like so:
//
// go build -tags antlr.stats .
//
// When a program is built with the antlr.stats tag, the Statistics object is created and available outside
// the package. The user can then call the [Statistics.Analyze] function to analyze the statistics and then call the
// [Statistics.Report] function to report the statistics.
//
// Please forward any questions about this package to the ANTLR discussion groups on GitHub or send to them to
// me [Jim Idle] directly at jimi@idle.ws
//
// [Jim Idle]: https:://github.com/jim-idle
func (s *goRunStats) Analyze() {
// Look for anything that looks strange and record it in our local maps etc for the report to present it
//
s.CollectionAnomalies()
s.TopNCollections()
}
// TopNCollections looks through all the statistical records and gathers the top ten collections by size.
func (s *goRunStats) TopNCollections() {
// Let's sort the stat records by MaxSize
//
sort.Slice(s.jStats, func(i, j int) bool {
return s.jStats[i].MaxSize > s.jStats[j].MaxSize
})
for i := 0; i < len(s.jStats) && i < s.topN; i++ {
s.topNByMax = append(s.topNByMax, s.jStats[i])
}
// Sort by the number of times used
//
sort.Slice(s.jStats, func(i, j int) bool {
return s.jStats[i].Gets+s.jStats[i].Puts > s.jStats[j].Gets+s.jStats[j].Puts
})
for i := 0; i < len(s.jStats) && i < s.topN; i++ {
s.topNByUsed = append(s.topNByUsed, s.jStats[i])
}
}
// Report dumps a markdown formatted report of all the statistics collected during a run to the given dir output
// path, which should represent a directory. Generated files will be prefixed with the given prefix and will be
// given a type name such as `anomalies` and a time stamp such as `2021-09-01T12:34:56` and a .md suffix.
func (s *goRunStats) Report(dir string, prefix string) error {
isDir, err := isDirectory(dir)
switch {
case err != nil:
return err
case !isDir:
return fmt.Errorf("output directory `%s` is not a directory", dir)
}
s.reportCollections(dir, prefix)
// Clean out any old data in case the user forgets
//
s.Reset()
return nil
}
func (s *goRunStats) Reset() {
s.jStats = nil
s.topNByUsed = nil
s.topNByMax = nil
}
func (s *goRunStats) reportCollections(dir, prefix string) {
cname := filepath.Join(dir, ".asciidoctor")
// If the file doesn't exist, create it, or append to the file
f, err := os.OpenFile(cname, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal(err)
}
_, _ = f.WriteString(`// .asciidoctorconfig
++++
<style>
body {
font-family: "Quicksand", "Montserrat", "Helvetica";
background-color: black;
}
</style>
++++`)
_ = f.Close()
fname := filepath.Join(dir, prefix+"_"+"_"+collectionsFile+"_"+".adoc")
// If the file doesn't exist, create it, or append to the file
f, err = os.OpenFile(fname, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal(err)
}
defer func(f *os.File) {
err := f.Close()
if err != nil {
log.Fatal(err)
}
}(f)
_, _ = f.WriteString("= Collections for " + prefix + "\n\n")
_, _ = f.WriteString("== Summary\n")
if s.unusedCollections != nil {
_, _ = f.WriteString("=== Unused Collections\n")
_, _ = f.WriteString("Unused collections incur a penalty for allocation that makes them a candidate for either\n")
_, _ = f.WriteString(" removal or optimization. If you are using a collection that is not used, you should\n")
_, _ = f.WriteString(" consider removing it. If you are using a collection that is used, but not very often,\n")
_, _ = f.WriteString(" you should consider using lazy initialization to defer the allocation until it is\n")
_, _ = f.WriteString(" actually needed.\n\n")
_, _ = f.WriteString("\n.Unused collections\n")
_, _ = f.WriteString(`[cols="<3,>1"]` + "\n\n")
_, _ = f.WriteString("|===\n")
_, _ = f.WriteString("| Type | Count\n")
for k, v := range s.unusedCollections {
_, _ = f.WriteString("| " + CollectionDescriptors[k].SybolicName + " | " + strconv.Itoa(v) + "\n")
}
f.WriteString("|===\n\n")
}
_, _ = f.WriteString("\n.Summary of Collections\n")
_, _ = f.WriteString(`[cols="<3,>1"]` + "\n\n")
_, _ = f.WriteString("|===\n")
_, _ = f.WriteString("| Type | Count\n")
for k, v := range s.counts {
_, _ = f.WriteString("| " + CollectionDescriptors[k].SybolicName + " | " + strconv.Itoa(v) + "\n")
}
_, _ = f.WriteString("| Total | " + strconv.Itoa(len(s.jStats)) + "\n")
_, _ = f.WriteString("|===\n\n")
_, _ = f.WriteString("\n.Summary of Top " + strconv.Itoa(s.topN) + " Collections by MaxSize\n")
_, _ = f.WriteString(`[cols="<1,<3,>1,>1,>1,>1"]` + "\n\n")
_, _ = f.WriteString("|===\n")
_, _ = f.WriteString("| Source | Description | MaxSize | EndSize | Puts | Gets\n")
for _, c := range s.topNByMax {
_, _ = f.WriteString("| " + CollectionDescriptors[c.Source].SybolicName + "\n")
_, _ = f.WriteString("| " + c.Description + "\n")
_, _ = f.WriteString("| " + strconv.Itoa(c.MaxSize) + "\n")
_, _ = f.WriteString("| " + strconv.Itoa(c.CurSize) + "\n")
_, _ = f.WriteString("| " + strconv.Itoa(c.Puts) + "\n")
_, _ = f.WriteString("| " + strconv.Itoa(c.Gets) + "\n")
_, _ = f.WriteString("\n")
}
_, _ = f.WriteString("|===\n\n")
_, _ = f.WriteString("\n.Summary of Top " + strconv.Itoa(s.topN) + " Collections by Access\n")
_, _ = f.WriteString(`[cols="<1,<3,>1,>1,>1,>1,>1"]` + "\n\n")
_, _ = f.WriteString("|===\n")
_, _ = f.WriteString("| Source | Description | MaxSize | EndSize | Puts | Gets | P+G\n")
for _, c := range s.topNByUsed {
_, _ = f.WriteString("| " + CollectionDescriptors[c.Source].SybolicName + "\n")
_, _ = f.WriteString("| " + c.Description + "\n")
_, _ = f.WriteString("| " + strconv.Itoa(c.MaxSize) + "\n")
_, _ = f.WriteString("| " + strconv.Itoa(c.CurSize) + "\n")
_, _ = f.WriteString("| " + strconv.Itoa(c.Puts) + "\n")
_, _ = f.WriteString("| " + strconv.Itoa(c.Gets) + "\n")
_, _ = f.WriteString("| " + strconv.Itoa(c.Gets+c.Puts) + "\n")
_, _ = f.WriteString("\n")
}
_, _ = f.WriteString("|===\n\n")
}
// AddJStatRec adds a [JStatRec] record to the [goRunStats] collection when build runtimeConfig antlr.stats is enabled.
func (s *goRunStats) AddJStatRec(rec *JStatRec) {
s.jStatsLock.Lock()
defer s.jStatsLock.Unlock()
s.jStats = append(s.jStats, rec)
}
// CollectionAnomalies looks through all the statistical records and gathers any anomalies that have been found.
func (s *goRunStats) CollectionAnomalies() {
s.jStatsLock.RLock()
defer s.jStatsLock.RUnlock()
s.counts = make(map[CollectionSource]int, len(s.jStats))
for _, c := range s.jStats {
// Accumlate raw counts
//
s.counts[c.Source]++
// Look for allocated but unused collections and count them
if c.MaxSize == 0 && c.Puts == 0 {
if s.unusedCollections == nil {
s.unusedCollections = make(map[CollectionSource]int)
}
s.unusedCollections[c.Source]++
}
if c.MaxSize > 6000 {
fmt.Println("Collection ", c.Description, "accumulated a max size of ", c.MaxSize, " - this is probably too large and indicates a poorly formed grammar")
}
}
}

23
vendor/github.com/antlr4-go/antlr/v4/stats_data.go generated vendored Normal file
View File

@@ -0,0 +1,23 @@
package antlr
// A JStatRec is a record of a particular use of a [JStore], [JMap] or JPCMap] collection. Typically, it will be
// used to look for unused collections that wre allocated anyway, problems with hash bucket clashes, and anomalies
// such as huge numbers of Gets with no entries found GetNoEnt. You can refer to the CollectionAnomalies() function
// for ideas on what can be gleaned from these statistics about collections.
type JStatRec struct {
Source CollectionSource
MaxSize int
CurSize int
Gets int
GetHits int
GetMisses int
GetHashConflicts int
GetNoEnt int
Puts int
PutHits int
PutMisses int
PutHashConflicts int
MaxSlotSize int
Description string
CreateStack []byte
}

View File

@@ -35,6 +35,8 @@ type Token interface {
GetTokenSource() TokenSource
GetInputStream() CharStream
String() string
}
type BaseToken struct {
@@ -53,7 +55,7 @@ type BaseToken struct {
const (
TokenInvalidType = 0
// During lookahead operations, this "token" signifies we hit rule end ATN state
// TokenEpsilon - during lookahead operations, this "token" signifies we hit the rule end [ATN] state
// and did not follow it despite needing to.
TokenEpsilon = -2
@@ -61,15 +63,16 @@ const (
TokenEOF = -1
// All tokens go to the parser (unless Skip() is called in that rule)
// TokenDefaultChannel is the default channel upon which tokens are sent to the parser.
//
// All tokens go to the parser (unless [Skip] is called in the lexer rule)
// on a particular "channel". The parser tunes to a particular channel
// so that whitespace etc... can go to the parser on a "hidden" channel.
TokenDefaultChannel = 0
// Anything on different channel than DEFAULT_CHANNEL is not parsed
// by parser.
// TokenHiddenChannel defines the normal hidden channel - the parser wil not see tokens that are not on [TokenDefaultChannel].
//
// Anything on a different channel than TokenDefaultChannel is not parsed by parser.
TokenHiddenChannel = 1
)
@@ -118,21 +121,22 @@ func (b *BaseToken) GetInputStream() CharStream {
}
type CommonToken struct {
*BaseToken
BaseToken
}
func NewCommonToken(source *TokenSourceCharStreamPair, tokenType, channel, start, stop int) *CommonToken {
t := new(CommonToken)
t := &CommonToken{
BaseToken: BaseToken{
source: source,
tokenType: tokenType,
channel: channel,
start: start,
stop: stop,
tokenIndex: -1,
},
}
t.BaseToken = new(BaseToken)
t.source = source
t.tokenType = tokenType
t.channel = channel
t.start = start
t.stop = stop
t.tokenIndex = -1
if t.source.tokenSource != nil {
t.line = source.tokenSource.GetLine()
t.column = source.tokenSource.GetCharPositionInLine()

View File

@@ -8,13 +8,14 @@ type TokenStream interface {
IntStream
LT(k int) Token
Reset()
Get(index int) Token
GetTokenSource() TokenSource
SetTokenSource(TokenSource)
GetAllText() string
GetTextFromInterval(*Interval) string
GetTextFromInterval(Interval) string
GetTextFromRuleContext(RuleContext) string
GetTextFromTokens(Token, Token) string
}

View File

@@ -86,14 +86,15 @@ import (
// first example shows.</p>
const (
Default_Program_Name = "default"
Program_Init_Size = 100
Min_Token_Index = 0
DefaultProgramName = "default"
ProgramInitSize = 100
MinTokenIndex = 0
)
// Define the rewrite operation hierarchy
type RewriteOperation interface {
// Execute the rewrite operation by possibly adding to the buffer.
// Return the index of the next token to operate on.
Execute(buffer *bytes.Buffer) int
@@ -112,19 +113,19 @@ type RewriteOperation interface {
type BaseRewriteOperation struct {
//Current index of rewrites list
instruction_index int
instructionIndex int
//Token buffer index
index int
//Substitution text
text string
//Actual operation name
op_name string
opName string
//Pointer to token steam
tokens TokenStream
}
func (op *BaseRewriteOperation) GetInstructionIndex() int {
return op.instruction_index
return op.instructionIndex
}
func (op *BaseRewriteOperation) GetIndex() int {
@@ -136,7 +137,7 @@ func (op *BaseRewriteOperation) GetText() string {
}
func (op *BaseRewriteOperation) GetOpName() string {
return op.op_name
return op.opName
}
func (op *BaseRewriteOperation) GetTokens() TokenStream {
@@ -144,7 +145,7 @@ func (op *BaseRewriteOperation) GetTokens() TokenStream {
}
func (op *BaseRewriteOperation) SetInstructionIndex(val int) {
op.instruction_index = val
op.instructionIndex = val
}
func (op *BaseRewriteOperation) SetIndex(val int) {
@@ -156,20 +157,20 @@ func (op *BaseRewriteOperation) SetText(val string) {
}
func (op *BaseRewriteOperation) SetOpName(val string) {
op.op_name = val
op.opName = val
}
func (op *BaseRewriteOperation) SetTokens(val TokenStream) {
op.tokens = val
}
func (op *BaseRewriteOperation) Execute(buffer *bytes.Buffer) int {
func (op *BaseRewriteOperation) Execute(_ *bytes.Buffer) int {
return op.index
}
func (op *BaseRewriteOperation) String() string {
return fmt.Sprintf("<%s@%d:\"%s\">",
op.op_name,
op.opName,
op.tokens.Get(op.GetIndex()),
op.text,
)
@@ -182,10 +183,10 @@ type InsertBeforeOp struct {
func NewInsertBeforeOp(index int, text string, stream TokenStream) *InsertBeforeOp {
return &InsertBeforeOp{BaseRewriteOperation: BaseRewriteOperation{
index: index,
text: text,
op_name: "InsertBeforeOp",
tokens: stream,
index: index,
text: text,
opName: "InsertBeforeOp",
tokens: stream,
}}
}
@@ -201,20 +202,21 @@ func (op *InsertBeforeOp) String() string {
return op.BaseRewriteOperation.String()
}
// Distinguish between insert after/before to do the "insert afters"
// first and then the "insert befores" at same index. Implementation
// of "insert after" is "insert before index+1".
// InsertAfterOp distinguishes between insert after/before to do the "insert after" instructions
// first and then the "insert before" instructions at same index. Implementation
// of "insert after" is "insert before index+1".
type InsertAfterOp struct {
BaseRewriteOperation
}
func NewInsertAfterOp(index int, text string, stream TokenStream) *InsertAfterOp {
return &InsertAfterOp{BaseRewriteOperation: BaseRewriteOperation{
index: index + 1,
text: text,
tokens: stream,
}}
return &InsertAfterOp{
BaseRewriteOperation: BaseRewriteOperation{
index: index + 1,
text: text,
tokens: stream,
},
}
}
func (op *InsertAfterOp) Execute(buffer *bytes.Buffer) int {
@@ -229,7 +231,7 @@ func (op *InsertAfterOp) String() string {
return op.BaseRewriteOperation.String()
}
// I'm going to try replacing range from x..y with (y-x)+1 ReplaceOp
// ReplaceOp tries to replace range from x..y with (y-x)+1 ReplaceOp
// instructions.
type ReplaceOp struct {
BaseRewriteOperation
@@ -239,10 +241,10 @@ type ReplaceOp struct {
func NewReplaceOp(from, to int, text string, stream TokenStream) *ReplaceOp {
return &ReplaceOp{
BaseRewriteOperation: BaseRewriteOperation{
index: from,
text: text,
op_name: "ReplaceOp",
tokens: stream,
index: from,
text: text,
opName: "ReplaceOp",
tokens: stream,
},
LastIndex: to,
}
@@ -270,17 +272,17 @@ type TokenStreamRewriter struct {
// You may have multiple, named streams of rewrite operations.
// I'm calling these things "programs."
// Maps String (name) &rarr; rewrite (List)
programs map[string][]RewriteOperation
last_rewrite_token_indexes map[string]int
programs map[string][]RewriteOperation
lastRewriteTokenIndexes map[string]int
}
func NewTokenStreamRewriter(tokens TokenStream) *TokenStreamRewriter {
return &TokenStreamRewriter{
tokens: tokens,
programs: map[string][]RewriteOperation{
Default_Program_Name: make([]RewriteOperation, 0, Program_Init_Size),
DefaultProgramName: make([]RewriteOperation, 0, ProgramInitSize),
},
last_rewrite_token_indexes: map[string]int{},
lastRewriteTokenIndexes: map[string]int{},
}
}
@@ -291,110 +293,110 @@ func (tsr *TokenStreamRewriter) GetTokenStream() TokenStream {
// Rollback the instruction stream for a program so that
// the indicated instruction (via instructionIndex) is no
// longer in the stream. UNTESTED!
func (tsr *TokenStreamRewriter) Rollback(program_name string, instruction_index int) {
is, ok := tsr.programs[program_name]
func (tsr *TokenStreamRewriter) Rollback(programName string, instructionIndex int) {
is, ok := tsr.programs[programName]
if ok {
tsr.programs[program_name] = is[Min_Token_Index:instruction_index]
tsr.programs[programName] = is[MinTokenIndex:instructionIndex]
}
}
func (tsr *TokenStreamRewriter) RollbackDefault(instruction_index int) {
tsr.Rollback(Default_Program_Name, instruction_index)
func (tsr *TokenStreamRewriter) RollbackDefault(instructionIndex int) {
tsr.Rollback(DefaultProgramName, instructionIndex)
}
// Reset the program so that no instructions exist
func (tsr *TokenStreamRewriter) DeleteProgram(program_name string) {
tsr.Rollback(program_name, Min_Token_Index) //TODO: double test on that cause lower bound is not included
// DeleteProgram Reset the program so that no instructions exist
func (tsr *TokenStreamRewriter) DeleteProgram(programName string) {
tsr.Rollback(programName, MinTokenIndex) //TODO: double test on that cause lower bound is not included
}
func (tsr *TokenStreamRewriter) DeleteProgramDefault() {
tsr.DeleteProgram(Default_Program_Name)
tsr.DeleteProgram(DefaultProgramName)
}
func (tsr *TokenStreamRewriter) InsertAfter(program_name string, index int, text string) {
func (tsr *TokenStreamRewriter) InsertAfter(programName string, index int, text string) {
// to insert after, just insert before next index (even if past end)
var op RewriteOperation = NewInsertAfterOp(index, text, tsr.tokens)
rewrites := tsr.GetProgram(program_name)
rewrites := tsr.GetProgram(programName)
op.SetInstructionIndex(len(rewrites))
tsr.AddToProgram(program_name, op)
tsr.AddToProgram(programName, op)
}
func (tsr *TokenStreamRewriter) InsertAfterDefault(index int, text string) {
tsr.InsertAfter(Default_Program_Name, index, text)
tsr.InsertAfter(DefaultProgramName, index, text)
}
func (tsr *TokenStreamRewriter) InsertAfterToken(program_name string, token Token, text string) {
tsr.InsertAfter(program_name, token.GetTokenIndex(), text)
func (tsr *TokenStreamRewriter) InsertAfterToken(programName string, token Token, text string) {
tsr.InsertAfter(programName, token.GetTokenIndex(), text)
}
func (tsr *TokenStreamRewriter) InsertBefore(program_name string, index int, text string) {
func (tsr *TokenStreamRewriter) InsertBefore(programName string, index int, text string) {
var op RewriteOperation = NewInsertBeforeOp(index, text, tsr.tokens)
rewrites := tsr.GetProgram(program_name)
rewrites := tsr.GetProgram(programName)
op.SetInstructionIndex(len(rewrites))
tsr.AddToProgram(program_name, op)
tsr.AddToProgram(programName, op)
}
func (tsr *TokenStreamRewriter) InsertBeforeDefault(index int, text string) {
tsr.InsertBefore(Default_Program_Name, index, text)
tsr.InsertBefore(DefaultProgramName, index, text)
}
func (tsr *TokenStreamRewriter) InsertBeforeToken(program_name string, token Token, text string) {
tsr.InsertBefore(program_name, token.GetTokenIndex(), text)
func (tsr *TokenStreamRewriter) InsertBeforeToken(programName string, token Token, text string) {
tsr.InsertBefore(programName, token.GetTokenIndex(), text)
}
func (tsr *TokenStreamRewriter) Replace(program_name string, from, to int, text string) {
func (tsr *TokenStreamRewriter) Replace(programName string, from, to int, text string) {
if from > to || from < 0 || to < 0 || to >= tsr.tokens.Size() {
panic(fmt.Sprintf("replace: range invalid: %d..%d(size=%d)",
from, to, tsr.tokens.Size()))
}
var op RewriteOperation = NewReplaceOp(from, to, text, tsr.tokens)
rewrites := tsr.GetProgram(program_name)
rewrites := tsr.GetProgram(programName)
op.SetInstructionIndex(len(rewrites))
tsr.AddToProgram(program_name, op)
tsr.AddToProgram(programName, op)
}
func (tsr *TokenStreamRewriter) ReplaceDefault(from, to int, text string) {
tsr.Replace(Default_Program_Name, from, to, text)
tsr.Replace(DefaultProgramName, from, to, text)
}
func (tsr *TokenStreamRewriter) ReplaceDefaultPos(index int, text string) {
tsr.ReplaceDefault(index, index, text)
}
func (tsr *TokenStreamRewriter) ReplaceToken(program_name string, from, to Token, text string) {
tsr.Replace(program_name, from.GetTokenIndex(), to.GetTokenIndex(), text)
func (tsr *TokenStreamRewriter) ReplaceToken(programName string, from, to Token, text string) {
tsr.Replace(programName, from.GetTokenIndex(), to.GetTokenIndex(), text)
}
func (tsr *TokenStreamRewriter) ReplaceTokenDefault(from, to Token, text string) {
tsr.ReplaceToken(Default_Program_Name, from, to, text)
tsr.ReplaceToken(DefaultProgramName, from, to, text)
}
func (tsr *TokenStreamRewriter) ReplaceTokenDefaultPos(index Token, text string) {
tsr.ReplaceTokenDefault(index, index, text)
}
func (tsr *TokenStreamRewriter) Delete(program_name string, from, to int) {
tsr.Replace(program_name, from, to, "")
func (tsr *TokenStreamRewriter) Delete(programName string, from, to int) {
tsr.Replace(programName, from, to, "")
}
func (tsr *TokenStreamRewriter) DeleteDefault(from, to int) {
tsr.Delete(Default_Program_Name, from, to)
tsr.Delete(DefaultProgramName, from, to)
}
func (tsr *TokenStreamRewriter) DeleteDefaultPos(index int) {
tsr.DeleteDefault(index, index)
}
func (tsr *TokenStreamRewriter) DeleteToken(program_name string, from, to Token) {
tsr.ReplaceToken(program_name, from, to, "")
func (tsr *TokenStreamRewriter) DeleteToken(programName string, from, to Token) {
tsr.ReplaceToken(programName, from, to, "")
}
func (tsr *TokenStreamRewriter) DeleteTokenDefault(from, to Token) {
tsr.DeleteToken(Default_Program_Name, from, to)
tsr.DeleteToken(DefaultProgramName, from, to)
}
func (tsr *TokenStreamRewriter) GetLastRewriteTokenIndex(program_name string) int {
i, ok := tsr.last_rewrite_token_indexes[program_name]
func (tsr *TokenStreamRewriter) GetLastRewriteTokenIndex(programName string) int {
i, ok := tsr.lastRewriteTokenIndexes[programName]
if !ok {
return -1
}
@@ -402,15 +404,15 @@ func (tsr *TokenStreamRewriter) GetLastRewriteTokenIndex(program_name string) in
}
func (tsr *TokenStreamRewriter) GetLastRewriteTokenIndexDefault() int {
return tsr.GetLastRewriteTokenIndex(Default_Program_Name)
return tsr.GetLastRewriteTokenIndex(DefaultProgramName)
}
func (tsr *TokenStreamRewriter) SetLastRewriteTokenIndex(program_name string, i int) {
tsr.last_rewrite_token_indexes[program_name] = i
func (tsr *TokenStreamRewriter) SetLastRewriteTokenIndex(programName string, i int) {
tsr.lastRewriteTokenIndexes[programName] = i
}
func (tsr *TokenStreamRewriter) InitializeProgram(name string) []RewriteOperation {
is := make([]RewriteOperation, 0, Program_Init_Size)
is := make([]RewriteOperation, 0, ProgramInitSize)
tsr.programs[name] = is
return is
}
@@ -429,24 +431,24 @@ func (tsr *TokenStreamRewriter) GetProgram(name string) []RewriteOperation {
return is
}
// Return the text from the original tokens altered per the
// GetTextDefault returns the text from the original tokens altered per the
// instructions given to this rewriter.
func (tsr *TokenStreamRewriter) GetTextDefault() string {
return tsr.GetText(
Default_Program_Name,
DefaultProgramName,
NewInterval(0, tsr.tokens.Size()-1))
}
// Return the text from the original tokens altered per the
// GetText returns the text from the original tokens altered per the
// instructions given to this rewriter.
func (tsr *TokenStreamRewriter) GetText(program_name string, interval *Interval) string {
rewrites := tsr.programs[program_name]
func (tsr *TokenStreamRewriter) GetText(programName string, interval Interval) string {
rewrites := tsr.programs[programName]
start := interval.Start
stop := interval.Stop
// ensure start/end are in range
stop = min(stop, tsr.tokens.Size()-1)
start = max(start, 0)
if rewrites == nil || len(rewrites) == 0 {
if len(rewrites) == 0 {
return tsr.tokens.GetTextFromInterval(interval) // no instructions to execute
}
buf := bytes.Buffer{}
@@ -482,11 +484,13 @@ func (tsr *TokenStreamRewriter) GetText(program_name string, interval *Interval)
return buf.String()
}
// We need to combine operations and report invalid operations (like
// overlapping replaces that are not completed nested). Inserts to
// same index need to be combined etc... Here are the cases:
// reduceToSingleOperationPerIndex combines operations and report invalid operations (like
// overlapping replaces that are not completed nested). Inserts to
// same index need to be combined etc...
//
// I.i.u I.j.v leave alone, nonoverlapping
// Here are the cases:
//
// I.i.u I.j.v leave alone, non-overlapping
// I.i.u I.i.v combine: Iivu
//
// R.i-j.u R.x-y.v | i-j in x-y delete first R
@@ -498,38 +502,38 @@ func (tsr *TokenStreamRewriter) GetText(program_name string, interval *Interval)
// D.i-j.u D.x-y.v | boundaries overlap combine to max(min)..max(right)
//
// I.i.u R.x-y.v | i in (x+1)-y delete I (since insert before
// we're not deleting i)
// I.i.u R.x-y.v | i not in (x+1)-y leave alone, nonoverlapping
// we're not deleting i)
// I.i.u R.x-y.v | i not in (x+1)-y leave alone, non-overlapping
// R.x-y.v I.i.u | i in x-y ERROR
// R.x-y.v I.x.u R.x-y.uv (combine, delete I)
// R.x-y.v I.i.u | i not in x-y leave alone, nonoverlapping
// R.x-y.v I.i.u | i not in x-y leave alone, non-overlapping
//
// I.i.u = insert u before op @ index i
// R.x-y.u = replace x-y indexed tokens with u
//
// First we need to examine replaces. For any replace op:
// First we need to examine replaces. For any replace op:
//
// 1. wipe out any insertions before op within that range.
// 2. Drop any replace op before that is contained completely within
// that range.
// 3. Throw exception upon boundary overlap with any previous replace.
// 1. wipe out any insertions before op within that range.
// 2. Drop any replace op before that is contained completely within
// that range.
// 3. Throw exception upon boundary overlap with any previous replace.
//
// Then we can deal with inserts:
// Then we can deal with inserts:
//
// 1. for any inserts to same index, combine even if not adjacent.
// 2. for any prior replace with same left boundary, combine this
// insert with replace and delete this replace.
// 3. throw exception if index in same range as previous replace
// 1. for any inserts to same index, combine even if not adjacent.
// 2. for any prior replace with same left boundary, combine this
// insert with replace and delete this 'replace'.
// 3. throw exception if index in same range as previous replace
//
// Don't actually delete; make op null in list. Easier to walk list.
// Later we can throw as we add to index &rarr; op map.
// Don't actually delete; make op null in list. Easier to walk list.
// Later we can throw as we add to index &rarr; op map.
//
// Note that I.2 R.2-2 will wipe out I.2 even though, technically, the
// inserted stuff would be before the replace range. But, if you
// add tokens in front of a method body '{' and then delete the method
// body, I think the stuff before the '{' you added should disappear too.
// Note that I.2 R.2-2 will wipe out I.2 even though, technically, the
// inserted stuff would be before the 'replace' range. But, if you
// add tokens in front of a method body '{' and then delete the method
// body, I think the stuff before the '{' you added should disappear too.
//
// Return a map from token index to operation.
// The func returns a map from token index to operation.
func reduceToSingleOperationPerIndex(rewrites []RewriteOperation) map[int]RewriteOperation {
// WALK REPLACES
for i := 0; i < len(rewrites); i++ {
@@ -547,7 +551,7 @@ func reduceToSingleOperationPerIndex(rewrites []RewriteOperation) map[int]Rewrit
if iop.index == rop.index {
// E.g., insert before 2, delete 2..2; update replace
// text to include insert before, kill insert
rewrites[iop.instruction_index] = nil
rewrites[iop.instructionIndex] = nil
if rop.text != "" {
rop.text = iop.text + rop.text
} else {
@@ -555,7 +559,7 @@ func reduceToSingleOperationPerIndex(rewrites []RewriteOperation) map[int]Rewrit
}
} else if iop.index > rop.index && iop.index <= rop.LastIndex {
// delete insert as it's a no-op.
rewrites[iop.instruction_index] = nil
rewrites[iop.instructionIndex] = nil
}
}
}
@@ -564,7 +568,7 @@ func reduceToSingleOperationPerIndex(rewrites []RewriteOperation) map[int]Rewrit
if prevop, ok := rewrites[j].(*ReplaceOp); ok {
if prevop.index >= rop.index && prevop.LastIndex <= rop.LastIndex {
// delete replace as it's a no-op.
rewrites[prevop.instruction_index] = nil
rewrites[prevop.instructionIndex] = nil
continue
}
// throw exception unless disjoint or identical
@@ -572,10 +576,9 @@ func reduceToSingleOperationPerIndex(rewrites []RewriteOperation) map[int]Rewrit
// Delete special case of replace (text==null):
// D.i-j.u D.x-y.v | boundaries overlap combine to max(min)..max(right)
if prevop.text == "" && rop.text == "" && !disjoint {
rewrites[prevop.instruction_index] = nil
rewrites[prevop.instructionIndex] = nil
rop.index = min(prevop.index, rop.index)
rop.LastIndex = max(prevop.LastIndex, rop.LastIndex)
println("new rop" + rop.String()) //TODO: remove console write, taken from Java version
} else if !disjoint {
panic("replace op boundaries of " + rop.String() + " overlap with previous " + prevop.String())
}
@@ -607,7 +610,7 @@ func reduceToSingleOperationPerIndex(rewrites []RewriteOperation) map[int]Rewrit
if prevIop, ok := rewrites[j].(*InsertBeforeOp); ok {
if prevIop.index == iop.GetIndex() {
iop.SetText(iop.GetText() + prevIop.text)
rewrites[prevIop.instruction_index] = nil
rewrites[prevIop.instructionIndex] = nil
}
}
}

View File

@@ -72,7 +72,7 @@ func (t *BaseTransition) getSerializationType() int {
return t.serializationType
}
func (t *BaseTransition) Matches(symbol, minVocabSymbol, maxVocabSymbol int) bool {
func (t *BaseTransition) Matches(_, _, _ int) bool {
panic("Not implemented")
}
@@ -89,6 +89,7 @@ const (
TransitionPRECEDENCE = 10
)
//goland:noinspection GoUnusedGlobalVariable
var TransitionserializationNames = []string{
"INVALID",
"EPSILON",
@@ -127,19 +128,22 @@ var TransitionserializationNames = []string{
// TransitionPRECEDENCE
//}
// AtomTransition
// TODO: make all transitions sets? no, should remove set edges
type AtomTransition struct {
*BaseTransition
BaseTransition
}
func NewAtomTransition(target ATNState, intervalSet int) *AtomTransition {
t := new(AtomTransition)
t.BaseTransition = NewBaseTransition(target)
t.label = intervalSet // The token type or character value or, signifies special intervalSet.
t := &AtomTransition{
BaseTransition: BaseTransition{
target: target,
serializationType: TransitionATOM,
label: intervalSet,
isEpsilon: false,
},
}
t.intervalSet = t.makeLabel()
t.serializationType = TransitionATOM
return t
}
@@ -150,7 +154,7 @@ func (t *AtomTransition) makeLabel() *IntervalSet {
return s
}
func (t *AtomTransition) Matches(symbol, minVocabSymbol, maxVocabSymbol int) bool {
func (t *AtomTransition) Matches(symbol, _, _ int) bool {
return t.label == symbol
}
@@ -159,48 +163,45 @@ func (t *AtomTransition) String() string {
}
type RuleTransition struct {
*BaseTransition
BaseTransition
followState ATNState
ruleIndex, precedence int
}
func NewRuleTransition(ruleStart ATNState, ruleIndex, precedence int, followState ATNState) *RuleTransition {
t := new(RuleTransition)
t.BaseTransition = NewBaseTransition(ruleStart)
t.ruleIndex = ruleIndex
t.precedence = precedence
t.followState = followState
t.serializationType = TransitionRULE
t.isEpsilon = true
return t
return &RuleTransition{
BaseTransition: BaseTransition{
target: ruleStart,
isEpsilon: true,
serializationType: TransitionRULE,
},
ruleIndex: ruleIndex,
precedence: precedence,
followState: followState,
}
}
func (t *RuleTransition) Matches(symbol, minVocabSymbol, maxVocabSymbol int) bool {
func (t *RuleTransition) Matches(_, _, _ int) bool {
return false
}
type EpsilonTransition struct {
*BaseTransition
BaseTransition
outermostPrecedenceReturn int
}
func NewEpsilonTransition(target ATNState, outermostPrecedenceReturn int) *EpsilonTransition {
t := new(EpsilonTransition)
t.BaseTransition = NewBaseTransition(target)
t.serializationType = TransitionEPSILON
t.isEpsilon = true
t.outermostPrecedenceReturn = outermostPrecedenceReturn
return t
return &EpsilonTransition{
BaseTransition: BaseTransition{
target: target,
serializationType: TransitionEPSILON,
isEpsilon: true,
},
outermostPrecedenceReturn: outermostPrecedenceReturn,
}
}
func (t *EpsilonTransition) Matches(symbol, minVocabSymbol, maxVocabSymbol int) bool {
func (t *EpsilonTransition) Matches(_, _, _ int) bool {
return false
}
@@ -209,19 +210,20 @@ func (t *EpsilonTransition) String() string {
}
type RangeTransition struct {
*BaseTransition
BaseTransition
start, stop int
}
func NewRangeTransition(target ATNState, start, stop int) *RangeTransition {
t := new(RangeTransition)
t.BaseTransition = NewBaseTransition(target)
t.serializationType = TransitionRANGE
t.start = start
t.stop = stop
t := &RangeTransition{
BaseTransition: BaseTransition{
target: target,
serializationType: TransitionRANGE,
isEpsilon: false,
},
start: start,
stop: stop,
}
t.intervalSet = t.makeLabel()
return t
}
@@ -232,7 +234,7 @@ func (t *RangeTransition) makeLabel() *IntervalSet {
return s
}
func (t *RangeTransition) Matches(symbol, minVocabSymbol, maxVocabSymbol int) bool {
func (t *RangeTransition) Matches(symbol, _, _ int) bool {
return symbol >= t.start && symbol <= t.stop
}
@@ -252,40 +254,41 @@ type AbstractPredicateTransition interface {
}
type BaseAbstractPredicateTransition struct {
*BaseTransition
BaseTransition
}
func NewBasePredicateTransition(target ATNState) *BaseAbstractPredicateTransition {
t := new(BaseAbstractPredicateTransition)
t.BaseTransition = NewBaseTransition(target)
return t
return &BaseAbstractPredicateTransition{
BaseTransition: BaseTransition{
target: target,
},
}
}
func (a *BaseAbstractPredicateTransition) IAbstractPredicateTransitionFoo() {}
type PredicateTransition struct {
*BaseAbstractPredicateTransition
BaseAbstractPredicateTransition
isCtxDependent bool
ruleIndex, predIndex int
}
func NewPredicateTransition(target ATNState, ruleIndex, predIndex int, isCtxDependent bool) *PredicateTransition {
t := new(PredicateTransition)
t.BaseAbstractPredicateTransition = NewBasePredicateTransition(target)
t.serializationType = TransitionPREDICATE
t.ruleIndex = ruleIndex
t.predIndex = predIndex
t.isCtxDependent = isCtxDependent // e.g., $i ref in pred
t.isEpsilon = true
return t
return &PredicateTransition{
BaseAbstractPredicateTransition: BaseAbstractPredicateTransition{
BaseTransition: BaseTransition{
target: target,
serializationType: TransitionPREDICATE,
isEpsilon: true,
},
},
isCtxDependent: isCtxDependent,
ruleIndex: ruleIndex,
predIndex: predIndex,
}
}
func (t *PredicateTransition) Matches(symbol, minVocabSymbol, maxVocabSymbol int) bool {
func (t *PredicateTransition) Matches(_, _, _ int) bool {
return false
}
@@ -298,26 +301,25 @@ func (t *PredicateTransition) String() string {
}
type ActionTransition struct {
*BaseTransition
BaseTransition
isCtxDependent bool
ruleIndex, actionIndex, predIndex int
}
func NewActionTransition(target ATNState, ruleIndex, actionIndex int, isCtxDependent bool) *ActionTransition {
t := new(ActionTransition)
t.BaseTransition = NewBaseTransition(target)
t.serializationType = TransitionACTION
t.ruleIndex = ruleIndex
t.actionIndex = actionIndex
t.isCtxDependent = isCtxDependent // e.g., $i ref in pred
t.isEpsilon = true
return t
return &ActionTransition{
BaseTransition: BaseTransition{
target: target,
serializationType: TransitionACTION,
isEpsilon: true,
},
isCtxDependent: isCtxDependent,
ruleIndex: ruleIndex,
actionIndex: actionIndex,
}
}
func (t *ActionTransition) Matches(symbol, minVocabSymbol, maxVocabSymbol int) bool {
func (t *ActionTransition) Matches(_, _, _ int) bool {
return false
}
@@ -326,26 +328,27 @@ func (t *ActionTransition) String() string {
}
type SetTransition struct {
*BaseTransition
BaseTransition
}
func NewSetTransition(target ATNState, set *IntervalSet) *SetTransition {
t := &SetTransition{
BaseTransition: BaseTransition{
target: target,
serializationType: TransitionSET,
},
}
t := new(SetTransition)
t.BaseTransition = NewBaseTransition(target)
t.serializationType = TransitionSET
if set != nil {
t.intervalSet = set
} else {
t.intervalSet = NewIntervalSet()
t.intervalSet.addOne(TokenInvalidType)
}
return t
}
func (t *SetTransition) Matches(symbol, minVocabSymbol, maxVocabSymbol int) bool {
func (t *SetTransition) Matches(symbol, _, _ int) bool {
return t.intervalSet.contains(symbol)
}
@@ -354,16 +357,24 @@ func (t *SetTransition) String() string {
}
type NotSetTransition struct {
*SetTransition
SetTransition
}
func NewNotSetTransition(target ATNState, set *IntervalSet) *NotSetTransition {
t := new(NotSetTransition)
t.SetTransition = NewSetTransition(target, set)
t.serializationType = TransitionNOTSET
t := &NotSetTransition{
SetTransition: SetTransition{
BaseTransition: BaseTransition{
target: target,
serializationType: TransitionNOTSET,
},
},
}
if set != nil {
t.intervalSet = set
} else {
t.intervalSet = NewIntervalSet()
t.intervalSet.addOne(TokenInvalidType)
}
return t
}
@@ -377,16 +388,16 @@ func (t *NotSetTransition) String() string {
}
type WildcardTransition struct {
*BaseTransition
BaseTransition
}
func NewWildcardTransition(target ATNState) *WildcardTransition {
t := new(WildcardTransition)
t.BaseTransition = NewBaseTransition(target)
t.serializationType = TransitionWILDCARD
return t
return &WildcardTransition{
BaseTransition: BaseTransition{
target: target,
serializationType: TransitionWILDCARD,
},
}
}
func (t *WildcardTransition) Matches(symbol, minVocabSymbol, maxVocabSymbol int) bool {
@@ -398,24 +409,24 @@ func (t *WildcardTransition) String() string {
}
type PrecedencePredicateTransition struct {
*BaseAbstractPredicateTransition
BaseAbstractPredicateTransition
precedence int
}
func NewPrecedencePredicateTransition(target ATNState, precedence int) *PrecedencePredicateTransition {
t := new(PrecedencePredicateTransition)
t.BaseAbstractPredicateTransition = NewBasePredicateTransition(target)
t.serializationType = TransitionPRECEDENCE
t.precedence = precedence
t.isEpsilon = true
return t
return &PrecedencePredicateTransition{
BaseAbstractPredicateTransition: BaseAbstractPredicateTransition{
BaseTransition: BaseTransition{
target: target,
serializationType: TransitionPRECEDENCE,
isEpsilon: true,
},
},
precedence: precedence,
}
}
func (t *PrecedencePredicateTransition) Matches(symbol, minVocabSymbol, maxVocabSymbol int) bool {
func (t *PrecedencePredicateTransition) Matches(_, _, _ int) bool {
return false
}

View File

@@ -21,29 +21,23 @@ type Tree interface {
type SyntaxTree interface {
Tree
GetSourceInterval() *Interval
GetSourceInterval() Interval
}
type ParseTree interface {
SyntaxTree
Accept(Visitor ParseTreeVisitor) interface{}
GetText() string
ToStringTree([]string, Recognizer) string
}
type RuleNode interface {
ParseTree
GetRuleContext() RuleContext
GetBaseRuleContext() *BaseRuleContext
}
type TerminalNode interface {
ParseTree
GetSymbol() Token
}
@@ -64,12 +58,12 @@ type BaseParseTreeVisitor struct{}
var _ ParseTreeVisitor = &BaseParseTreeVisitor{}
func (v *BaseParseTreeVisitor) Visit(tree ParseTree) interface{} { return tree.Accept(v) }
func (v *BaseParseTreeVisitor) VisitChildren(node RuleNode) interface{} { return nil }
func (v *BaseParseTreeVisitor) VisitTerminal(node TerminalNode) interface{} { return nil }
func (v *BaseParseTreeVisitor) VisitErrorNode(node ErrorNode) interface{} { return nil }
func (v *BaseParseTreeVisitor) Visit(tree ParseTree) interface{} { return tree.Accept(v) }
func (v *BaseParseTreeVisitor) VisitChildren(_ RuleNode) interface{} { return nil }
func (v *BaseParseTreeVisitor) VisitTerminal(_ TerminalNode) interface{} { return nil }
func (v *BaseParseTreeVisitor) VisitErrorNode(_ ErrorNode) interface{} { return nil }
// TODO
// TODO: Implement this?
//func (this ParseTreeVisitor) Visit(ctx) {
// if (Utils.isArray(ctx)) {
// self := this
@@ -101,15 +95,14 @@ type BaseParseTreeListener struct{}
var _ ParseTreeListener = &BaseParseTreeListener{}
func (l *BaseParseTreeListener) VisitTerminal(node TerminalNode) {}
func (l *BaseParseTreeListener) VisitErrorNode(node ErrorNode) {}
func (l *BaseParseTreeListener) EnterEveryRule(ctx ParserRuleContext) {}
func (l *BaseParseTreeListener) ExitEveryRule(ctx ParserRuleContext) {}
func (l *BaseParseTreeListener) VisitTerminal(_ TerminalNode) {}
func (l *BaseParseTreeListener) VisitErrorNode(_ ErrorNode) {}
func (l *BaseParseTreeListener) EnterEveryRule(_ ParserRuleContext) {}
func (l *BaseParseTreeListener) ExitEveryRule(_ ParserRuleContext) {}
type TerminalNodeImpl struct {
parentCtx RuleContext
symbol Token
symbol Token
}
var _ TerminalNode = &TerminalNodeImpl{}
@@ -123,7 +116,7 @@ func NewTerminalNodeImpl(symbol Token) *TerminalNodeImpl {
return tn
}
func (t *TerminalNodeImpl) GetChild(i int) Tree {
func (t *TerminalNodeImpl) GetChild(_ int) Tree {
return nil
}
@@ -131,7 +124,7 @@ func (t *TerminalNodeImpl) GetChildren() []Tree {
return nil
}
func (t *TerminalNodeImpl) SetChildren(tree []Tree) {
func (t *TerminalNodeImpl) SetChildren(_ []Tree) {
panic("Cannot set children on terminal node")
}
@@ -151,7 +144,7 @@ func (t *TerminalNodeImpl) GetPayload() interface{} {
return t.symbol
}
func (t *TerminalNodeImpl) GetSourceInterval() *Interval {
func (t *TerminalNodeImpl) GetSourceInterval() Interval {
if t.symbol == nil {
return TreeInvalidInterval
}
@@ -179,7 +172,7 @@ func (t *TerminalNodeImpl) String() string {
return t.symbol.GetText()
}
func (t *TerminalNodeImpl) ToStringTree(s []string, r Recognizer) string {
func (t *TerminalNodeImpl) ToStringTree(_ []string, _ Recognizer) string {
return t.String()
}
@@ -214,10 +207,9 @@ func NewParseTreeWalker() *ParseTreeWalker {
return new(ParseTreeWalker)
}
// Performs a walk on the given parse tree starting at the root and going down recursively
// with depth-first search. On each node, EnterRule is called before
// recursively walking down into child nodes, then
// ExitRule is called after the recursive call to wind up.
// Walk performs a walk on the given parse tree starting at the root and going down recursively
// with depth-first search. On each node, [EnterRule] is called before
// recursively walking down into child nodes, then [ExitRule] is called after the recursive call to wind up.
func (p *ParseTreeWalker) Walk(listener ParseTreeListener, t Tree) {
switch tt := t.(type) {
case ErrorNode:
@@ -234,7 +226,7 @@ func (p *ParseTreeWalker) Walk(listener ParseTreeListener, t Tree) {
}
}
// Enters a grammar rule by first triggering the generic event {@link ParseTreeListener//EnterEveryRule}
// EnterRule enters a grammar rule by first triggering the generic event [ParseTreeListener].[EnterEveryRule]
// then by triggering the event specific to the given parse tree node
func (p *ParseTreeWalker) EnterRule(listener ParseTreeListener, r RuleNode) {
ctx := r.GetRuleContext().(ParserRuleContext)
@@ -242,12 +234,71 @@ func (p *ParseTreeWalker) EnterRule(listener ParseTreeListener, r RuleNode) {
ctx.EnterRule(listener)
}
// Exits a grammar rule by first triggering the event specific to the given parse tree node
// then by triggering the generic event {@link ParseTreeListener//ExitEveryRule}
// ExitRule exits a grammar rule by first triggering the event specific to the given parse tree node
// then by triggering the generic event [ParseTreeListener].ExitEveryRule
func (p *ParseTreeWalker) ExitRule(listener ParseTreeListener, r RuleNode) {
ctx := r.GetRuleContext().(ParserRuleContext)
ctx.ExitRule(listener)
listener.ExitEveryRule(ctx)
}
//goland:noinspection GoUnusedGlobalVariable
var ParseTreeWalkerDefault = NewParseTreeWalker()
type IterativeParseTreeWalker struct {
*ParseTreeWalker
}
//goland:noinspection GoUnusedExportedFunction
func NewIterativeParseTreeWalker() *IterativeParseTreeWalker {
return new(IterativeParseTreeWalker)
}
func (i *IterativeParseTreeWalker) Walk(listener ParseTreeListener, t Tree) {
var stack []Tree
var indexStack []int
currentNode := t
currentIndex := 0
for currentNode != nil {
// pre-order visit
switch tt := currentNode.(type) {
case ErrorNode:
listener.VisitErrorNode(tt)
case TerminalNode:
listener.VisitTerminal(tt)
default:
i.EnterRule(listener, currentNode.(RuleNode))
}
// Move down to first child, if exists
if currentNode.GetChildCount() > 0 {
stack = append(stack, currentNode)
indexStack = append(indexStack, currentIndex)
currentIndex = 0
currentNode = currentNode.GetChild(0)
continue
}
for {
// post-order visit
if ruleNode, ok := currentNode.(RuleNode); ok {
i.ExitRule(listener, ruleNode)
}
// No parent, so no siblings
if len(stack) == 0 {
currentNode = nil
currentIndex = 0
break
}
// Move to next sibling if possible
currentIndex++
if stack[len(stack)-1].GetChildCount() > currentIndex {
currentNode = stack[len(stack)-1].GetChild(currentIndex)
break
}
// No next, sibling, so move up
currentNode, stack = stack[len(stack)-1], stack[:len(stack)-1]
currentIndex, indexStack = indexStack[len(indexStack)-1], indexStack[:len(indexStack)-1]
}
}
}

View File

@@ -8,10 +8,8 @@ import "fmt"
/** A set of utility routines useful for all kinds of ANTLR trees. */
// Print out a whole tree in LISP form. {@link //getNodeText} is used on the
//
// node payloads to get the text for the nodes. Detect
// parse trees and extract data appropriately.
// TreesStringTree prints out a whole tree in LISP form. [getNodeText] is used on the
// node payloads to get the text for the nodes. Detects parse trees and extracts data appropriately.
func TreesStringTree(tree Tree, ruleNames []string, recog Recognizer) string {
if recog != nil {
@@ -32,7 +30,7 @@ func TreesStringTree(tree Tree, ruleNames []string, recog Recognizer) string {
}
for i := 1; i < c; i++ {
s = TreesStringTree(tree.GetChild(i), ruleNames, nil)
res += (" " + s)
res += " " + s
}
res += ")"
return res
@@ -62,7 +60,7 @@ func TreesGetNodeText(t Tree, ruleNames []string, recog Parser) string {
}
}
// no recog for rule names
// no recognition for rule names
payload := t.GetPayload()
if p2, ok := payload.(Token); ok {
return p2.GetText()
@@ -71,7 +69,9 @@ func TreesGetNodeText(t Tree, ruleNames []string, recog Parser) string {
return fmt.Sprint(t.GetPayload())
}
// Return ordered list of all children of this node
// TreesGetChildren returns am ordered list of all children of this node
//
//goland:noinspection GoUnusedExportedFunction
func TreesGetChildren(t Tree) []Tree {
list := make([]Tree, 0)
for i := 0; i < t.GetChildCount(); i++ {
@@ -80,9 +80,10 @@ func TreesGetChildren(t Tree) []Tree {
return list
}
// Return a list of all ancestors of this node. The first node of
// TreesgetAncestors returns a list of all ancestors of this node. The first node of list is the root
// and the last node is the parent of this node.
//
// list is the root and the last is the parent of this node.
//goland:noinspection GoUnusedExportedFunction
func TreesgetAncestors(t Tree) []Tree {
ancestors := make([]Tree, 0)
t = t.GetParent()
@@ -94,10 +95,12 @@ func TreesgetAncestors(t Tree) []Tree {
return ancestors
}
//goland:noinspection GoUnusedExportedFunction
func TreesFindAllTokenNodes(t ParseTree, ttype int) []ParseTree {
return TreesfindAllNodes(t, ttype, true)
}
//goland:noinspection GoUnusedExportedFunction
func TreesfindAllRuleNodes(t ParseTree, ruleIndex int) []ParseTree {
return TreesfindAllNodes(t, ruleIndex, false)
}
@@ -129,6 +132,7 @@ func treesFindAllNodes(t ParseTree, index int, findTokens bool, nodes *[]ParseTr
}
}
//goland:noinspection GoUnusedExportedFunction
func TreesDescendants(t ParseTree) []ParseTree {
nodes := []ParseTree{t}
for i := 0; i < t.GetChildCount(); i++ {

View File

@@ -9,8 +9,10 @@ import (
"errors"
"fmt"
"math/bits"
"os"
"strconv"
"strings"
"syscall"
)
func intMin(a, b int) int {
@@ -31,7 +33,7 @@ func intMax(a, b int) int {
type IntStack []int
var ErrEmptyStack = errors.New("Stack is empty")
var ErrEmptyStack = errors.New("stack is empty")
func (s *IntStack) Pop() (int, error) {
l := len(*s) - 1
@@ -47,33 +49,13 @@ func (s *IntStack) Push(e int) {
*s = append(*s, e)
}
type comparable interface {
Equals(other Collectable[any]) bool
}
func standardEqualsFunction(a Collectable[any], b Collectable[any]) bool {
return a.Equals(b)
}
func standardHashFunction(a interface{}) int {
if h, ok := a.(hasher); ok {
return h.Hash()
}
panic("Not Hasher")
}
type hasher interface {
Hash() int
}
const bitsPerWord = 64
func indexForBit(bit int) int {
return bit / bitsPerWord
}
//goland:noinspection GoUnusedExportedFunction,GoUnusedFunction
func wordForBit(data []uint64, bit int) uint64 {
idx := indexForBit(bit)
if idx >= len(data) {
@@ -94,6 +76,8 @@ type BitSet struct {
data []uint64
}
// NewBitSet creates a new bitwise set
// TODO: See if we can replace with the standard library's BitSet
func NewBitSet() *BitSet {
return &BitSet{}
}
@@ -123,7 +107,7 @@ func (b *BitSet) or(set *BitSet) {
setLen := set.minLen()
maxLen := intMax(bLen, setLen)
if maxLen > len(b.data) {
// Increase the size of len(b.data) to repesent the bits in both sets.
// Increase the size of len(b.data) to represent the bits in both sets.
data := make([]uint64, maxLen)
copy(data, b.data)
b.data = data
@@ -246,37 +230,6 @@ func (a *AltDict) values() []interface{} {
return vs
}
type DoubleDict struct {
data map[int]map[int]interface{}
}
func NewDoubleDict() *DoubleDict {
dd := new(DoubleDict)
dd.data = make(map[int]map[int]interface{})
return dd
}
func (d *DoubleDict) Get(a, b int) interface{} {
data := d.data[a]
if data == nil {
return nil
}
return data[b]
}
func (d *DoubleDict) set(a, b int, o interface{}) {
data := d.data[a]
if data == nil {
data = make(map[int]interface{})
d.data[a] = data
}
data[b] = o
}
func EscapeWhitespace(s string, escapeSpaces bool) string {
s = strings.Replace(s, "\t", "\\t", -1)
@@ -288,6 +241,7 @@ func EscapeWhitespace(s string, escapeSpaces bool) string {
return s
}
//goland:noinspection GoUnusedExportedFunction
func TerminalNodeToStringArray(sa []TerminalNode) []string {
st := make([]string, len(sa))
@@ -298,6 +252,7 @@ func TerminalNodeToStringArray(sa []TerminalNode) []string {
return st
}
//goland:noinspection GoUnusedExportedFunction
func PrintArrayJavaStyle(sa []string) string {
var buffer bytes.Buffer
@@ -350,3 +305,24 @@ func murmurFinish(h int, numberOfWords int) int {
return int(hash)
}
func isDirectory(dir string) (bool, error) {
fileInfo, err := os.Stat(dir)
if err != nil {
switch {
case errors.Is(err, syscall.ENOENT):
// The given directory does not exist, so we will try to create it
//
err = os.MkdirAll(dir, 0755)
if err != nil {
return false, err
}
return true, nil
case err != nil:
return false, err
default:
}
}
return fileInfo.IsDir(), err
}

View File

@@ -10,9 +10,12 @@ go_library(
"cel.go",
"decls.go",
"env.go",
"folding.go",
"io.go",
"inlining.go",
"library.go",
"macro.go",
"optimizer.go",
"options.go",
"program.go",
"validator.go",
@@ -56,7 +59,11 @@ go_test(
"cel_test.go",
"decls_test.go",
"env_test.go",
"folding_test.go",
"io_test.go",
"inlining_test.go",
"optimizer_test.go",
"validator_test.go",
],
data = [
"//cel/testdata:gen_test_fds",
@@ -70,6 +77,7 @@ go_test(
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"//ext:go_default_library",
"//test:go_default_library",
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",

View File

@@ -353,43 +353,3 @@ func ExprDeclToDeclaration(d *exprpb.Decl) (EnvOption, error) {
return nil, fmt.Errorf("unsupported decl: %v", d)
}
}
func typeValueToKind(tv ref.Type) (Kind, error) {
switch tv {
case types.BoolType:
return BoolKind, nil
case types.DoubleType:
return DoubleKind, nil
case types.IntType:
return IntKind, nil
case types.UintType:
return UintKind, nil
case types.ListType:
return ListKind, nil
case types.MapType:
return MapKind, nil
case types.StringType:
return StringKind, nil
case types.BytesType:
return BytesKind, nil
case types.DurationType:
return DurationKind, nil
case types.TimestampType:
return TimestampKind, nil
case types.NullType:
return NullTypeKind, nil
case types.TypeType:
return TypeKind, nil
default:
switch tv.TypeName() {
case "dyn":
return DynKind, nil
case "google.protobuf.Any":
return AnyKind, nil
case "optional":
return OpaqueKind, nil
default:
return 0, fmt.Errorf("no known conversion for type of %s", tv.TypeName())
}
}
}

View File

@@ -38,26 +38,42 @@ type Source = common.Source
// Ast representing the checked or unchecked expression, its source, and related metadata such as
// source position information.
type Ast struct {
expr *exprpb.Expr
info *exprpb.SourceInfo
source Source
refMap map[int64]*celast.ReferenceInfo
typeMap map[int64]*types.Type
source Source
impl *celast.AST
}
// NativeRep converts the AST to a Go-native representation.
func (ast *Ast) NativeRep() *celast.AST {
return ast.impl
}
// Expr returns the proto serializable instance of the parsed/checked expression.
//
// Deprecated: prefer cel.AstToCheckedExpr() or cel.AstToParsedExpr() and call GetExpr()
// the result instead.
func (ast *Ast) Expr() *exprpb.Expr {
return ast.expr
if ast == nil {
return nil
}
pbExpr, _ := celast.ExprToProto(ast.impl.Expr())
return pbExpr
}
// IsChecked returns whether the Ast value has been successfully type-checked.
func (ast *Ast) IsChecked() bool {
return ast.typeMap != nil && len(ast.typeMap) > 0
if ast == nil {
return false
}
return ast.impl.IsChecked()
}
// SourceInfo returns character offset and newline position information about expression elements.
func (ast *Ast) SourceInfo() *exprpb.SourceInfo {
return ast.info
if ast == nil {
return nil
}
pbInfo, _ := celast.SourceInfoToProto(ast.impl.SourceInfo())
return pbInfo
}
// ResultType returns the output type of the expression if the Ast has been type-checked, else
@@ -65,9 +81,6 @@ func (ast *Ast) SourceInfo() *exprpb.SourceInfo {
//
// Deprecated: use OutputType
func (ast *Ast) ResultType() *exprpb.Type {
if !ast.IsChecked() {
return chkdecls.Dyn
}
out := ast.OutputType()
t, err := TypeToExprType(out)
if err != nil {
@@ -79,16 +92,18 @@ func (ast *Ast) ResultType() *exprpb.Type {
// OutputType returns the output type of the expression if the Ast has been type-checked, else
// returns cel.DynType as the parse step cannot infer types.
func (ast *Ast) OutputType() *Type {
t, found := ast.typeMap[ast.expr.GetId()]
if !found {
return DynType
if ast == nil {
return types.ErrorType
}
return t
return ast.impl.GetType(ast.impl.Expr().ID())
}
// Source returns a view of the input used to create the Ast. This source may be complete or
// constructed from the SourceInfo.
func (ast *Ast) Source() Source {
if ast == nil {
return nil
}
return ast.source
}
@@ -198,29 +213,28 @@ func NewCustomEnv(opts ...EnvOption) (*Env, error) {
// It is possible to have both non-nil Ast and Issues values returned from this call: however,
// the mere presence of an Ast does not imply that it is valid for use.
func (e *Env) Check(ast *Ast) (*Ast, *Issues) {
// Note, errors aren't currently possible on the Ast to ParsedExpr conversion.
pe, _ := AstToParsedExpr(ast)
// Construct the internal checker env, erroring if there is an issue adding the declarations.
chk, err := e.initChecker()
if err != nil {
errs := common.NewErrors(ast.Source())
errs.ReportError(common.NoLocation, err.Error())
return nil, NewIssuesWithSourceInfo(errs, ast.SourceInfo())
return nil, NewIssuesWithSourceInfo(errs, ast.impl.SourceInfo())
}
res, errs := checker.Check(pe, ast.Source(), chk)
checked, errs := checker.Check(ast.impl, ast.Source(), chk)
if len(errs.GetErrors()) > 0 {
return nil, NewIssuesWithSourceInfo(errs, ast.SourceInfo())
return nil, NewIssuesWithSourceInfo(errs, ast.impl.SourceInfo())
}
// Manually create the Ast to ensure that the Ast source information (which may be more
// detailed than the information provided by Check), is returned to the caller.
ast = &Ast{
source: ast.Source(),
expr: res.Expr,
info: res.SourceInfo,
refMap: res.ReferenceMap,
typeMap: res.TypeMap}
source: ast.Source(),
impl: checked}
// Avoid creating a validator config if it's not needed.
if len(e.validators) == 0 {
return ast, nil
}
// Generate a validator configuration from the set of configured validators.
vConfig := newValidatorConfig()
@@ -230,9 +244,9 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) {
}
}
// Apply additional validators on the type-checked result.
iss := NewIssuesWithSourceInfo(errs, ast.SourceInfo())
iss := NewIssuesWithSourceInfo(errs, ast.impl.SourceInfo())
for _, v := range e.validators {
v.Validate(e, vConfig, res, iss)
v.Validate(e, vConfig, checked, iss)
}
if iss.Err() != nil {
return nil, iss
@@ -429,16 +443,11 @@ func (e *Env) Parse(txt string) (*Ast, *Issues) {
// It is possible to have both non-nil Ast and Issues values returned from this call; however,
// the mere presence of an Ast does not imply that it is valid for use.
func (e *Env) ParseSource(src Source) (*Ast, *Issues) {
res, errs := e.prsr.Parse(src)
parsed, errs := e.prsr.Parse(src)
if len(errs.GetErrors()) > 0 {
return nil, &Issues{errs: errs}
}
// Manually create the Ast to ensure that the text source information is propagated on
// subsequent calls to Check.
return &Ast{
source: src,
expr: res.GetExpr(),
info: res.GetSourceInfo()}, nil
return &Ast{source: src, impl: parsed}, nil
}
// Program generates an evaluable instance of the Ast within the environment (Env).
@@ -534,8 +543,9 @@ func (e *Env) PartialVars(vars any) (interpreter.PartialActivation, error) {
// TODO: Consider adding an option to generate a Program.Residual to avoid round-tripping to an
// Ast format and then Program again.
func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) {
pruned := interpreter.PruneAst(a.Expr(), a.SourceInfo().GetMacroCalls(), details.State())
expr, err := AstToString(ParsedExprToAst(pruned))
pruned := interpreter.PruneAst(a.impl.Expr(), a.impl.SourceInfo().MacroCalls(), details.State())
newAST := &Ast{source: a.Source(), impl: pruned}
expr, err := AstToString(newAST)
if err != nil {
return nil, err
}
@@ -556,16 +566,10 @@ func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) {
// EstimateCost estimates the cost of a type checked CEL expression using the length estimates of input data and
// extension functions provided by estimator.
func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator, opts ...checker.CostOption) (checker.CostEstimate, error) {
checked := &celast.CheckedAST{
Expr: ast.Expr(),
SourceInfo: ast.SourceInfo(),
TypeMap: ast.typeMap,
ReferenceMap: ast.refMap,
}
extendedOpts := make([]checker.CostOption, 0, len(e.costOptions))
extendedOpts = append(extendedOpts, opts...)
extendedOpts = append(extendedOpts, e.costOptions...)
return checker.Cost(checked, estimator, extendedOpts...)
return checker.Cost(ast.impl, estimator, extendedOpts...)
}
// configure applies a series of EnvOptions to the current environment.
@@ -707,7 +711,7 @@ type Error = common.Error
// Note: in the future, non-fatal warnings and notices may be inspectable via the Issues struct.
type Issues struct {
errs *common.Errors
info *exprpb.SourceInfo
info *celast.SourceInfo
}
// NewIssues returns an Issues struct from a common.Errors object.
@@ -718,7 +722,7 @@ func NewIssues(errs *common.Errors) *Issues {
// NewIssuesWithSourceInfo returns an Issues struct from a common.Errors object with SourceInfo metatata
// which can be used with the `ReportErrorAtID` method for additional error reports within the context
// information that's inferred from an expression id.
func NewIssuesWithSourceInfo(errs *common.Errors, info *exprpb.SourceInfo) *Issues {
func NewIssuesWithSourceInfo(errs *common.Errors, info *celast.SourceInfo) *Issues {
return &Issues{
errs: errs,
info: info,
@@ -768,30 +772,7 @@ func (i *Issues) String() string {
// The source metadata for the expression at `id`, if present, is attached to the error report.
// To ensure that source metadata is attached to error reports, use NewIssuesWithSourceInfo.
func (i *Issues) ReportErrorAtID(id int64, message string, args ...any) {
i.errs.ReportErrorAtID(id, locationByID(id, i.info), message, args...)
}
// locationByID returns a common.Location given an expression id.
//
// TODO: move this functionality into the native SourceInfo and an overhaul of the common.Source
// as this implementation relies on the abstractions present in the protobuf SourceInfo object,
// and is replicated in the checker.
func locationByID(id int64, sourceInfo *exprpb.SourceInfo) common.Location {
positions := sourceInfo.GetPositions()
var line = 1
if offset, found := positions[id]; found {
col := int(offset)
for _, lineOffset := range sourceInfo.GetLineOffsets() {
if lineOffset < offset {
line++
col = int(offset - lineOffset)
} else {
break
}
}
return common.NewLocation(line, col)
}
return common.NoLocation
i.errs.ReportErrorAtID(id, i.info.GetStartLocation(id), message, args...)
}
// getStdEnv lazy initializes the CEL standard environment.
@@ -822,6 +803,13 @@ func (p *interopCELTypeProvider) FindStructType(typeName string) (*types.Type, b
return nil, false
}
// FindStructFieldNames returns an empty set of field for the interop provider.
//
// To inspect the field names, migrate to a `types.Provider` implementation.
func (p *interopCELTypeProvider) FindStructFieldNames(typeName string) ([]string, bool) {
return []string{}, false
}
// FindStructFieldType returns a types.FieldType instance for the given fully-qualified typeName and field
// name, if one exists.
//

559
vendor/github.com/google/cel-go/cel/folding.go generated vendored Normal file
View File

@@ -0,0 +1,559 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cel
import (
"fmt"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
)
// ConstantFoldingOption defines a functional option for configuring constant folding.
type ConstantFoldingOption func(opt *constantFoldingOptimizer) (*constantFoldingOptimizer, error)
// MaxConstantFoldIterations limits the number of times literals may be folding during optimization.
//
// Defaults to 100 if not set.
func MaxConstantFoldIterations(limit int) ConstantFoldingOption {
return func(opt *constantFoldingOptimizer) (*constantFoldingOptimizer, error) {
opt.maxFoldIterations = limit
return opt, nil
}
}
// NewConstantFoldingOptimizer creates an optimizer which inlines constant scalar an aggregate
// literal values within function calls and select statements with their evaluated result.
func NewConstantFoldingOptimizer(opts ...ConstantFoldingOption) (ASTOptimizer, error) {
folder := &constantFoldingOptimizer{
maxFoldIterations: defaultMaxConstantFoldIterations,
}
var err error
for _, o := range opts {
folder, err = o(folder)
if err != nil {
return nil, err
}
}
return folder, nil
}
type constantFoldingOptimizer struct {
maxFoldIterations int
}
// Optimize queries the expression graph for scalar and aggregate literal expressions within call and
// select statements and then evaluates them and replaces the call site with the literal result.
//
// Note: only values which can be represented as literals in CEL syntax are supported.
func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.AST {
root := ast.NavigateAST(a)
// Walk the list of foldable expression and continue to fold until there are no more folds left.
// All of the fold candidates returned by the constantExprMatcher should succeed unless there's
// a logic bug with the selection of expressions.
foldableExprs := ast.MatchDescendants(root, constantExprMatcher)
foldCount := 0
for len(foldableExprs) != 0 && foldCount < opt.maxFoldIterations {
for _, fold := range foldableExprs {
// If the expression could be folded because it's a non-strict call, and the
// branches are pruned, continue to the next fold.
if fold.Kind() == ast.CallKind && maybePruneBranches(ctx, fold) {
continue
}
// Otherwise, assume all context is needed to evaluate the expression.
err := tryFold(ctx, a, fold)
if err != nil {
ctx.ReportErrorAtID(fold.ID(), "constant-folding evaluation failed: %v", err.Error())
return a
}
}
foldCount++
foldableExprs = ast.MatchDescendants(root, constantExprMatcher)
}
// Once all of the constants have been folded, try to run through the remaining comprehensions
// one last time. In this case, there's no guarantee they'll run, so we only update the
// target comprehension node with the literal value if the evaluation succeeds.
for _, compre := range ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind)) {
tryFold(ctx, a, compre)
}
// If the output is a list, map, or struct which contains optional entries, then prune it
// to make sure that the optionals, if resolved, do not surface in the output literal.
pruneOptionalElements(ctx, root)
// Ensure that all intermediate values in the folded expression can be represented as valid
// CEL literals within the AST structure. Use `PostOrderVisit` rather than `MatchDescendents`
// to avoid extra allocations during this final pass through the AST.
ast.PostOrderVisit(root, ast.NewExprVisitor(func(e ast.Expr) {
if e.Kind() != ast.LiteralKind {
return
}
val := e.AsLiteral()
adapted, err := adaptLiteral(ctx, val)
if err != nil {
ctx.ReportErrorAtID(root.ID(), "constant-folding evaluation failed: %v", err.Error())
return
}
ctx.UpdateExpr(e, adapted)
}))
return a
}
// tryFold attempts to evaluate a sub-expression to a literal.
//
// If the evaluation succeeds, the input expr value will be modified to become a literal, otherwise
// the method will return an error.
func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error {
// Assume all context is needed to evaluate the expression.
subAST := &Ast{
impl: ast.NewCheckedAST(ast.NewAST(expr, a.SourceInfo()), a.TypeMap(), a.ReferenceMap()),
}
prg, err := ctx.Program(subAST)
if err != nil {
return err
}
out, _, err := prg.Eval(NoVars())
if err != nil {
return err
}
// Update the fold expression to be a literal.
ctx.UpdateExpr(expr, ctx.NewLiteral(out))
return nil
}
// maybePruneBranches inspects the non-strict call expression to determine whether
// a branch can be removed. Evaluation will naturally prune logical and / or calls,
// but conditional will not be pruned cleanly, so this is one small area where the
// constant folding step reimplements a portion of the evaluator.
func maybePruneBranches(ctx *OptimizerContext, expr ast.NavigableExpr) bool {
call := expr.AsCall()
args := call.Args()
switch call.FunctionName() {
case operators.LogicalAnd, operators.LogicalOr:
return maybeShortcircuitLogic(ctx, call.FunctionName(), args, expr)
case operators.Conditional:
cond := args[0]
truthy := args[1]
falsy := args[2]
if cond.Kind() != ast.LiteralKind {
return false
}
if cond.AsLiteral() == types.True {
ctx.UpdateExpr(expr, truthy)
} else {
ctx.UpdateExpr(expr, falsy)
}
return true
case operators.In:
haystack := args[1]
if haystack.Kind() == ast.ListKind && haystack.AsList().Size() == 0 {
ctx.UpdateExpr(expr, ctx.NewLiteral(types.False))
return true
}
needle := args[0]
if needle.Kind() == ast.LiteralKind && haystack.Kind() == ast.ListKind {
needleValue := needle.AsLiteral()
list := haystack.AsList()
for _, e := range list.Elements() {
if e.Kind() == ast.LiteralKind && e.AsLiteral().Equal(needleValue) == types.True {
ctx.UpdateExpr(expr, ctx.NewLiteral(types.True))
return true
}
}
}
}
return false
}
func maybeShortcircuitLogic(ctx *OptimizerContext, function string, args []ast.Expr, expr ast.NavigableExpr) bool {
shortcircuit := types.False
skip := types.True
if function == operators.LogicalOr {
shortcircuit = types.True
skip = types.False
}
newArgs := []ast.Expr{}
for _, arg := range args {
if arg.Kind() != ast.LiteralKind {
newArgs = append(newArgs, arg)
continue
}
if arg.AsLiteral() == skip {
continue
}
if arg.AsLiteral() == shortcircuit {
ctx.UpdateExpr(expr, arg)
return true
}
}
if len(newArgs) == 0 {
newArgs = append(newArgs, args[0])
ctx.UpdateExpr(expr, newArgs[0])
return true
}
if len(newArgs) == 1 {
ctx.UpdateExpr(expr, newArgs[0])
return true
}
ctx.UpdateExpr(expr, ctx.NewCall(function, newArgs...))
return true
}
// pruneOptionalElements works from the bottom up to resolve optional elements within
// aggregate literals.
//
// Note, many aggregate literals will be resolved as arguments to functions or select
// statements, so this method exists to handle the case where the literal could not be
// fully resolved or exists outside of a call, select, or comprehension context.
func pruneOptionalElements(ctx *OptimizerContext, root ast.NavigableExpr) {
aggregateLiterals := ast.MatchDescendants(root, aggregateLiteralMatcher)
for _, lit := range aggregateLiterals {
switch lit.Kind() {
case ast.ListKind:
pruneOptionalListElements(ctx, lit)
case ast.MapKind:
pruneOptionalMapEntries(ctx, lit)
case ast.StructKind:
pruneOptionalStructFields(ctx, lit)
}
}
}
func pruneOptionalListElements(ctx *OptimizerContext, e ast.Expr) {
l := e.AsList()
elems := l.Elements()
optIndices := l.OptionalIndices()
if len(optIndices) == 0 {
return
}
updatedElems := []ast.Expr{}
updatedIndices := []int32{}
newOptIndex := -1
for _, e := range elems {
newOptIndex++
if !l.IsOptional(int32(newOptIndex)) {
updatedElems = append(updatedElems, e)
continue
}
if e.Kind() != ast.LiteralKind {
updatedElems = append(updatedElems, e)
updatedIndices = append(updatedIndices, int32(newOptIndex))
continue
}
optElemVal, ok := e.AsLiteral().(*types.Optional)
if !ok {
updatedElems = append(updatedElems, e)
updatedIndices = append(updatedIndices, int32(newOptIndex))
continue
}
if !optElemVal.HasValue() {
newOptIndex-- // Skipping causes the list to get smaller.
continue
}
ctx.UpdateExpr(e, ctx.NewLiteral(optElemVal.GetValue()))
updatedElems = append(updatedElems, e)
}
ctx.UpdateExpr(e, ctx.NewList(updatedElems, updatedIndices))
}
func pruneOptionalMapEntries(ctx *OptimizerContext, e ast.Expr) {
m := e.AsMap()
entries := m.Entries()
updatedEntries := []ast.EntryExpr{}
modified := false
for _, e := range entries {
entry := e.AsMapEntry()
key := entry.Key()
val := entry.Value()
// If the entry is not optional, or the value-side of the optional hasn't
// been resolved to a literal, then preserve the entry as-is.
if !entry.IsOptional() || val.Kind() != ast.LiteralKind {
updatedEntries = append(updatedEntries, e)
continue
}
optElemVal, ok := val.AsLiteral().(*types.Optional)
if !ok {
updatedEntries = append(updatedEntries, e)
continue
}
// When the key is not a literal, but the value is, then it needs to be
// restored to an optional value.
if key.Kind() != ast.LiteralKind {
undoOptVal, err := adaptLiteral(ctx, optElemVal)
if err != nil {
ctx.ReportErrorAtID(val.ID(), "invalid map value literal %v: %v", optElemVal, err)
}
ctx.UpdateExpr(val, undoOptVal)
updatedEntries = append(updatedEntries, e)
continue
}
modified = true
if !optElemVal.HasValue() {
continue
}
ctx.UpdateExpr(val, ctx.NewLiteral(optElemVal.GetValue()))
updatedEntry := ctx.NewMapEntry(key, val, false)
updatedEntries = append(updatedEntries, updatedEntry)
}
if modified {
ctx.UpdateExpr(e, ctx.NewMap(updatedEntries))
}
}
func pruneOptionalStructFields(ctx *OptimizerContext, e ast.Expr) {
s := e.AsStruct()
fields := s.Fields()
updatedFields := []ast.EntryExpr{}
modified := false
for _, f := range fields {
field := f.AsStructField()
val := field.Value()
if !field.IsOptional() || val.Kind() != ast.LiteralKind {
updatedFields = append(updatedFields, f)
continue
}
optElemVal, ok := val.AsLiteral().(*types.Optional)
if !ok {
updatedFields = append(updatedFields, f)
continue
}
modified = true
if !optElemVal.HasValue() {
continue
}
ctx.UpdateExpr(val, ctx.NewLiteral(optElemVal.GetValue()))
updatedField := ctx.NewStructField(field.Name(), val, false)
updatedFields = append(updatedFields, updatedField)
}
if modified {
ctx.UpdateExpr(e, ctx.NewStruct(s.TypeName(), updatedFields))
}
}
// adaptLiteral converts a runtime CEL value to its equivalent literal expression.
//
// For strongly typed values, the type-provider will be used to reconstruct the fields
// which are present in the literal and their equivalent initialization values.
func adaptLiteral(ctx *OptimizerContext, val ref.Val) (ast.Expr, error) {
switch t := val.Type().(type) {
case *types.Type:
switch t {
case types.BoolType, types.BytesType, types.DoubleType, types.IntType,
types.NullType, types.StringType, types.UintType:
return ctx.NewLiteral(val), nil
case types.DurationType:
return ctx.NewCall(
overloads.TypeConvertDuration,
ctx.NewLiteral(val.ConvertToType(types.StringType)),
), nil
case types.TimestampType:
return ctx.NewCall(
overloads.TypeConvertTimestamp,
ctx.NewLiteral(val.ConvertToType(types.StringType)),
), nil
case types.OptionalType:
opt := val.(*types.Optional)
if !opt.HasValue() {
return ctx.NewCall("optional.none"), nil
}
target, err := adaptLiteral(ctx, opt.GetValue())
if err != nil {
return nil, err
}
return ctx.NewCall("optional.of", target), nil
case types.TypeType:
return ctx.NewIdent(val.(*types.Type).TypeName()), nil
case types.ListType:
l, ok := val.(traits.Lister)
if !ok {
return nil, fmt.Errorf("failed to adapt %v to literal", val)
}
elems := make([]ast.Expr, l.Size().(types.Int))
idx := 0
it := l.Iterator()
for it.HasNext() == types.True {
elemVal := it.Next()
elemExpr, err := adaptLiteral(ctx, elemVal)
if err != nil {
return nil, err
}
elems[idx] = elemExpr
idx++
}
return ctx.NewList(elems, []int32{}), nil
case types.MapType:
m, ok := val.(traits.Mapper)
if !ok {
return nil, fmt.Errorf("failed to adapt %v to literal", val)
}
entries := make([]ast.EntryExpr, m.Size().(types.Int))
idx := 0
it := m.Iterator()
for it.HasNext() == types.True {
keyVal := it.Next()
keyExpr, err := adaptLiteral(ctx, keyVal)
if err != nil {
return nil, err
}
valVal := m.Get(keyVal)
valExpr, err := adaptLiteral(ctx, valVal)
if err != nil {
return nil, err
}
entries[idx] = ctx.NewMapEntry(keyExpr, valExpr, false)
idx++
}
return ctx.NewMap(entries), nil
default:
provider := ctx.CELTypeProvider()
fields, found := provider.FindStructFieldNames(t.TypeName())
if !found {
return nil, fmt.Errorf("failed to adapt %v to literal", val)
}
tester := val.(traits.FieldTester)
indexer := val.(traits.Indexer)
fieldInits := []ast.EntryExpr{}
for _, f := range fields {
field := types.String(f)
if tester.IsSet(field) != types.True {
continue
}
fieldVal := indexer.Get(field)
fieldExpr, err := adaptLiteral(ctx, fieldVal)
if err != nil {
return nil, err
}
fieldInits = append(fieldInits, ctx.NewStructField(f, fieldExpr, false))
}
return ctx.NewStruct(t.TypeName(), fieldInits), nil
}
}
return nil, fmt.Errorf("failed to adapt %v to literal", val)
}
// constantExprMatcher matches calls, select statements, and comprehensions whose arguments
// are all constant scalar or aggregate literal values.
//
// Only comprehensions which are not nested are included as possible constant folds, and only
// if all variables referenced in the comprehension stack exist are only iteration or
// accumulation variables.
func constantExprMatcher(e ast.NavigableExpr) bool {
switch e.Kind() {
case ast.CallKind:
return constantCallMatcher(e)
case ast.SelectKind:
sel := e.AsSelect() // guaranteed to be a navigable value
return constantMatcher(sel.Operand().(ast.NavigableExpr))
case ast.ComprehensionKind:
if isNestedComprehension(e) {
return false
}
vars := map[string]bool{}
constantExprs := true
visitor := ast.NewExprVisitor(func(e ast.Expr) {
if e.Kind() == ast.ComprehensionKind {
nested := e.AsComprehension()
vars[nested.AccuVar()] = true
vars[nested.IterVar()] = true
}
if e.Kind() == ast.IdentKind && !vars[e.AsIdent()] {
constantExprs = false
}
})
ast.PreOrderVisit(e, visitor)
return constantExprs
default:
return false
}
}
// constantCallMatcher identifies strict and non-strict calls which can be folded.
func constantCallMatcher(e ast.NavigableExpr) bool {
call := e.AsCall()
children := e.Children()
fnName := call.FunctionName()
if fnName == operators.LogicalAnd {
for _, child := range children {
if child.Kind() == ast.LiteralKind {
return true
}
}
}
if fnName == operators.LogicalOr {
for _, child := range children {
if child.Kind() == ast.LiteralKind {
return true
}
}
}
if fnName == operators.Conditional {
cond := children[0]
if cond.Kind() == ast.LiteralKind && cond.AsLiteral().Type() == types.BoolType {
return true
}
}
if fnName == operators.In {
haystack := children[1]
if haystack.Kind() == ast.ListKind && haystack.AsList().Size() == 0 {
return true
}
needle := children[0]
if needle.Kind() == ast.LiteralKind && haystack.Kind() == ast.ListKind {
needleValue := needle.AsLiteral()
list := haystack.AsList()
for _, e := range list.Elements() {
if e.Kind() == ast.LiteralKind && e.AsLiteral().Equal(needleValue) == types.True {
return true
}
}
}
}
// convert all other calls with constant arguments
for _, child := range children {
if !constantMatcher(child) {
return false
}
}
return true
}
func isNestedComprehension(e ast.NavigableExpr) bool {
parent, found := e.Parent()
for found {
if parent.Kind() == ast.ComprehensionKind {
return true
}
parent, found = parent.Parent()
}
return false
}
func aggregateLiteralMatcher(e ast.NavigableExpr) bool {
return e.Kind() == ast.ListKind || e.Kind() == ast.MapKind || e.Kind() == ast.StructKind
}
var (
constantMatcher = ast.ConstantValueMatcher()
)
const (
defaultMaxConstantFoldIterations = 100
)

228
vendor/github.com/google/cel-go/cel/inlining.go generated vendored Normal file
View File

@@ -0,0 +1,228 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cel
import (
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/traits"
)
// InlineVariable holds a variable name to be matched and an AST representing
// the expression graph which should be used to replace it.
type InlineVariable struct {
name string
alias string
def *ast.AST
}
// Name returns the qualified variable or field selection to replace.
func (v *InlineVariable) Name() string {
return v.name
}
// Alias returns the alias to use when performing cel.bind() calls during inlining.
func (v *InlineVariable) Alias() string {
return v.alias
}
// Expr returns the inlined expression value.
func (v *InlineVariable) Expr() ast.Expr {
return v.def.Expr()
}
// Type indicates the inlined expression type.
func (v *InlineVariable) Type() *Type {
return v.def.GetType(v.def.Expr().ID())
}
// NewInlineVariable declares a variable name to be replaced by a checked expression.
func NewInlineVariable(name string, definition *Ast) *InlineVariable {
return NewInlineVariableWithAlias(name, name, definition)
}
// NewInlineVariableWithAlias declares a variable name to be replaced by a checked expression.
// If the variable occurs more than once, the provided alias will be used to replace the expressions
// where the variable name occurs.
func NewInlineVariableWithAlias(name, alias string, definition *Ast) *InlineVariable {
return &InlineVariable{name: name, alias: alias, def: definition.impl}
}
// NewInliningOptimizer creates and optimizer which replaces variables with expression definitions.
//
// If a variable occurs one time, the variable is replaced by the inline definition. If the
// variable occurs more than once, the variable occurences are replaced by a cel.bind() call.
func NewInliningOptimizer(inlineVars ...*InlineVariable) ASTOptimizer {
return &inliningOptimizer{variables: inlineVars}
}
type inliningOptimizer struct {
variables []*InlineVariable
}
func (opt *inliningOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.AST {
root := ast.NavigateAST(a)
for _, inlineVar := range opt.variables {
matches := ast.MatchDescendants(root, opt.matchVariable(inlineVar.Name()))
// Skip cases where the variable isn't in the expression graph
if len(matches) == 0 {
continue
}
// For a single match, do a direct replacement of the expression sub-graph.
if len(matches) == 1 || !isBindable(matches, inlineVar.Expr(), inlineVar.Type()) {
for _, match := range matches {
// Copy the inlined AST expr and source info.
copyExpr := ctx.CopyASTAndMetadata(inlineVar.def)
opt.inlineExpr(ctx, match, copyExpr, inlineVar.Type())
}
continue
}
// For multiple matches, find the least common ancestor (lca) and insert the
// variable as a cel.bind() macro.
var lca ast.NavigableExpr = root
lcaAncestorCount := 0
ancestors := map[int64]int{}
for _, match := range matches {
// Update the identifier matches with the provided alias.
parent, found := match, true
for found {
ancestorCount, hasAncestor := ancestors[parent.ID()]
if !hasAncestor {
ancestors[parent.ID()] = 1
parent, found = parent.Parent()
continue
}
if lcaAncestorCount < ancestorCount || (lcaAncestorCount == ancestorCount && lca.Depth() < parent.Depth()) {
lca = parent
lcaAncestorCount = ancestorCount
}
ancestors[parent.ID()] = ancestorCount + 1
parent, found = parent.Parent()
}
aliasExpr := ctx.NewIdent(inlineVar.Alias())
opt.inlineExpr(ctx, match, aliasExpr, inlineVar.Type())
}
// Copy the inlined AST expr and source info.
copyExpr := ctx.CopyASTAndMetadata(inlineVar.def)
// Update the least common ancestor by inserting a cel.bind() call to the alias.
inlined, bindMacro := ctx.NewBindMacro(lca.ID(), inlineVar.Alias(), copyExpr, lca)
opt.inlineExpr(ctx, lca, inlined, inlineVar.Type())
ctx.SetMacroCall(lca.ID(), bindMacro)
}
return a
}
// inlineExpr replaces the current expression with the inlined one, unless the location of the inlining
// happens within a presence test, e.g. has(a.b.c) -> inline alpha for a.b.c in which case an attempt is
// made to determine whether the inlined value can be presence or existence tested.
func (opt *inliningOptimizer) inlineExpr(ctx *OptimizerContext, prev ast.NavigableExpr, inlined ast.Expr, inlinedType *Type) {
switch prev.Kind() {
case ast.SelectKind:
sel := prev.AsSelect()
if !sel.IsTestOnly() {
ctx.UpdateExpr(prev, inlined)
return
}
opt.rewritePresenceExpr(ctx, prev, inlined, inlinedType)
default:
ctx.UpdateExpr(prev, inlined)
}
}
// rewritePresenceExpr converts the inlined expression, when it occurs within a has() macro, to type-safe
// expression appropriate for the inlined type, if possible.
//
// If the rewrite is not possible an error is reported at the inline expression site.
func (opt *inliningOptimizer) rewritePresenceExpr(ctx *OptimizerContext, prev, inlined ast.Expr, inlinedType *Type) {
// If the input inlined expression is not a select expression it won't work with the has()
// macro. Attempt to rewrite the presence test in terms of the typed input, otherwise error.
if inlined.Kind() == ast.SelectKind {
presenceTest, hasMacro := ctx.NewHasMacro(prev.ID(), inlined)
ctx.UpdateExpr(prev, presenceTest)
ctx.SetMacroCall(prev.ID(), hasMacro)
return
}
ctx.ClearMacroCall(prev.ID())
if inlinedType.IsAssignableType(NullType) {
ctx.UpdateExpr(prev,
ctx.NewCall(operators.NotEquals,
inlined,
ctx.NewLiteral(types.NullValue),
))
return
}
if inlinedType.HasTrait(traits.SizerType) {
ctx.UpdateExpr(prev,
ctx.NewCall(operators.NotEquals,
ctx.NewMemberCall(overloads.Size, inlined),
ctx.NewLiteral(types.IntZero),
))
return
}
ctx.ReportErrorAtID(prev.ID(), "unable to inline expression type %v into presence test", inlinedType)
}
// isBindable indicates whether the inlined type can be used within a cel.bind() if the expression
// being replaced occurs within a presence test. Value types with a size() method or field selection
// support can be bound.
//
// In future iterations, support may also be added for indexer types which can be rewritten as an `in`
// expression; however, this would imply a rewrite of the inlined expression that may not be necessary
// in most cases.
func isBindable(matches []ast.NavigableExpr, inlined ast.Expr, inlinedType *Type) bool {
if inlinedType.IsAssignableType(NullType) ||
inlinedType.HasTrait(traits.SizerType) {
return true
}
for _, m := range matches {
if m.Kind() != ast.SelectKind {
continue
}
sel := m.AsSelect()
if sel.IsTestOnly() {
return false
}
}
return true
}
// matchVariable matches simple identifiers, select expressions, and presence test expressions
// which match the (potentially) qualified variable name provided as input.
//
// Note, this function does not support inlining against select expressions which includes optional
// field selection. This may be a future refinement.
func (opt *inliningOptimizer) matchVariable(varName string) ast.ExprMatcher {
return func(e ast.NavigableExpr) bool {
if e.Kind() == ast.IdentKind && e.AsIdent() == varName {
return true
}
if e.Kind() == ast.SelectKind {
sel := e.AsSelect()
// While the `ToQualifiedName` call could take the select directly, this
// would skip presence tests from possible matches, which we would like
// to include.
qualName, found := containers.ToQualifiedName(sel.Operand())
return found && qualName+"."+sel.FieldName() == varName
}
return false
}
}

View File

@@ -47,17 +47,11 @@ func CheckedExprToAst(checkedExpr *exprpb.CheckedExpr) *Ast {
//
// Prefer CheckedExprToAst if loading expressions from storage.
func CheckedExprToAstWithSource(checkedExpr *exprpb.CheckedExpr, src Source) (*Ast, error) {
checkedAST, err := ast.CheckedExprToCheckedAST(checkedExpr)
checked, err := ast.ToAST(checkedExpr)
if err != nil {
return nil, err
}
return &Ast{
expr: checkedAST.Expr,
info: checkedAST.SourceInfo,
source: src,
refMap: checkedAST.ReferenceMap,
typeMap: checkedAST.TypeMap,
}, nil
return &Ast{source: src, impl: checked}, nil
}
// AstToCheckedExpr converts an Ast to an protobuf CheckedExpr value.
@@ -67,13 +61,7 @@ func AstToCheckedExpr(a *Ast) (*exprpb.CheckedExpr, error) {
if !a.IsChecked() {
return nil, fmt.Errorf("cannot convert unchecked ast")
}
cAst := &ast.CheckedAST{
Expr: a.expr,
SourceInfo: a.info,
ReferenceMap: a.refMap,
TypeMap: a.typeMap,
}
return ast.CheckedASTToCheckedExpr(cAst)
return ast.ToProto(a.impl)
}
// ParsedExprToAst converts a parsed expression proto message to an Ast.
@@ -89,18 +77,12 @@ func ParsedExprToAst(parsedExpr *exprpb.ParsedExpr) *Ast {
//
// Prefer ParsedExprToAst if loading expressions from storage.
func ParsedExprToAstWithSource(parsedExpr *exprpb.ParsedExpr, src Source) *Ast {
si := parsedExpr.GetSourceInfo()
if si == nil {
si = &exprpb.SourceInfo{}
}
info, _ := ast.ProtoToSourceInfo(parsedExpr.GetSourceInfo())
if src == nil {
src = common.NewInfoSource(si)
}
return &Ast{
expr: parsedExpr.GetExpr(),
info: si,
source: src,
src = common.NewInfoSource(parsedExpr.GetSourceInfo())
}
e, _ := ast.ProtoToExpr(parsedExpr.GetExpr())
return &Ast{source: src, impl: ast.NewAST(e, info)}
}
// AstToParsedExpr converts an Ast to an protobuf ParsedExpr value.
@@ -116,9 +98,7 @@ func AstToParsedExpr(a *Ast) (*exprpb.ParsedExpr, error) {
// Note, the conversion may not be an exact replica of the original expression, but will produce
// a string that is semantically equivalent and whose textual representation is stable.
func AstToString(a *Ast) (string, error) {
expr := a.Expr()
info := a.SourceInfo()
return parser.Unparse(expr, info)
return parser.Unparse(a.impl.Expr(), a.impl.SourceInfo())
}
// RefValueToValue converts between ref.Val and api.expr.Value.

View File

@@ -20,6 +20,7 @@ import (
"strings"
"time"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/stdlib"
@@ -28,8 +29,6 @@ import (
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/interpreter"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
const (
@@ -313,7 +312,7 @@ func (lib *optionalLib) CompileOptions() []EnvOption {
Types(types.OptionalType),
// Configure the optMap and optFlatMap macros.
Macros(NewReceiverMacro(optMapMacro, 2, optMap)),
Macros(ReceiverMacro(optMapMacro, 2, optMap)),
// Global and member functions for working with optional values.
Function(optionalOfFunc,
@@ -374,7 +373,7 @@ func (lib *optionalLib) CompileOptions() []EnvOption {
Overload("optional_map_index_value", []*Type{OptionalType(mapTypeKV), paramTypeK}, optionalTypeV)),
}
if lib.version >= 1 {
opts = append(opts, Macros(NewReceiverMacro(optFlatMapMacro, 2, optFlatMap)))
opts = append(opts, Macros(ReceiverMacro(optFlatMapMacro, 2, optFlatMap)))
}
return opts
}
@@ -386,57 +385,57 @@ func (lib *optionalLib) ProgramOptions() []ProgramOption {
}
}
func optMap(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
func optMap(meh MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *Error) {
varIdent := args[0]
varName := ""
switch varIdent.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
varName = varIdent.GetIdentExpr().GetName()
switch varIdent.Kind() {
case ast.IdentKind:
varName = varIdent.AsIdent()
default:
return nil, meh.NewError(varIdent.GetId(), "optMap() variable name must be a simple identifier")
return nil, meh.NewError(varIdent.ID(), "optMap() variable name must be a simple identifier")
}
mapExpr := args[1]
return meh.GlobalCall(
return meh.NewCall(
operators.Conditional,
meh.ReceiverCall(hasValueFunc, target),
meh.GlobalCall(optionalOfFunc,
meh.Fold(
unusedIterVar,
meh.NewMemberCall(hasValueFunc, target),
meh.NewCall(optionalOfFunc,
meh.NewComprehension(
meh.NewList(),
unusedIterVar,
varName,
meh.ReceiverCall(valueFunc, target),
meh.LiteralBool(false),
meh.Ident(varName),
meh.NewMemberCall(valueFunc, target),
meh.NewLiteral(types.False),
meh.NewIdent(varName),
mapExpr,
),
),
meh.GlobalCall(optionalNoneFunc),
meh.NewCall(optionalNoneFunc),
), nil
}
func optFlatMap(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
func optFlatMap(meh MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *Error) {
varIdent := args[0]
varName := ""
switch varIdent.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
varName = varIdent.GetIdentExpr().GetName()
switch varIdent.Kind() {
case ast.IdentKind:
varName = varIdent.AsIdent()
default:
return nil, meh.NewError(varIdent.GetId(), "optFlatMap() variable name must be a simple identifier")
return nil, meh.NewError(varIdent.ID(), "optFlatMap() variable name must be a simple identifier")
}
mapExpr := args[1]
return meh.GlobalCall(
return meh.NewCall(
operators.Conditional,
meh.ReceiverCall(hasValueFunc, target),
meh.Fold(
unusedIterVar,
meh.NewMemberCall(hasValueFunc, target),
meh.NewComprehension(
meh.NewList(),
unusedIterVar,
varName,
meh.ReceiverCall(valueFunc, target),
meh.LiteralBool(false),
meh.Ident(varName),
meh.NewMemberCall(valueFunc, target),
meh.NewLiteral(types.False),
meh.NewIdent(varName),
mapExpr,
),
meh.GlobalCall(optionalNoneFunc),
meh.NewCall(optionalNoneFunc),
), nil
}

View File

@@ -15,6 +15,11 @@
package cel
import (
"fmt"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
@@ -26,7 +31,14 @@ import (
// a Macro should be created per arg-count or as a var arg macro.
type Macro = parser.Macro
// MacroExpander converts a call and its associated arguments into a new CEL abstract syntax tree.
// MacroFactory defines an expansion function which converts a call and its arguments to a cel.Expr value.
type MacroFactory = parser.MacroExpander
// MacroExprFactory assists with the creation of Expr values in a manner which is consistent
// the internal semantics and id generation behaviors of the parser and checker libraries.
type MacroExprFactory = parser.ExprHelper
// MacroExpander converts a call and its associated arguments into a protobuf Expr representation.
//
// If the MacroExpander determines within the implementation that an expansion is not needed it may return
// a nil Expr value to indicate a non-match. However, if an expansion is to be performed, but the arguments
@@ -36,48 +48,197 @@ type Macro = parser.Macro
// and produces as output an Expr ast node.
//
// Note: when the Macro.IsReceiverStyle() method returns true, the target argument will be nil.
type MacroExpander = parser.MacroExpander
type MacroExpander func(eh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error)
// MacroExprHelper exposes helper methods for creating new expressions within a CEL abstract syntax tree.
type MacroExprHelper = parser.ExprHelper
// ExprHelper assists with the manipulation of proto-based Expr values in a manner which is
// consistent with the source position and expression id generation code leveraged by both
// the parser and type-checker.
type MacroExprHelper interface {
// Copy the input expression with a brand new set of identifiers.
Copy(*exprpb.Expr) *exprpb.Expr
// LiteralBool creates an Expr value for a bool literal.
LiteralBool(value bool) *exprpb.Expr
// LiteralBytes creates an Expr value for a byte literal.
LiteralBytes(value []byte) *exprpb.Expr
// LiteralDouble creates an Expr value for double literal.
LiteralDouble(value float64) *exprpb.Expr
// LiteralInt creates an Expr value for an int literal.
LiteralInt(value int64) *exprpb.Expr
// LiteralString creates am Expr value for a string literal.
LiteralString(value string) *exprpb.Expr
// LiteralUint creates an Expr value for a uint literal.
LiteralUint(value uint64) *exprpb.Expr
// NewList creates a CreateList instruction where the list is comprised of the optional set
// of elements provided as arguments.
NewList(elems ...*exprpb.Expr) *exprpb.Expr
// NewMap creates a CreateStruct instruction for a map where the map is comprised of the
// optional set of key, value entries.
NewMap(entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr
// NewMapEntry creates a Map Entry for the key, value pair.
NewMapEntry(key *exprpb.Expr, val *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry
// NewObject creates a CreateStruct instruction for an object with a given type name and
// optional set of field initializers.
NewObject(typeName string, fieldInits ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr
// NewObjectFieldInit creates a new Object field initializer from the field name and value.
NewObjectFieldInit(field string, init *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry
// Fold creates a fold comprehension instruction.
//
// - iterVar is the iteration variable name.
// - iterRange represents the expression that resolves to a list or map where the elements or
// keys (respectively) will be iterated over.
// - accuVar is the accumulation variable name, typically parser.AccumulatorName.
// - accuInit is the initial expression whose value will be set for the accuVar prior to
// folding.
// - condition is the expression to test to determine whether to continue folding.
// - step is the expression to evaluation at the conclusion of a single fold iteration.
// - result is the computation to evaluate at the conclusion of the fold.
//
// The accuVar should not shadow variable names that you would like to reference within the
// environment in the step and condition expressions. Presently, the name __result__ is commonly
// used by built-in macros but this may change in the future.
Fold(iterVar string,
iterRange *exprpb.Expr,
accuVar string,
accuInit *exprpb.Expr,
condition *exprpb.Expr,
step *exprpb.Expr,
result *exprpb.Expr) *exprpb.Expr
// Ident creates an identifier Expr value.
Ident(name string) *exprpb.Expr
// AccuIdent returns an accumulator identifier for use with comprehension results.
AccuIdent() *exprpb.Expr
// GlobalCall creates a function call Expr value for a global (free) function.
GlobalCall(function string, args ...*exprpb.Expr) *exprpb.Expr
// ReceiverCall creates a function call Expr value for a receiver-style function.
ReceiverCall(function string, target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr
// PresenceTest creates a Select TestOnly Expr value for modelling has() semantics.
PresenceTest(operand *exprpb.Expr, field string) *exprpb.Expr
// Select create a field traversal Expr value.
Select(operand *exprpb.Expr, field string) *exprpb.Expr
// OffsetLocation returns the Location of the expression identifier.
OffsetLocation(exprID int64) common.Location
// NewError associates an error message with a given expression id.
NewError(exprID int64, message string) *Error
}
// GlobalMacro creates a Macro for a global function with the specified arg count.
func GlobalMacro(function string, argCount int, factory MacroFactory) Macro {
return parser.NewGlobalMacro(function, argCount, factory)
}
// ReceiverMacro creates a Macro for a receiver function matching the specified arg count.
func ReceiverMacro(function string, argCount int, factory MacroFactory) Macro {
return parser.NewReceiverMacro(function, argCount, factory)
}
// GlobalVarArgMacro creates a Macro for a global function with a variable arg count.
func GlobalVarArgMacro(function string, factory MacroFactory) Macro {
return parser.NewGlobalVarArgMacro(function, factory)
}
// ReceiverVarArgMacro creates a Macro for a receiver function matching a variable arg count.
func ReceiverVarArgMacro(function string, factory MacroFactory) Macro {
return parser.NewReceiverVarArgMacro(function, factory)
}
// NewGlobalMacro creates a Macro for a global function with the specified arg count.
//
// Deprecated: use GlobalMacro
func NewGlobalMacro(function string, argCount int, expander MacroExpander) Macro {
return parser.NewGlobalMacro(function, argCount, expander)
expand := adaptingExpander{expander}
return parser.NewGlobalMacro(function, argCount, expand.Expander)
}
// NewReceiverMacro creates a Macro for a receiver function matching the specified arg count.
//
// Deprecated: use ReceiverMacro
func NewReceiverMacro(function string, argCount int, expander MacroExpander) Macro {
return parser.NewReceiverMacro(function, argCount, expander)
expand := adaptingExpander{expander}
return parser.NewReceiverMacro(function, argCount, expand.Expander)
}
// NewGlobalVarArgMacro creates a Macro for a global function with a variable arg count.
//
// Deprecated: use GlobalVarArgMacro
func NewGlobalVarArgMacro(function string, expander MacroExpander) Macro {
return parser.NewGlobalVarArgMacro(function, expander)
expand := adaptingExpander{expander}
return parser.NewGlobalVarArgMacro(function, expand.Expander)
}
// NewReceiverVarArgMacro creates a Macro for a receiver function matching a variable arg count.
//
// Deprecated: use ReceiverVarArgMacro
func NewReceiverVarArgMacro(function string, expander MacroExpander) Macro {
return parser.NewReceiverVarArgMacro(function, expander)
expand := adaptingExpander{expander}
return parser.NewReceiverVarArgMacro(function, expand.Expander)
}
// HasMacroExpander expands the input call arguments into a presence test, e.g. has(<operand>.field)
func HasMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return parser.MakeHas(meh, target, args)
ph, err := toParserHelper(meh)
if err != nil {
return nil, err
}
arg, err := adaptToExpr(args[0])
if err != nil {
return nil, err
}
if arg.Kind() == ast.SelectKind {
s := arg.AsSelect()
return adaptToProto(ph.NewPresenceTest(s.Operand(), s.FieldName()))
}
return nil, ph.NewError(arg.ID(), "invalid argument to has() macro")
}
// ExistsMacroExpander expands the input call arguments into a comprehension that returns true if any of the
// elements in the range match the predicate expressions:
// <iterRange>.exists(<iterVar>, <predicate>)
func ExistsMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return parser.MakeExists(meh, target, args)
ph, err := toParserHelper(meh)
if err != nil {
return nil, err
}
out, err := parser.MakeExists(ph, mustAdaptToExpr(target), mustAdaptToExprs(args))
if err != nil {
return nil, err
}
return adaptToProto(out)
}
// ExistsOneMacroExpander expands the input call arguments into a comprehension that returns true if exactly
// one of the elements in the range match the predicate expressions:
// <iterRange>.exists_one(<iterVar>, <predicate>)
func ExistsOneMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return parser.MakeExistsOne(meh, target, args)
ph, err := toParserHelper(meh)
if err != nil {
return nil, err
}
out, err := parser.MakeExistsOne(ph, mustAdaptToExpr(target), mustAdaptToExprs(args))
if err != nil {
return nil, err
}
return adaptToProto(out)
}
// MapMacroExpander expands the input call arguments into a comprehension that transforms each element in the
@@ -91,14 +252,30 @@ func ExistsOneMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*ex
// In the second form only iterVar values which return true when provided to the predicate expression
// are transformed.
func MapMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return parser.MakeMap(meh, target, args)
ph, err := toParserHelper(meh)
if err != nil {
return nil, err
}
out, err := parser.MakeMap(ph, mustAdaptToExpr(target), mustAdaptToExprs(args))
if err != nil {
return nil, err
}
return adaptToProto(out)
}
// FilterMacroExpander expands the input call arguments into a comprehension which produces a list which contains
// only elements which match the provided predicate expression:
// <iterRange>.filter(<iterVar>, <predicate>)
func FilterMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return parser.MakeFilter(meh, target, args)
ph, err := toParserHelper(meh)
if err != nil {
return nil, err
}
out, err := parser.MakeFilter(ph, mustAdaptToExpr(target), mustAdaptToExprs(args))
if err != nil {
return nil, err
}
return adaptToProto(out)
}
var (
@@ -142,3 +319,258 @@ var (
// NoMacros provides an alias to an empty list of macros
NoMacros = []Macro{}
)
type adaptingExpander struct {
legacyExpander MacroExpander
}
func (adapt *adaptingExpander) Expander(eh parser.ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) {
var legacyTarget *exprpb.Expr = nil
var err *Error = nil
if target != nil {
legacyTarget, err = adaptToProto(target)
if err != nil {
return nil, err
}
}
legacyArgs := make([]*exprpb.Expr, len(args))
for i, arg := range args {
legacyArgs[i], err = adaptToProto(arg)
if err != nil {
return nil, err
}
}
ah := &adaptingHelper{modernHelper: eh}
legacyExpr, err := adapt.legacyExpander(ah, legacyTarget, legacyArgs)
if err != nil {
return nil, err
}
ex, err := adaptToExpr(legacyExpr)
if err != nil {
return nil, err
}
return ex, nil
}
func wrapErr(id int64, message string, err error) *common.Error {
return &common.Error{
Location: common.NoLocation,
Message: fmt.Sprintf("%s: %v", message, err),
ExprID: id,
}
}
type adaptingHelper struct {
modernHelper parser.ExprHelper
}
// Copy the input expression with a brand new set of identifiers.
func (ah *adaptingHelper) Copy(e *exprpb.Expr) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.Copy(mustAdaptToExpr(e)))
}
// LiteralBool creates an Expr value for a bool literal.
func (ah *adaptingHelper) LiteralBool(value bool) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Bool(value)))
}
// LiteralBytes creates an Expr value for a byte literal.
func (ah *adaptingHelper) LiteralBytes(value []byte) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Bytes(value)))
}
// LiteralDouble creates an Expr value for double literal.
func (ah *adaptingHelper) LiteralDouble(value float64) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Double(value)))
}
// LiteralInt creates an Expr value for an int literal.
func (ah *adaptingHelper) LiteralInt(value int64) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Int(value)))
}
// LiteralString creates am Expr value for a string literal.
func (ah *adaptingHelper) LiteralString(value string) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewLiteral(types.String(value)))
}
// LiteralUint creates an Expr value for a uint literal.
func (ah *adaptingHelper) LiteralUint(value uint64) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Uint(value)))
}
// NewList creates a CreateList instruction where the list is comprised of the optional set
// of elements provided as arguments.
func (ah *adaptingHelper) NewList(elems ...*exprpb.Expr) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewList(mustAdaptToExprs(elems)...))
}
// NewMap creates a CreateStruct instruction for a map where the map is comprised of the
// optional set of key, value entries.
func (ah *adaptingHelper) NewMap(entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr {
adaptedEntries := make([]ast.EntryExpr, len(entries))
for i, e := range entries {
adaptedEntries[i] = mustAdaptToEntryExpr(e)
}
return mustAdaptToProto(ah.modernHelper.NewMap(adaptedEntries...))
}
// NewMapEntry creates a Map Entry for the key, value pair.
func (ah *adaptingHelper) NewMapEntry(key *exprpb.Expr, val *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry {
return mustAdaptToProtoEntry(
ah.modernHelper.NewMapEntry(mustAdaptToExpr(key), mustAdaptToExpr(val), optional))
}
// NewObject creates a CreateStruct instruction for an object with a given type name and
// optional set of field initializers.
func (ah *adaptingHelper) NewObject(typeName string, fieldInits ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr {
adaptedEntries := make([]ast.EntryExpr, len(fieldInits))
for i, e := range fieldInits {
adaptedEntries[i] = mustAdaptToEntryExpr(e)
}
return mustAdaptToProto(ah.modernHelper.NewStruct(typeName, adaptedEntries...))
}
// NewObjectFieldInit creates a new Object field initializer from the field name and value.
func (ah *adaptingHelper) NewObjectFieldInit(field string, init *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry {
return mustAdaptToProtoEntry(
ah.modernHelper.NewStructField(field, mustAdaptToExpr(init), optional))
}
// Fold creates a fold comprehension instruction.
//
// - iterVar is the iteration variable name.
// - iterRange represents the expression that resolves to a list or map where the elements or
// keys (respectively) will be iterated over.
// - accuVar is the accumulation variable name, typically parser.AccumulatorName.
// - accuInit is the initial expression whose value will be set for the accuVar prior to
// folding.
// - condition is the expression to test to determine whether to continue folding.
// - step is the expression to evaluation at the conclusion of a single fold iteration.
// - result is the computation to evaluate at the conclusion of the fold.
//
// The accuVar should not shadow variable names that you would like to reference within the
// environment in the step and condition expressions. Presently, the name __result__ is commonly
// used by built-in macros but this may change in the future.
func (ah *adaptingHelper) Fold(iterVar string,
iterRange *exprpb.Expr,
accuVar string,
accuInit *exprpb.Expr,
condition *exprpb.Expr,
step *exprpb.Expr,
result *exprpb.Expr) *exprpb.Expr {
return mustAdaptToProto(
ah.modernHelper.NewComprehension(
mustAdaptToExpr(iterRange),
iterVar,
accuVar,
mustAdaptToExpr(accuInit),
mustAdaptToExpr(condition),
mustAdaptToExpr(step),
mustAdaptToExpr(result),
),
)
}
// Ident creates an identifier Expr value.
func (ah *adaptingHelper) Ident(name string) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewIdent(name))
}
// AccuIdent returns an accumulator identifier for use with comprehension results.
func (ah *adaptingHelper) AccuIdent() *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewAccuIdent())
}
// GlobalCall creates a function call Expr value for a global (free) function.
func (ah *adaptingHelper) GlobalCall(function string, args ...*exprpb.Expr) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewCall(function, mustAdaptToExprs(args)...))
}
// ReceiverCall creates a function call Expr value for a receiver-style function.
func (ah *adaptingHelper) ReceiverCall(function string, target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr {
return mustAdaptToProto(
ah.modernHelper.NewMemberCall(function, mustAdaptToExpr(target), mustAdaptToExprs(args)...))
}
// PresenceTest creates a Select TestOnly Expr value for modelling has() semantics.
func (ah *adaptingHelper) PresenceTest(operand *exprpb.Expr, field string) *exprpb.Expr {
op := mustAdaptToExpr(operand)
return mustAdaptToProto(ah.modernHelper.NewPresenceTest(op, field))
}
// Select create a field traversal Expr value.
func (ah *adaptingHelper) Select(operand *exprpb.Expr, field string) *exprpb.Expr {
op := mustAdaptToExpr(operand)
return mustAdaptToProto(ah.modernHelper.NewSelect(op, field))
}
// OffsetLocation returns the Location of the expression identifier.
func (ah *adaptingHelper) OffsetLocation(exprID int64) common.Location {
return ah.modernHelper.OffsetLocation(exprID)
}
// NewError associates an error message with a given expression id.
func (ah *adaptingHelper) NewError(exprID int64, message string) *Error {
return ah.modernHelper.NewError(exprID, message)
}
func mustAdaptToExprs(exprs []*exprpb.Expr) []ast.Expr {
adapted := make([]ast.Expr, len(exprs))
for i, e := range exprs {
adapted[i] = mustAdaptToExpr(e)
}
return adapted
}
func mustAdaptToExpr(e *exprpb.Expr) ast.Expr {
out, _ := adaptToExpr(e)
return out
}
func adaptToExpr(e *exprpb.Expr) (ast.Expr, *Error) {
if e == nil {
return nil, nil
}
out, err := ast.ProtoToExpr(e)
if err != nil {
return nil, wrapErr(e.GetId(), "proto conversion failure", err)
}
return out, nil
}
func mustAdaptToEntryExpr(e *exprpb.Expr_CreateStruct_Entry) ast.EntryExpr {
out, _ := ast.ProtoToEntryExpr(e)
return out
}
func mustAdaptToProto(e ast.Expr) *exprpb.Expr {
out, _ := adaptToProto(e)
return out
}
func adaptToProto(e ast.Expr) (*exprpb.Expr, *Error) {
if e == nil {
return nil, nil
}
out, err := ast.ExprToProto(e)
if err != nil {
return nil, wrapErr(e.ID(), "expr conversion failure", err)
}
return out, nil
}
func mustAdaptToProtoEntry(e ast.EntryExpr) *exprpb.Expr_CreateStruct_Entry {
out, _ := ast.EntryExprToProto(e)
return out
}
func toParserHelper(meh MacroExprHelper) (parser.ExprHelper, *Error) {
ah, ok := meh.(*adaptingHelper)
if !ok {
return nil, common.NewError(0,
fmt.Sprintf("unsupported macro helper: %v (%T)", meh, meh),
common.NoLocation)
}
return ah.modernHelper, nil
}

509
vendor/github.com/google/cel-go/cel/optimizer.go generated vendored Normal file
View File

@@ -0,0 +1,509 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cel
import (
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
// StaticOptimizer contains a sequence of ASTOptimizer instances which will be applied in order.
//
// The static optimizer normalizes expression ids and type-checking run between optimization
// passes to ensure that the final optimized output is a valid expression with metadata consistent
// with what would have been generated from a parsed and checked expression.
//
// Note: source position information is best-effort and likely wrong, but optimized expressions
// should be suitable for calls to parser.Unparse.
type StaticOptimizer struct {
optimizers []ASTOptimizer
}
// NewStaticOptimizer creates a StaticOptimizer with a sequence of ASTOptimizer's to be applied
// to a checked expression.
func NewStaticOptimizer(optimizers ...ASTOptimizer) *StaticOptimizer {
return &StaticOptimizer{
optimizers: optimizers,
}
}
// Optimize applies a sequence of optimizations to an Ast within a given environment.
//
// If issues are encountered, the Issues.Err() return value will be non-nil.
func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) {
// Make a copy of the AST to be optimized.
optimized := ast.Copy(a.impl)
ids := newIDGenerator(ast.MaxID(a.impl))
// Create the optimizer context, could be pooled in the future.
issues := NewIssues(common.NewErrors(a.Source()))
baseFac := ast.NewExprFactory()
exprFac := &optimizerExprFactory{
idGenerator: ids,
fac: baseFac,
sourceInfo: optimized.SourceInfo(),
}
ctx := &OptimizerContext{
optimizerExprFactory: exprFac,
Env: env,
Issues: issues,
}
// Apply the optimizations sequentially.
for _, o := range opt.optimizers {
optimized = o.Optimize(ctx, optimized)
if issues.Err() != nil {
return nil, issues
}
// Normalize expression id metadata including coordination with macro call metadata.
freshIDGen := newIDGenerator(0)
info := optimized.SourceInfo()
expr := optimized.Expr()
normalizeIDs(freshIDGen.renumberStable, expr, info)
cleanupMacroRefs(expr, info)
// Recheck the updated expression for any possible type-agreement or validation errors.
parsed := &Ast{
source: a.Source(),
impl: ast.NewAST(expr, info)}
checked, iss := ctx.Check(parsed)
if iss.Err() != nil {
return nil, iss
}
optimized = checked.impl
}
// Return the optimized result.
return &Ast{
source: a.Source(),
impl: optimized,
}, nil
}
// normalizeIDs ensures that the metadata present with an AST is reset in a manner such
// that the ids within the expression correspond to the ids within macros.
func normalizeIDs(idGen ast.IDGenerator, optimized ast.Expr, info *ast.SourceInfo) {
optimized.RenumberIDs(idGen)
if len(info.MacroCalls()) == 0 {
return
}
// First, update the macro call ids themselves.
callIDMap := map[int64]int64{}
for id := range info.MacroCalls() {
callIDMap[id] = idGen(id)
}
// Then update the macro call definitions which refer to these ids, but
// ensure that the updates don't collide and remove macro entries which haven't
// been visited / updated yet.
type macroUpdate struct {
id int64
call ast.Expr
}
macroUpdates := []macroUpdate{}
for oldID, newID := range callIDMap {
call, found := info.GetMacroCall(oldID)
if !found {
continue
}
call.RenumberIDs(idGen)
macroUpdates = append(macroUpdates, macroUpdate{id: newID, call: call})
info.ClearMacroCall(oldID)
}
for _, u := range macroUpdates {
info.SetMacroCall(u.id, u.call)
}
}
func cleanupMacroRefs(expr ast.Expr, info *ast.SourceInfo) {
if len(info.MacroCalls()) == 0 {
return
}
// Sanitize the macro call references once the optimized expression has been computed
// and the ids normalized between the expression and the macros.
exprRefMap := make(map[int64]struct{})
ast.PostOrderVisit(expr, ast.NewExprVisitor(func(e ast.Expr) {
if e.ID() == 0 {
return
}
exprRefMap[e.ID()] = struct{}{}
}))
// Update the macro call id references to ensure that macro pointers are
// updated consistently across macros.
for _, call := range info.MacroCalls() {
ast.PostOrderVisit(call, ast.NewExprVisitor(func(e ast.Expr) {
if e.ID() == 0 {
return
}
exprRefMap[e.ID()] = struct{}{}
}))
}
for id := range info.MacroCalls() {
if _, found := exprRefMap[id]; !found {
info.ClearMacroCall(id)
}
}
}
// newIDGenerator ensures that new ids are only created the first time they are encountered.
func newIDGenerator(seed int64) *idGenerator {
return &idGenerator{
idMap: make(map[int64]int64),
seed: seed,
}
}
type idGenerator struct {
idMap map[int64]int64
seed int64
}
func (gen *idGenerator) nextID() int64 {
gen.seed++
return gen.seed
}
func (gen *idGenerator) renumberStable(id int64) int64 {
if id == 0 {
return 0
}
if newID, found := gen.idMap[id]; found {
return newID
}
nextID := gen.nextID()
gen.idMap[id] = nextID
return nextID
}
// OptimizerContext embeds Env and Issues instances to make it easy to type-check and evaluate
// subexpressions and report any errors encountered along the way. The context also embeds the
// optimizerExprFactory which can be used to generate new sub-expressions with expression ids
// consistent with the expectations of a parsed expression.
type OptimizerContext struct {
*Env
*optimizerExprFactory
*Issues
}
// ASTOptimizer applies an optimization over an AST and returns the optimized result.
type ASTOptimizer interface {
// Optimize optimizes a type-checked AST within an Environment and accumulates any issues.
Optimize(*OptimizerContext, *ast.AST) *ast.AST
}
type optimizerExprFactory struct {
*idGenerator
fac ast.ExprFactory
sourceInfo *ast.SourceInfo
}
// NewAST creates an AST from the current expression using the tracked source info which
// is modified and managed by the OptimizerContext.
func (opt *optimizerExprFactory) NewAST(expr ast.Expr) *ast.AST {
return ast.NewAST(expr, opt.sourceInfo)
}
// CopyAST creates a renumbered copy of `Expr` and `SourceInfo` values of the input AST, where the
// renumbering uses the same scheme as the core optimizer logic ensuring there are no collisions
// between copies.
//
// Use this method before attempting to merge the expression from AST into another.
func (opt *optimizerExprFactory) CopyAST(a *ast.AST) (ast.Expr, *ast.SourceInfo) {
idGen := newIDGenerator(opt.nextID())
defer func() { opt.seed = idGen.nextID() }()
copyExpr := opt.fac.CopyExpr(a.Expr())
copyInfo := ast.CopySourceInfo(a.SourceInfo())
normalizeIDs(idGen.renumberStable, copyExpr, copyInfo)
return copyExpr, copyInfo
}
// CopyASTAndMetadata copies the input AST and propagates the macro metadata into the AST being
// optimized.
func (opt *optimizerExprFactory) CopyASTAndMetadata(a *ast.AST) ast.Expr {
copyExpr, copyInfo := opt.CopyAST(a)
for macroID, call := range copyInfo.MacroCalls() {
opt.SetMacroCall(macroID, call)
}
return copyExpr
}
// ClearMacroCall clears the macro at the given expression id.
func (opt *optimizerExprFactory) ClearMacroCall(id int64) {
opt.sourceInfo.ClearMacroCall(id)
}
// SetMacroCall sets the macro call metadata for the given macro id within the tracked source info
// metadata.
func (opt *optimizerExprFactory) SetMacroCall(id int64, expr ast.Expr) {
opt.sourceInfo.SetMacroCall(id, expr)
}
// NewBindMacro creates an AST expression representing the expanded bind() macro, and a macro expression
// representing the unexpanded call signature to be inserted into the source info macro call metadata.
func (opt *optimizerExprFactory) NewBindMacro(macroID int64, varName string, varInit, remaining ast.Expr) (astExpr, macroExpr ast.Expr) {
varID := opt.nextID()
remainingID := opt.nextID()
remaining = opt.fac.CopyExpr(remaining)
remaining.RenumberIDs(func(id int64) int64 {
if id == macroID {
return remainingID
}
return id
})
if call, exists := opt.sourceInfo.GetMacroCall(macroID); exists {
opt.SetMacroCall(remainingID, opt.fac.CopyExpr(call))
}
astExpr = opt.fac.NewComprehension(macroID,
opt.fac.NewList(opt.nextID(), []ast.Expr{}, []int32{}),
"#unused",
varName,
opt.fac.CopyExpr(varInit),
opt.fac.NewLiteral(opt.nextID(), types.False),
opt.fac.NewIdent(varID, varName),
remaining)
macroExpr = opt.fac.NewMemberCall(0, "bind",
opt.fac.NewIdent(opt.nextID(), "cel"),
opt.fac.NewIdent(varID, varName),
opt.fac.CopyExpr(varInit),
opt.fac.CopyExpr(remaining))
opt.sanitizeMacro(macroID, macroExpr)
return
}
// NewCall creates a global function call invocation expression.
//
// Example:
//
// countByField(list, fieldName)
// - function: countByField
// - args: [list, fieldName]
func (opt *optimizerExprFactory) NewCall(function string, args ...ast.Expr) ast.Expr {
return opt.fac.NewCall(opt.nextID(), function, args...)
}
// NewMemberCall creates a member function call invocation expression where 'target' is the receiver of the call.
//
// Example:
//
// list.countByField(fieldName)
// - function: countByField
// - target: list
// - args: [fieldName]
func (opt *optimizerExprFactory) NewMemberCall(function string, target ast.Expr, args ...ast.Expr) ast.Expr {
return opt.fac.NewMemberCall(opt.nextID(), function, target, args...)
}
// NewIdent creates a new identifier expression.
//
// Examples:
//
// - simple_var_name
// - qualified.subpackage.var_name
func (opt *optimizerExprFactory) NewIdent(name string) ast.Expr {
return opt.fac.NewIdent(opt.nextID(), name)
}
// NewLiteral creates a new literal expression value.
//
// The range of valid values for a literal generated during optimization is different than for expressions
// generated via parsing / type-checking, as the ref.Val may be _any_ CEL value so long as the value can
// be converted back to a literal-like form.
func (opt *optimizerExprFactory) NewLiteral(value ref.Val) ast.Expr {
return opt.fac.NewLiteral(opt.nextID(), value)
}
// NewList creates a list expression with a set of optional indices.
//
// Examples:
//
// [a, b]
// - elems: [a, b]
// - optIndices: []
//
// [a, ?b, ?c]
// - elems: [a, b, c]
// - optIndices: [1, 2]
func (opt *optimizerExprFactory) NewList(elems []ast.Expr, optIndices []int32) ast.Expr {
return opt.fac.NewList(opt.nextID(), elems, optIndices)
}
// NewMap creates a map from a set of entry expressions which contain a key and value expression.
func (opt *optimizerExprFactory) NewMap(entries []ast.EntryExpr) ast.Expr {
return opt.fac.NewMap(opt.nextID(), entries)
}
// NewMapEntry creates a map entry with a key and value expression and a flag to indicate whether the
// entry is optional.
//
// Examples:
//
// {a: b}
// - key: a
// - value: b
// - optional: false
//
// {?a: ?b}
// - key: a
// - value: b
// - optional: true
func (opt *optimizerExprFactory) NewMapEntry(key, value ast.Expr, isOptional bool) ast.EntryExpr {
return opt.fac.NewMapEntry(opt.nextID(), key, value, isOptional)
}
// NewHasMacro generates a test-only select expression to be included within an AST and an unexpanded
// has() macro call signature to be inserted into the source info macro call metadata.
func (opt *optimizerExprFactory) NewHasMacro(macroID int64, s ast.Expr) (astExpr, macroExpr ast.Expr) {
sel := s.AsSelect()
astExpr = opt.fac.NewPresenceTest(macroID, sel.Operand(), sel.FieldName())
macroExpr = opt.fac.NewCall(0, "has",
opt.NewSelect(opt.fac.CopyExpr(sel.Operand()), sel.FieldName()))
opt.sanitizeMacro(macroID, macroExpr)
return
}
// NewSelect creates a select expression where a field value is selected from an operand.
//
// Example:
//
// msg.field_name
// - operand: msg
// - field: field_name
func (opt *optimizerExprFactory) NewSelect(operand ast.Expr, field string) ast.Expr {
return opt.fac.NewSelect(opt.nextID(), operand, field)
}
// NewStruct creates a new typed struct value with an set of field initializations.
//
// Example:
//
// pkg.TypeName{field: value}
// - typeName: pkg.TypeName
// - fields: [{field: value}]
func (opt *optimizerExprFactory) NewStruct(typeName string, fields []ast.EntryExpr) ast.Expr {
return opt.fac.NewStruct(opt.nextID(), typeName, fields)
}
// NewStructField creates a struct field initialization.
//
// Examples:
//
// {count: 3u}
// - field: count
// - value: 3u
// - optional: false
//
// {?count: x}
// - field: count
// - value: x
// - optional: true
func (opt *optimizerExprFactory) NewStructField(field string, value ast.Expr, isOptional bool) ast.EntryExpr {
return opt.fac.NewStructField(opt.nextID(), field, value, isOptional)
}
// UpdateExpr updates the target expression with the updated content while preserving macro metadata.
//
// There are four scenarios during the update to consider:
// 1. target is not macro, updated is not macro
// 2. target is macro, updated is not macro
// 3. target is macro, updated is macro
// 4. target is not macro, updated is macro
//
// When the target is a macro already, it may either be updated to a new macro function
// body if the update is also a macro, or it may be removed altogether if the update is
// a macro.
//
// When the update is a macro, then the target references within other macros must be
// updated to point to the new updated macro. Otherwise, other macros which pointed to
// the target body must be replaced with copies of the updated expression body.
func (opt *optimizerExprFactory) UpdateExpr(target, updated ast.Expr) {
// Update the expression
target.SetKindCase(updated)
// Early return if there's no macros present sa the source info reflects the
// macro set from the target and updated expressions.
if len(opt.sourceInfo.MacroCalls()) == 0 {
return
}
// Determine whether the target expression was a macro.
_, targetIsMacro := opt.sourceInfo.GetMacroCall(target.ID())
// Determine whether the updated expression was a macro.
updatedMacro, updatedIsMacro := opt.sourceInfo.GetMacroCall(updated.ID())
if updatedIsMacro {
// If the updated call was a macro, then updated id maps to target id,
// and the updated macro moves into the target id slot.
opt.sourceInfo.ClearMacroCall(updated.ID())
opt.sourceInfo.SetMacroCall(target.ID(), updatedMacro)
} else if targetIsMacro {
// Otherwise if the target expr was a macro, but is no longer, clear
// the macro reference.
opt.sourceInfo.ClearMacroCall(target.ID())
}
// Punch holes in the updated value where macros references exist.
macroExpr := opt.fac.CopyExpr(target)
macroRefVisitor := ast.NewExprVisitor(func(e ast.Expr) {
if _, exists := opt.sourceInfo.GetMacroCall(e.ID()); exists {
e.SetKindCase(nil)
}
})
ast.PostOrderVisit(macroExpr, macroRefVisitor)
// Update any references to the expression within a macro
macroVisitor := ast.NewExprVisitor(func(call ast.Expr) {
// Update the target expression to point to the macro expression which
// will be empty if the updated expression was a macro.
if call.ID() == target.ID() {
call.SetKindCase(opt.fac.CopyExpr(macroExpr))
}
// Update the macro call expression if it refers to the updated expression
// id which has since been remapped to the target id.
if call.ID() == updated.ID() {
// Either ensure the expression is a macro reference or a populated with
// the relevant sub-expression if the updated expr was not a macro.
if updatedIsMacro {
call.SetKindCase(nil)
} else {
call.SetKindCase(opt.fac.CopyExpr(macroExpr))
}
// Since SetKindCase does not renumber the id, ensure the references to
// the old 'updated' id are mapped to the target id.
call.RenumberIDs(func(id int64) int64 {
if id == updated.ID() {
return target.ID()
}
return id
})
}
})
for _, call := range opt.sourceInfo.MacroCalls() {
ast.PostOrderVisit(call, macroVisitor)
}
}
func (opt *optimizerExprFactory) sanitizeMacro(macroID int64, macroExpr ast.Expr) {
macroRefVisitor := ast.NewExprVisitor(func(e ast.Expr) {
if _, exists := opt.sourceInfo.GetMacroCall(e.ID()); exists && e.ID() != macroID {
e.SetKindCase(nil)
}
})
ast.PostOrderVisit(macroExpr, macroRefVisitor)
}

View File

@@ -448,6 +448,8 @@ const (
OptTrackCost EvalOption = 1 << iota
// OptCheckStringFormat enables compile-time checking of string.format calls for syntax/cardinality.
//
// Deprecated: use ext.StringsValidateFormatCalls() as this option is now a no-op.
OptCheckStringFormat EvalOption = 1 << iota
)

View File

@@ -19,7 +19,6 @@ import (
"fmt"
"sync"
celast "github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter"
@@ -152,7 +151,7 @@ func (p *prog) clone() *prog {
// ProgramOption values.
//
// If the program cannot be configured the prog will be nil, with a non-nil error response.
func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) {
// Build the dispatcher, interpreter, and default program value.
disp := interpreter.NewDispatcher()
@@ -213,34 +212,6 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
if len(p.regexOptimizations) > 0 {
decorators = append(decorators, interpreter.CompileRegexConstants(p.regexOptimizations...))
}
// Enable compile-time checking of syntax/cardinality for string.format calls.
if p.evalOpts&OptCheckStringFormat == OptCheckStringFormat {
var isValidType func(id int64, validTypes ...ref.Type) (bool, error)
if ast.IsChecked() {
isValidType = func(id int64, validTypes ...ref.Type) (bool, error) {
t := ast.typeMap[id]
if t.Kind() == DynKind {
return true, nil
}
for _, vt := range validTypes {
k, err := typeValueToKind(vt)
if err != nil {
return false, err
}
if t.Kind() == k {
return true, nil
}
}
return false, nil
}
} else {
// if the AST isn't type-checked, short-circuit validation
isValidType = func(id int64, validTypes ...ref.Type) (bool, error) {
return true, nil
}
}
decorators = append(decorators, interpreter.InterpolateFormattedString(isValidType))
}
// Enable exhaustive eval, state tracking and cost tracking last since they require a factory.
if p.evalOpts&(OptExhaustiveEval|OptTrackState|OptTrackCost) != 0 {
@@ -274,33 +245,16 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
decs = append(decs, interpreter.Observe(observers...))
}
return p.clone().initInterpretable(ast, decs)
return p.clone().initInterpretable(a, decs)
}
return newProgGen(factory)
}
return p.initInterpretable(ast, decorators)
return p.initInterpretable(a, decorators)
}
func (p *prog) initInterpretable(ast *Ast, decs []interpreter.InterpretableDecorator) (*prog, error) {
// Unchecked programs do not contain type and reference information and may be slower to execute.
if !ast.IsChecked() {
interpretable, err :=
p.interpreter.NewUncheckedInterpretable(ast.Expr(), decs...)
if err != nil {
return nil, err
}
p.interpretable = interpretable
return p, nil
}
// When the AST has been checked it contains metadata that can be used to speed up program execution.
checked := &celast.CheckedAST{
Expr: ast.Expr(),
SourceInfo: ast.SourceInfo(),
TypeMap: ast.typeMap,
ReferenceMap: ast.refMap,
}
interpretable, err := p.interpreter.NewInterpretable(checked, decs...)
func (p *prog) initInterpretable(a *Ast, decs []interpreter.InterpretableDecorator) (*prog, error) {
// When the AST has been exprAST it contains metadata that can be used to speed up program execution.
interpretable, err := p.interpreter.NewInterpretable(a.impl, decs...)
if err != nil {
return nil, err
}
@@ -580,8 +534,6 @@ func (p *evalActivationPool) Put(value any) {
}
var (
emptyEvalState = interpreter.NewEvalState()
// activationPool is an internally managed pool of Activation values that wrap map[string]any inputs
activationPool = newEvalActivationPool()

View File

@@ -21,8 +21,6 @@ import (
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/overloads"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
const (
@@ -69,7 +67,7 @@ type ASTValidator interface {
//
// See individual validators for more information on their configuration keys and configuration
// properties.
Validate(*Env, ValidatorConfig, *ast.CheckedAST, *Issues)
Validate(*Env, ValidatorConfig, *ast.AST, *Issues)
}
// ValidatorConfig provides an accessor method for querying validator configuration state.
@@ -180,7 +178,7 @@ func ValidateComprehensionNestingLimit(limit int) ASTValidator {
return nestingLimitValidator{limit: limit}
}
type argChecker func(env *Env, call, arg ast.NavigableExpr) error
type argChecker func(env *Env, call, arg ast.Expr) error
func newFormatValidator(funcName string, argNum int, check argChecker) formatValidator {
return formatValidator{
@@ -203,8 +201,8 @@ func (v formatValidator) Name() string {
// Validate searches the AST for uses of a given function name with a constant argument and performs a check
// on whether the argument is a valid literal value.
func (v formatValidator) Validate(e *Env, _ ValidatorConfig, a *ast.CheckedAST, iss *Issues) {
root := ast.NavigateCheckedAST(a)
func (v formatValidator) Validate(e *Env, _ ValidatorConfig, a *ast.AST, iss *Issues) {
root := ast.NavigateAST(a)
funcCalls := ast.MatchDescendants(root, ast.FunctionMatcher(v.funcName))
for _, call := range funcCalls {
callArgs := call.AsCall().Args()
@@ -221,8 +219,8 @@ func (v formatValidator) Validate(e *Env, _ ValidatorConfig, a *ast.CheckedAST,
}
}
func evalCall(env *Env, call, arg ast.NavigableExpr) error {
ast := ParsedExprToAst(&exprpb.ParsedExpr{Expr: call.ToExpr()})
func evalCall(env *Env, call, arg ast.Expr) error {
ast := &Ast{impl: ast.NewAST(call, ast.NewSourceInfo(nil))}
prg, err := env.Program(ast)
if err != nil {
return err
@@ -231,7 +229,7 @@ func evalCall(env *Env, call, arg ast.NavigableExpr) error {
return err
}
func compileRegex(_ *Env, _, arg ast.NavigableExpr) error {
func compileRegex(_ *Env, _, arg ast.Expr) error {
pattern := arg.AsLiteral().Value().(string)
_, err := regexp.Compile(pattern)
return err
@@ -244,25 +242,14 @@ func (homogeneousAggregateLiteralValidator) Name() string {
return homogeneousValidatorName
}
// Configure implements the ASTValidatorConfigurer interface and currently sets the list of standard
// and exempt functions from homogeneous aggregate literal checks.
//
// TODO: Move this call into the string.format() ASTValidator once ported.
func (homogeneousAggregateLiteralValidator) Configure(c MutableValidatorConfig) error {
emptyList := []string{}
exemptFunctions := c.GetOrDefault(HomogeneousAggregateLiteralExemptFunctions, emptyList).([]string)
exemptFunctions = append(exemptFunctions, "format")
return c.Set(HomogeneousAggregateLiteralExemptFunctions, exemptFunctions)
}
// Validate validates that all lists and map literals have homogeneous types, i.e. don't contain dyn types.
//
// This validator makes an exception for list and map literals which occur at any level of nesting within
// string format calls.
func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig, a *ast.CheckedAST, iss *Issues) {
func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig, a *ast.AST, iss *Issues) {
var exemptedFunctions []string
exemptedFunctions = c.GetOrDefault(HomogeneousAggregateLiteralExemptFunctions, exemptedFunctions).([]string)
root := ast.NavigateCheckedAST(a)
root := ast.NavigateAST(a)
listExprs := ast.MatchDescendants(root, ast.KindMatcher(ast.ListKind))
for _, listExpr := range listExprs {
if inExemptFunction(listExpr, exemptedFunctions) {
@@ -273,7 +260,7 @@ func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig
optIndices := l.OptionalIndices()
var elemType *Type
for i, e := range elements {
et := e.Type()
et := a.GetType(e.ID())
if isOptionalIndex(i, optIndices) {
et = et.Parameters()[0]
}
@@ -296,9 +283,10 @@ func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig
entries := m.Entries()
var keyType, valType *Type
for _, e := range entries {
key, val := e.Key(), e.Value()
kt, vt := key.Type(), val.Type()
if e.IsOptional() {
mapEntry := e.AsMapEntry()
key, val := mapEntry.Key(), mapEntry.Value()
kt, vt := a.GetType(key.ID()), a.GetType(val.ID())
if mapEntry.IsOptional() {
vt = vt.Parameters()[0]
}
if keyType == nil && valType == nil {
@@ -316,7 +304,8 @@ func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig
}
func inExemptFunction(e ast.NavigableExpr, exemptFunctions []string) bool {
if parent, found := e.Parent(); found {
parent, found := e.Parent()
for found {
if parent.Kind() == ast.CallKind {
fnName := parent.AsCall().FunctionName()
for _, exempt := range exemptFunctions {
@@ -325,9 +314,7 @@ func inExemptFunction(e ast.NavigableExpr, exemptFunctions []string) bool {
}
}
}
if parent.Kind() == ast.ListKind || parent.Kind() == ast.MapKind {
return inExemptFunction(parent, exemptFunctions)
}
parent, found = parent.Parent()
}
return false
}
@@ -353,8 +340,8 @@ func (v nestingLimitValidator) Name() string {
return "cel.lib.std.validate.comprehension_nesting_limit"
}
func (v nestingLimitValidator) Validate(e *Env, _ ValidatorConfig, a *ast.CheckedAST, iss *Issues) {
root := ast.NavigateCheckedAST(a)
func (v nestingLimitValidator) Validate(e *Env, _ ValidatorConfig, a *ast.AST, iss *Issues) {
root := ast.NavigateAST(a)
comprehensions := ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind))
if len(comprehensions) <= v.limit {
return

View File

@@ -60,7 +60,6 @@ go_test(
"//test:go_default_library",
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",
"@com_github_antlr_antlr4_runtime_go_antlr_v4//:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
],
)

View File

@@ -18,6 +18,7 @@ package checker
import (
"fmt"
"reflect"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
@@ -25,139 +26,98 @@ import (
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/types/ref"
)
type checker struct {
*ast.AST
ast.ExprFactory
env *Env
errors *typeErrors
mappings *mapping
freeTypeVarCounter int
sourceInfo *exprpb.SourceInfo
types map[int64]*types.Type
references map[int64]*ast.ReferenceInfo
}
// Check performs type checking, giving a typed AST.
// The input is a ParsedExpr proto and an env which encapsulates
// type binding of variables, declarations of built-in functions,
// descriptions of protocol buffers, and a registry for errors.
// Returns a CheckedExpr proto, which might not be usable if
// there are errors in the error registry.
func Check(parsedExpr *exprpb.ParsedExpr, source common.Source, env *Env) (*ast.CheckedAST, *common.Errors) {
//
// The input is a parsed AST and an env which encapsulates type binding of variables,
// declarations of built-in functions, descriptions of protocol buffers, and a registry for
// errors.
//
// Returns a type-checked AST, which might not be usable if there are errors in the error
// registry.
func Check(parsed *ast.AST, source common.Source, env *Env) (*ast.AST, *common.Errors) {
errs := common.NewErrors(source)
typeMap := make(map[int64]*types.Type)
refMap := make(map[int64]*ast.ReferenceInfo)
c := checker{
AST: ast.NewCheckedAST(parsed, typeMap, refMap),
ExprFactory: ast.NewExprFactory(),
env: env,
errors: &typeErrors{errs: errs},
mappings: newMapping(),
freeTypeVarCounter: 0,
sourceInfo: parsedExpr.GetSourceInfo(),
types: make(map[int64]*types.Type),
references: make(map[int64]*ast.ReferenceInfo),
}
c.check(parsedExpr.GetExpr())
c.check(c.Expr())
// Walk over the final type map substituting any type parameters either by their bound value or
// by DYN.
m := make(map[int64]*types.Type)
for id, t := range c.types {
m[id] = substitute(c.mappings, t, true)
// Walk over the final type map substituting any type parameters either by their bound value
// or by DYN.
for id, t := range c.TypeMap() {
c.SetType(id, substitute(c.mappings, t, true))
}
return &ast.CheckedAST{
Expr: parsedExpr.GetExpr(),
SourceInfo: parsedExpr.GetSourceInfo(),
TypeMap: m,
ReferenceMap: c.references,
}, errs
return c.AST, errs
}
func (c *checker) check(e *exprpb.Expr) {
func (c *checker) check(e ast.Expr) {
if e == nil {
return
}
switch e.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
literal := e.GetConstExpr()
switch literal.GetConstantKind().(type) {
case *exprpb.Constant_BoolValue:
c.checkBoolLiteral(e)
case *exprpb.Constant_BytesValue:
c.checkBytesLiteral(e)
case *exprpb.Constant_DoubleValue:
c.checkDoubleLiteral(e)
case *exprpb.Constant_Int64Value:
c.checkInt64Literal(e)
case *exprpb.Constant_NullValue:
c.checkNullLiteral(e)
case *exprpb.Constant_StringValue:
c.checkStringLiteral(e)
case *exprpb.Constant_Uint64Value:
c.checkUint64Literal(e)
switch e.Kind() {
case ast.LiteralKind:
literal := ref.Val(e.AsLiteral())
switch literal.Type() {
case types.BoolType, types.BytesType, types.DoubleType, types.IntType,
types.NullType, types.StringType, types.UintType:
c.setType(e, literal.Type().(*types.Type))
default:
c.errors.unexpectedASTType(e.ID(), c.location(e), "literal", literal.Type().TypeName())
}
case *exprpb.Expr_IdentExpr:
case ast.IdentKind:
c.checkIdent(e)
case *exprpb.Expr_SelectExpr:
case ast.SelectKind:
c.checkSelect(e)
case *exprpb.Expr_CallExpr:
case ast.CallKind:
c.checkCall(e)
case *exprpb.Expr_ListExpr:
case ast.ListKind:
c.checkCreateList(e)
case *exprpb.Expr_StructExpr:
case ast.MapKind:
c.checkCreateMap(e)
case ast.StructKind:
c.checkCreateStruct(e)
case *exprpb.Expr_ComprehensionExpr:
case ast.ComprehensionKind:
c.checkComprehension(e)
default:
c.errors.unexpectedASTType(e.GetId(), c.location(e), e)
c.errors.unexpectedASTType(e.ID(), c.location(e), "unspecified", reflect.TypeOf(e).Name())
}
}
func (c *checker) checkInt64Literal(e *exprpb.Expr) {
c.setType(e, types.IntType)
}
func (c *checker) checkUint64Literal(e *exprpb.Expr) {
c.setType(e, types.UintType)
}
func (c *checker) checkStringLiteral(e *exprpb.Expr) {
c.setType(e, types.StringType)
}
func (c *checker) checkBytesLiteral(e *exprpb.Expr) {
c.setType(e, types.BytesType)
}
func (c *checker) checkDoubleLiteral(e *exprpb.Expr) {
c.setType(e, types.DoubleType)
}
func (c *checker) checkBoolLiteral(e *exprpb.Expr) {
c.setType(e, types.BoolType)
}
func (c *checker) checkNullLiteral(e *exprpb.Expr) {
c.setType(e, types.NullType)
}
func (c *checker) checkIdent(e *exprpb.Expr) {
identExpr := e.GetIdentExpr()
func (c *checker) checkIdent(e ast.Expr) {
identName := e.AsIdent()
// Check to see if the identifier is declared.
if ident := c.env.LookupIdent(identExpr.GetName()); ident != nil {
if ident := c.env.LookupIdent(identName); ident != nil {
c.setType(e, ident.Type())
c.setReference(e, ast.NewIdentReference(ident.Name(), ident.Value()))
// Overwrite the identifier with its fully qualified name.
identExpr.Name = ident.Name()
e.SetKindCase(c.NewIdent(e.ID(), ident.Name()))
return
}
c.setType(e, types.ErrorType)
c.errors.undeclaredReference(e.GetId(), c.location(e), c.env.container.Name(), identExpr.GetName())
c.errors.undeclaredReference(e.ID(), c.location(e), c.env.container.Name(), identName)
}
func (c *checker) checkSelect(e *exprpb.Expr) {
sel := e.GetSelectExpr()
func (c *checker) checkSelect(e ast.Expr) {
sel := e.AsSelect()
// Before traversing down the tree, try to interpret as qualified name.
qname, found := containers.ToQualifiedName(e)
if found {
@@ -170,31 +130,26 @@ func (c *checker) checkSelect(e *exprpb.Expr) {
// variable name.
c.setType(e, ident.Type())
c.setReference(e, ast.NewIdentReference(ident.Name(), ident.Value()))
identName := ident.Name()
e.ExprKind = &exprpb.Expr_IdentExpr{
IdentExpr: &exprpb.Expr_Ident{
Name: identName,
},
}
e.SetKindCase(c.NewIdent(e.ID(), ident.Name()))
return
}
}
resultType := c.checkSelectField(e, sel.GetOperand(), sel.GetField(), false)
if sel.TestOnly {
resultType := c.checkSelectField(e, sel.Operand(), sel.FieldName(), false)
if sel.IsTestOnly() {
resultType = types.BoolType
}
c.setType(e, substitute(c.mappings, resultType, false))
}
func (c *checker) checkOptSelect(e *exprpb.Expr) {
func (c *checker) checkOptSelect(e ast.Expr) {
// Collect metadata related to the opt select call packaged by the parser.
call := e.GetCallExpr()
operand := call.GetArgs()[0]
field := call.GetArgs()[1]
call := e.AsCall()
operand := call.Args()[0]
field := call.Args()[1]
fieldName, isString := maybeUnwrapString(field)
if !isString {
c.errors.notAnOptionalFieldSelection(field.GetId(), c.location(field), field)
c.errors.notAnOptionalFieldSelection(field.ID(), c.location(field), field)
return
}
@@ -204,7 +159,7 @@ func (c *checker) checkOptSelect(e *exprpb.Expr) {
c.setReference(e, ast.NewFunctionReference("select_optional_field"))
}
func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, optional bool) *types.Type {
func (c *checker) checkSelectField(e, operand ast.Expr, field string, optional bool) *types.Type {
// Interpret as field selection, first traversing down the operand.
c.check(operand)
operandType := substitute(c.mappings, c.getType(operand), false)
@@ -222,7 +177,7 @@ func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, option
// Objects yield their field type declaration as the selection result type, but only if
// the field is defined.
messageType := targetType
if fieldType, found := c.lookupFieldType(e.GetId(), messageType.TypeName(), field); found {
if fieldType, found := c.lookupFieldType(e.ID(), messageType.TypeName(), field); found {
resultType = fieldType
}
case types.TypeParamKind:
@@ -236,7 +191,7 @@ func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, option
// Dynamic / error values are treated as DYN type. Errors are handled this way as well
// in order to allow forward progress on the check.
if !isDynOrError(targetType) {
c.errors.typeDoesNotSupportFieldSelection(e.GetId(), c.location(e), targetType)
c.errors.typeDoesNotSupportFieldSelection(e.ID(), c.location(e), targetType)
}
resultType = types.DynType
}
@@ -248,35 +203,34 @@ func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, option
return resultType
}
func (c *checker) checkCall(e *exprpb.Expr) {
func (c *checker) checkCall(e ast.Expr) {
// Note: similar logic exists within the `interpreter/planner.go`. If making changes here
// please consider the impact on planner.go and consolidate implementations or mirror code
// as appropriate.
call := e.GetCallExpr()
fnName := call.GetFunction()
call := e.AsCall()
fnName := call.FunctionName()
if fnName == operators.OptSelect {
c.checkOptSelect(e)
return
}
args := call.GetArgs()
args := call.Args()
// Traverse arguments.
for _, arg := range args {
c.check(arg)
}
target := call.GetTarget()
// Regular static call with simple name.
if target == nil {
if !call.IsMemberFunction() {
// Check for the existence of the function.
fn := c.env.LookupFunction(fnName)
if fn == nil {
c.errors.undeclaredReference(e.GetId(), c.location(e), c.env.container.Name(), fnName)
c.errors.undeclaredReference(e.ID(), c.location(e), c.env.container.Name(), fnName)
c.setType(e, types.ErrorType)
return
}
// Overwrite the function name with its fully qualified resolved name.
call.Function = fn.Name()
e.SetKindCase(c.NewCall(e.ID(), fn.Name(), args...))
// Check to see whether the overload resolves.
c.resolveOverloadOrError(e, fn, nil, args)
return
@@ -287,6 +241,7 @@ func (c *checker) checkCall(e *exprpb.Expr) {
// target a.b.
//
// Check whether the target is a namespaced function name.
target := call.Target()
qualifiedPrefix, maybeQualified := containers.ToQualifiedName(target)
if maybeQualified {
maybeQualifiedName := qualifiedPrefix + "." + fnName
@@ -295,15 +250,14 @@ func (c *checker) checkCall(e *exprpb.Expr) {
// The function name is namespaced and so preserving the target operand would
// be an inaccurate representation of the desired evaluation behavior.
// Overwrite with fully-qualified resolved function name sans receiver target.
call.Target = nil
call.Function = fn.Name()
e.SetKindCase(c.NewCall(e.ID(), fn.Name(), args...))
c.resolveOverloadOrError(e, fn, nil, args)
return
}
}
// Regular instance call.
c.check(call.Target)
c.check(target)
fn := c.env.LookupFunction(fnName)
// Function found, attempt overload resolution.
if fn != nil {
@@ -312,11 +266,11 @@ func (c *checker) checkCall(e *exprpb.Expr) {
}
// Function name not declared, record error.
c.setType(e, types.ErrorType)
c.errors.undeclaredReference(e.GetId(), c.location(e), c.env.container.Name(), fnName)
c.errors.undeclaredReference(e.ID(), c.location(e), c.env.container.Name(), fnName)
}
func (c *checker) resolveOverloadOrError(
e *exprpb.Expr, fn *decls.FunctionDecl, target *exprpb.Expr, args []*exprpb.Expr) {
e ast.Expr, fn *decls.FunctionDecl, target ast.Expr, args []ast.Expr) {
// Attempt to resolve the overload.
resolution := c.resolveOverload(e, fn, target, args)
// No such overload, error noted in the resolveOverload call, type recorded here.
@@ -330,7 +284,7 @@ func (c *checker) resolveOverloadOrError(
}
func (c *checker) resolveOverload(
call *exprpb.Expr, fn *decls.FunctionDecl, target *exprpb.Expr, args []*exprpb.Expr) *overloadResolution {
call ast.Expr, fn *decls.FunctionDecl, target ast.Expr, args []ast.Expr) *overloadResolution {
var argTypes []*types.Type
if target != nil {
@@ -362,8 +316,8 @@ func (c *checker) resolveOverload(
for i, argType := range argTypes {
if !c.isAssignable(argType, types.BoolType) {
c.errors.typeMismatch(
args[i].GetId(),
c.locationByID(args[i].GetId()),
args[i].ID(),
c.locationByID(args[i].ID()),
types.BoolType,
argType)
resultType = types.ErrorType
@@ -408,29 +362,29 @@ func (c *checker) resolveOverload(
for i, argType := range argTypes {
argTypes[i] = substitute(c.mappings, argType, true)
}
c.errors.noMatchingOverload(call.GetId(), c.location(call), fn.Name(), argTypes, target != nil)
c.errors.noMatchingOverload(call.ID(), c.location(call), fn.Name(), argTypes, target != nil)
return nil
}
return newResolution(checkedRef, resultType)
}
func (c *checker) checkCreateList(e *exprpb.Expr) {
create := e.GetListExpr()
func (c *checker) checkCreateList(e ast.Expr) {
create := e.AsList()
var elemsType *types.Type
optionalIndices := create.GetOptionalIndices()
optionalIndices := create.OptionalIndices()
optionals := make(map[int32]bool, len(optionalIndices))
for _, optInd := range optionalIndices {
optionals[optInd] = true
}
for i, e := range create.GetElements() {
for i, e := range create.Elements() {
c.check(e)
elemType := c.getType(e)
if optionals[int32(i)] {
var isOptional bool
elemType, isOptional = maybeUnwrapOptional(elemType)
if !isOptional && !isDyn(elemType) {
c.errors.typeMismatch(e.GetId(), c.location(e), types.NewOptionalType(elemType), elemType)
c.errors.typeMismatch(e.ID(), c.location(e), types.NewOptionalType(elemType), elemType)
}
}
elemsType = c.joinTypes(e, elemsType, elemType)
@@ -442,32 +396,24 @@ func (c *checker) checkCreateList(e *exprpb.Expr) {
c.setType(e, types.NewListType(elemsType))
}
func (c *checker) checkCreateStruct(e *exprpb.Expr) {
str := e.GetStructExpr()
if str.GetMessageName() != "" {
c.checkCreateMessage(e)
} else {
c.checkCreateMap(e)
}
}
func (c *checker) checkCreateMap(e *exprpb.Expr) {
mapVal := e.GetStructExpr()
func (c *checker) checkCreateMap(e ast.Expr) {
mapVal := e.AsMap()
var mapKeyType *types.Type
var mapValueType *types.Type
for _, ent := range mapVal.GetEntries() {
key := ent.GetMapKey()
for _, e := range mapVal.Entries() {
entry := e.AsMapEntry()
key := entry.Key()
c.check(key)
mapKeyType = c.joinTypes(key, mapKeyType, c.getType(key))
val := ent.GetValue()
val := entry.Value()
c.check(val)
valType := c.getType(val)
if ent.GetOptionalEntry() {
if entry.IsOptional() {
var isOptional bool
valType, isOptional = maybeUnwrapOptional(valType)
if !isOptional && !isDyn(valType) {
c.errors.typeMismatch(val.GetId(), c.location(val), types.NewOptionalType(valType), valType)
c.errors.typeMismatch(val.ID(), c.location(val), types.NewOptionalType(valType), valType)
}
}
mapValueType = c.joinTypes(val, mapValueType, valType)
@@ -480,25 +426,28 @@ func (c *checker) checkCreateMap(e *exprpb.Expr) {
c.setType(e, types.NewMapType(mapKeyType, mapValueType))
}
func (c *checker) checkCreateMessage(e *exprpb.Expr) {
msgVal := e.GetStructExpr()
func (c *checker) checkCreateStruct(e ast.Expr) {
msgVal := e.AsStruct()
// Determine the type of the message.
resultType := types.ErrorType
ident := c.env.LookupIdent(msgVal.GetMessageName())
ident := c.env.LookupIdent(msgVal.TypeName())
if ident == nil {
c.errors.undeclaredReference(
e.GetId(), c.location(e), c.env.container.Name(), msgVal.GetMessageName())
e.ID(), c.location(e), c.env.container.Name(), msgVal.TypeName())
c.setType(e, types.ErrorType)
return
}
// Ensure the type name is fully qualified in the AST.
typeName := ident.Name()
msgVal.MessageName = typeName
c.setReference(e, ast.NewIdentReference(ident.Name(), nil))
if msgVal.TypeName() != typeName {
e.SetKindCase(c.NewStruct(e.ID(), typeName, msgVal.Fields()))
msgVal = e.AsStruct()
}
c.setReference(e, ast.NewIdentReference(typeName, nil))
identKind := ident.Type().Kind()
if identKind != types.ErrorKind {
if identKind != types.TypeKind {
c.errors.notAType(e.GetId(), c.location(e), ident.Type().DeclaredTypeName())
c.errors.notAType(e.ID(), c.location(e), ident.Type().DeclaredTypeName())
} else {
resultType = ident.Type().Parameters()[0]
// Backwards compatibility test between well-known types and message types
@@ -509,7 +458,7 @@ func (c *checker) checkCreateMessage(e *exprpb.Expr) {
} else if resultType.Kind() == types.StructKind {
typeName = resultType.DeclaredTypeName()
} else {
c.errors.notAMessageType(e.GetId(), c.location(e), resultType.DeclaredTypeName())
c.errors.notAMessageType(e.ID(), c.location(e), resultType.DeclaredTypeName())
resultType = types.ErrorType
}
}
@@ -517,37 +466,38 @@ func (c *checker) checkCreateMessage(e *exprpb.Expr) {
c.setType(e, resultType)
// Check the field initializers.
for _, ent := range msgVal.GetEntries() {
field := ent.GetFieldKey()
value := ent.GetValue()
for _, f := range msgVal.Fields() {
field := f.AsStructField()
fieldName := field.Name()
value := field.Value()
c.check(value)
fieldType := types.ErrorType
ft, found := c.lookupFieldType(ent.GetId(), typeName, field)
ft, found := c.lookupFieldType(f.ID(), typeName, fieldName)
if found {
fieldType = ft
}
valType := c.getType(value)
if ent.GetOptionalEntry() {
if field.IsOptional() {
var isOptional bool
valType, isOptional = maybeUnwrapOptional(valType)
if !isOptional && !isDyn(valType) {
c.errors.typeMismatch(value.GetId(), c.location(value), types.NewOptionalType(valType), valType)
c.errors.typeMismatch(value.ID(), c.location(value), types.NewOptionalType(valType), valType)
}
}
if !c.isAssignable(fieldType, valType) {
c.errors.fieldTypeMismatch(ent.GetId(), c.locationByID(ent.GetId()), field, fieldType, valType)
c.errors.fieldTypeMismatch(f.ID(), c.locationByID(f.ID()), fieldName, fieldType, valType)
}
}
}
func (c *checker) checkComprehension(e *exprpb.Expr) {
comp := e.GetComprehensionExpr()
c.check(comp.GetIterRange())
c.check(comp.GetAccuInit())
accuType := c.getType(comp.GetAccuInit())
rangeType := substitute(c.mappings, c.getType(comp.GetIterRange()), false)
func (c *checker) checkComprehension(e ast.Expr) {
comp := e.AsComprehension()
c.check(comp.IterRange())
c.check(comp.AccuInit())
accuType := c.getType(comp.AccuInit())
rangeType := substitute(c.mappings, c.getType(comp.IterRange()), false)
var varType *types.Type
switch rangeType.Kind() {
@@ -564,32 +514,32 @@ func (c *checker) checkComprehension(e *exprpb.Expr) {
// Set the range iteration variable to type DYN as well.
varType = types.DynType
default:
c.errors.notAComprehensionRange(comp.GetIterRange().GetId(), c.location(comp.GetIterRange()), rangeType)
c.errors.notAComprehensionRange(comp.IterRange().ID(), c.location(comp.IterRange()), rangeType)
varType = types.ErrorType
}
// Create a scope for the comprehension since it has a local accumulation variable.
// This scope will contain the accumulation variable used to compute the result.
c.env = c.env.enterScope()
c.env.AddIdents(decls.NewVariable(comp.GetAccuVar(), accuType))
c.env.AddIdents(decls.NewVariable(comp.AccuVar(), accuType))
// Create a block scope for the loop.
c.env = c.env.enterScope()
c.env.AddIdents(decls.NewVariable(comp.GetIterVar(), varType))
c.env.AddIdents(decls.NewVariable(comp.IterVar(), varType))
// Check the variable references in the condition and step.
c.check(comp.GetLoopCondition())
c.assertType(comp.GetLoopCondition(), types.BoolType)
c.check(comp.GetLoopStep())
c.assertType(comp.GetLoopStep(), accuType)
c.check(comp.LoopCondition())
c.assertType(comp.LoopCondition(), types.BoolType)
c.check(comp.LoopStep())
c.assertType(comp.LoopStep(), accuType)
// Exit the loop's block scope before checking the result.
c.env = c.env.exitScope()
c.check(comp.GetResult())
c.check(comp.Result())
// Exit the comprehension scope.
c.env = c.env.exitScope()
c.setType(e, substitute(c.mappings, c.getType(comp.GetResult()), false))
c.setType(e, substitute(c.mappings, c.getType(comp.Result()), false))
}
// Checks compatibility of joined types, and returns the most general common type.
func (c *checker) joinTypes(e *exprpb.Expr, previous, current *types.Type) *types.Type {
func (c *checker) joinTypes(e ast.Expr, previous, current *types.Type) *types.Type {
if previous == nil {
return current
}
@@ -599,7 +549,7 @@ func (c *checker) joinTypes(e *exprpb.Expr, previous, current *types.Type) *type
if c.dynAggregateLiteralElementTypesEnabled() {
return types.DynType
}
c.errors.typeMismatch(e.GetId(), c.location(e), previous, current)
c.errors.typeMismatch(e.ID(), c.location(e), previous, current)
return types.ErrorType
}
@@ -633,41 +583,41 @@ func (c *checker) isAssignableList(l1, l2 []*types.Type) bool {
return false
}
func maybeUnwrapString(e *exprpb.Expr) (string, bool) {
switch e.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
literal := e.GetConstExpr()
switch literal.GetConstantKind().(type) {
case *exprpb.Constant_StringValue:
return literal.GetStringValue(), true
func maybeUnwrapString(e ast.Expr) (string, bool) {
switch e.Kind() {
case ast.LiteralKind:
literal := e.AsLiteral()
switch v := literal.(type) {
case types.String:
return string(v), true
}
}
return "", false
}
func (c *checker) setType(e *exprpb.Expr, t *types.Type) {
if old, found := c.types[e.GetId()]; found && !old.IsExactType(t) {
c.errors.incompatibleType(e.GetId(), c.location(e), e, old, t)
func (c *checker) setType(e ast.Expr, t *types.Type) {
if old, found := c.TypeMap()[e.ID()]; found && !old.IsExactType(t) {
c.errors.incompatibleType(e.ID(), c.location(e), e, old, t)
return
}
c.types[e.GetId()] = t
c.SetType(e.ID(), t)
}
func (c *checker) getType(e *exprpb.Expr) *types.Type {
return c.types[e.GetId()]
func (c *checker) getType(e ast.Expr) *types.Type {
return c.TypeMap()[e.ID()]
}
func (c *checker) setReference(e *exprpb.Expr, r *ast.ReferenceInfo) {
if old, found := c.references[e.GetId()]; found && !old.Equals(r) {
c.errors.referenceRedefinition(e.GetId(), c.location(e), e, old, r)
func (c *checker) setReference(e ast.Expr, r *ast.ReferenceInfo) {
if old, found := c.ReferenceMap()[e.ID()]; found && !old.Equals(r) {
c.errors.referenceRedefinition(e.ID(), c.location(e), e, old, r)
return
}
c.references[e.GetId()] = r
c.SetReference(e.ID(), r)
}
func (c *checker) assertType(e *exprpb.Expr, t *types.Type) {
func (c *checker) assertType(e ast.Expr, t *types.Type) {
if !c.isAssignable(t, c.getType(e)) {
c.errors.typeMismatch(e.GetId(), c.location(e), t, c.getType(e))
c.errors.typeMismatch(e.ID(), c.location(e), t, c.getType(e))
}
}
@@ -683,26 +633,12 @@ func newResolution(r *ast.ReferenceInfo, t *types.Type) *overloadResolution {
}
}
func (c *checker) location(e *exprpb.Expr) common.Location {
return c.locationByID(e.GetId())
func (c *checker) location(e ast.Expr) common.Location {
return c.locationByID(e.ID())
}
func (c *checker) locationByID(id int64) common.Location {
positions := c.sourceInfo.GetPositions()
var line = 1
if offset, found := positions[id]; found {
col := int(offset)
for _, lineOffset := range c.sourceInfo.GetLineOffsets() {
if lineOffset < offset {
line++
col = int(offset - lineOffset)
} else {
break
}
}
return common.NewLocation(line, col)
}
return common.NoLocation
return c.SourceInfo().GetStartLocation(id)
}
func (c *checker) lookupFieldType(exprID int64, structType, fieldName string) (*types.Type, bool) {

View File

@@ -22,8 +22,6 @@ import (
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// WARNING: Any changes to cost calculations in this file require a corresponding change in interpreter/runtimecost.go
@@ -58,7 +56,7 @@ type AstNode interface {
// Type returns the deduced type of the AstNode.
Type() *types.Type
// Expr returns the expression of the AstNode.
Expr() *exprpb.Expr
Expr() ast.Expr
// ComputedSize returns a size estimate of the AstNode derived from information available in the CEL expression.
// For constants and inline list and map declarations, the exact size is returned. For concatenated list, strings
// and bytes, the size is derived from the size estimates of the operands. nil is returned if there is no
@@ -69,7 +67,7 @@ type AstNode interface {
type astNode struct {
path []string
t *types.Type
expr *exprpb.Expr
expr ast.Expr
derivedSize *SizeEstimate
}
@@ -81,7 +79,7 @@ func (e astNode) Type() *types.Type {
return e.t
}
func (e astNode) Expr() *exprpb.Expr {
func (e astNode) Expr() ast.Expr {
return e.expr
}
@@ -90,29 +88,27 @@ func (e astNode) ComputedSize() *SizeEstimate {
return e.derivedSize
}
var v uint64
switch ek := e.expr.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
switch ck := ek.ConstExpr.GetConstantKind().(type) {
case *exprpb.Constant_StringValue:
switch e.expr.Kind() {
case ast.LiteralKind:
switch ck := e.expr.AsLiteral().(type) {
case types.String:
// converting to runes here is an O(n) operation, but
// this is consistent with how size is computed at runtime,
// and how the language definition defines string size
v = uint64(len([]rune(ck.StringValue)))
case *exprpb.Constant_BytesValue:
v = uint64(len(ck.BytesValue))
case *exprpb.Constant_BoolValue, *exprpb.Constant_DoubleValue, *exprpb.Constant_DurationValue,
*exprpb.Constant_Int64Value, *exprpb.Constant_TimestampValue, *exprpb.Constant_Uint64Value,
*exprpb.Constant_NullValue:
v = uint64(len([]rune(ck)))
case types.Bytes:
v = uint64(len(ck))
case types.Bool, types.Double, types.Duration,
types.Int, types.Timestamp, types.Uint,
types.Null:
v = uint64(1)
default:
return nil
}
case *exprpb.Expr_ListExpr:
v = uint64(len(ek.ListExpr.GetElements()))
case *exprpb.Expr_StructExpr:
if ek.StructExpr.GetMessageName() == "" {
v = uint64(len(ek.StructExpr.GetEntries()))
}
case ast.ListKind:
v = uint64(e.expr.AsList().Size())
case ast.MapKind:
v = uint64(e.expr.AsMap().Size())
default:
return nil
}
@@ -265,7 +261,7 @@ type coster struct {
iterRanges iterRangeScopes
// computedSizes tracks the computed sizes of call results.
computedSizes map[int64]SizeEstimate
checkedAST *ast.CheckedAST
checkedAST *ast.AST
estimator CostEstimator
overloadEstimators map[string]FunctionEstimator
// presenceTestCost will either be a zero or one based on whether has() macros count against cost computations.
@@ -275,8 +271,8 @@ type coster struct {
// Use a stack of iterVar -> iterRange Expr Ids to handle shadowed variable names.
type iterRangeScopes map[string][]int64
func (vs iterRangeScopes) push(varName string, expr *exprpb.Expr) {
vs[varName] = append(vs[varName], expr.GetId())
func (vs iterRangeScopes) push(varName string, expr ast.Expr) {
vs[varName] = append(vs[varName], expr.ID())
}
func (vs iterRangeScopes) pop(varName string) {
@@ -324,9 +320,9 @@ func OverloadCostEstimate(overloadID string, functionCoster FunctionEstimator) C
}
// Cost estimates the cost of the parsed and type checked CEL expression.
func Cost(checker *ast.CheckedAST, estimator CostEstimator, opts ...CostOption) (CostEstimate, error) {
func Cost(checked *ast.AST, estimator CostEstimator, opts ...CostOption) (CostEstimate, error) {
c := &coster{
checkedAST: checker,
checkedAST: checked,
estimator: estimator,
overloadEstimators: map[string]FunctionEstimator{},
exprPath: map[int64][]string{},
@@ -340,28 +336,30 @@ func Cost(checker *ast.CheckedAST, estimator CostEstimator, opts ...CostOption)
return CostEstimate{}, err
}
}
return c.cost(checker.Expr), nil
return c.cost(checked.Expr()), nil
}
func (c *coster) cost(e *exprpb.Expr) CostEstimate {
func (c *coster) cost(e ast.Expr) CostEstimate {
if e == nil {
return CostEstimate{}
}
var cost CostEstimate
switch e.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
switch e.Kind() {
case ast.LiteralKind:
cost = constCost
case *exprpb.Expr_IdentExpr:
case ast.IdentKind:
cost = c.costIdent(e)
case *exprpb.Expr_SelectExpr:
case ast.SelectKind:
cost = c.costSelect(e)
case *exprpb.Expr_CallExpr:
case ast.CallKind:
cost = c.costCall(e)
case *exprpb.Expr_ListExpr:
case ast.ListKind:
cost = c.costCreateList(e)
case *exprpb.Expr_StructExpr:
case ast.MapKind:
cost = c.costCreateMap(e)
case ast.StructKind:
cost = c.costCreateStruct(e)
case *exprpb.Expr_ComprehensionExpr:
case ast.ComprehensionKind:
cost = c.costComprehension(e)
default:
return CostEstimate{}
@@ -369,53 +367,51 @@ func (c *coster) cost(e *exprpb.Expr) CostEstimate {
return cost
}
func (c *coster) costIdent(e *exprpb.Expr) CostEstimate {
identExpr := e.GetIdentExpr()
func (c *coster) costIdent(e ast.Expr) CostEstimate {
identName := e.AsIdent()
// build and track the field path
if iterRange, ok := c.iterRanges.peek(identExpr.GetName()); ok {
switch c.checkedAST.TypeMap[iterRange].Kind() {
if iterRange, ok := c.iterRanges.peek(identName); ok {
switch c.checkedAST.GetType(iterRange).Kind() {
case types.ListKind:
c.addPath(e, append(c.exprPath[iterRange], "@items"))
case types.MapKind:
c.addPath(e, append(c.exprPath[iterRange], "@keys"))
}
} else {
c.addPath(e, []string{identExpr.GetName()})
c.addPath(e, []string{identName})
}
return selectAndIdentCost
}
func (c *coster) costSelect(e *exprpb.Expr) CostEstimate {
sel := e.GetSelectExpr()
func (c *coster) costSelect(e ast.Expr) CostEstimate {
sel := e.AsSelect()
var sum CostEstimate
if sel.GetTestOnly() {
if sel.IsTestOnly() {
// recurse, but do not add any cost
// this is equivalent to how evalTestOnly increments the runtime cost counter
// but does not add any additional cost for the qualifier, except here we do
// the reverse (ident adds cost)
sum = sum.Add(c.presenceTestCost)
sum = sum.Add(c.cost(sel.GetOperand()))
sum = sum.Add(c.cost(sel.Operand()))
return sum
}
sum = sum.Add(c.cost(sel.GetOperand()))
targetType := c.getType(sel.GetOperand())
sum = sum.Add(c.cost(sel.Operand()))
targetType := c.getType(sel.Operand())
switch targetType.Kind() {
case types.MapKind, types.StructKind, types.TypeParamKind:
sum = sum.Add(selectAndIdentCost)
}
// build and track the field path
c.addPath(e, append(c.getPath(sel.GetOperand()), sel.GetField()))
c.addPath(e, append(c.getPath(sel.Operand()), sel.FieldName()))
return sum
}
func (c *coster) costCall(e *exprpb.Expr) CostEstimate {
call := e.GetCallExpr()
target := call.GetTarget()
args := call.GetArgs()
func (c *coster) costCall(e ast.Expr) CostEstimate {
call := e.AsCall()
args := call.Args()
var sum CostEstimate
@@ -426,22 +422,20 @@ func (c *coster) costCall(e *exprpb.Expr) CostEstimate {
argTypes[i] = c.newAstNode(arg)
}
ref := c.checkedAST.ReferenceMap[e.GetId()]
if ref == nil || len(ref.OverloadIDs) == 0 {
overloadIDs := c.checkedAST.GetOverloadIDs(e.ID())
if len(overloadIDs) == 0 {
return CostEstimate{}
}
var targetType AstNode
if target != nil {
if call.Target != nil {
sum = sum.Add(c.cost(call.GetTarget()))
targetType = c.newAstNode(call.GetTarget())
}
if call.IsMemberFunction() {
sum = sum.Add(c.cost(call.Target()))
targetType = c.newAstNode(call.Target())
}
// Pick a cost estimate range that covers all the overload cost estimation ranges
fnCost := CostEstimate{Min: uint64(math.MaxUint64), Max: 0}
var resultSize *SizeEstimate
for _, overload := range ref.OverloadIDs {
overloadCost := c.functionCost(call.GetFunction(), overload, &targetType, argTypes, argCosts)
for _, overload := range overloadIDs {
overloadCost := c.functionCost(call.FunctionName(), overload, &targetType, argTypes, argCosts)
fnCost = fnCost.Union(overloadCost.CostEstimate)
if overloadCost.ResultSize != nil {
if resultSize == nil {
@@ -464,64 +458,56 @@ func (c *coster) costCall(e *exprpb.Expr) CostEstimate {
}
}
if resultSize != nil {
c.computedSizes[e.GetId()] = *resultSize
c.computedSizes[e.ID()] = *resultSize
}
return sum.Add(fnCost)
}
func (c *coster) costCreateList(e *exprpb.Expr) CostEstimate {
create := e.GetListExpr()
func (c *coster) costCreateList(e ast.Expr) CostEstimate {
create := e.AsList()
var sum CostEstimate
for _, e := range create.GetElements() {
for _, e := range create.Elements() {
sum = sum.Add(c.cost(e))
}
return sum.Add(createListBaseCost)
}
func (c *coster) costCreateStruct(e *exprpb.Expr) CostEstimate {
str := e.GetStructExpr()
if str.MessageName != "" {
return c.costCreateMessage(e)
}
return c.costCreateMap(e)
}
func (c *coster) costCreateMap(e *exprpb.Expr) CostEstimate {
mapVal := e.GetStructExpr()
func (c *coster) costCreateMap(e ast.Expr) CostEstimate {
mapVal := e.AsMap()
var sum CostEstimate
for _, ent := range mapVal.GetEntries() {
key := ent.GetMapKey()
sum = sum.Add(c.cost(key))
sum = sum.Add(c.cost(ent.GetValue()))
for _, ent := range mapVal.Entries() {
entry := ent.AsMapEntry()
sum = sum.Add(c.cost(entry.Key()))
sum = sum.Add(c.cost(entry.Value()))
}
return sum.Add(createMapBaseCost)
}
func (c *coster) costCreateMessage(e *exprpb.Expr) CostEstimate {
msgVal := e.GetStructExpr()
func (c *coster) costCreateStruct(e ast.Expr) CostEstimate {
msgVal := e.AsStruct()
var sum CostEstimate
for _, ent := range msgVal.GetEntries() {
sum = sum.Add(c.cost(ent.GetValue()))
for _, ent := range msgVal.Fields() {
field := ent.AsStructField()
sum = sum.Add(c.cost(field.Value()))
}
return sum.Add(createMessageBaseCost)
}
func (c *coster) costComprehension(e *exprpb.Expr) CostEstimate {
comp := e.GetComprehensionExpr()
func (c *coster) costComprehension(e ast.Expr) CostEstimate {
comp := e.AsComprehension()
var sum CostEstimate
sum = sum.Add(c.cost(comp.GetIterRange()))
sum = sum.Add(c.cost(comp.GetAccuInit()))
sum = sum.Add(c.cost(comp.IterRange()))
sum = sum.Add(c.cost(comp.AccuInit()))
// Track the iterRange of each IterVar for field path construction
c.iterRanges.push(comp.GetIterVar(), comp.GetIterRange())
loopCost := c.cost(comp.GetLoopCondition())
stepCost := c.cost(comp.GetLoopStep())
c.iterRanges.pop(comp.GetIterVar())
sum = sum.Add(c.cost(comp.Result))
rangeCnt := c.sizeEstimate(c.newAstNode(comp.GetIterRange()))
c.iterRanges.push(comp.IterVar(), comp.IterRange())
loopCost := c.cost(comp.LoopCondition())
stepCost := c.cost(comp.LoopStep())
c.iterRanges.pop(comp.IterVar())
sum = sum.Add(c.cost(comp.Result()))
rangeCnt := c.sizeEstimate(c.newAstNode(comp.IterRange()))
c.computedSizes[e.GetId()] = rangeCnt
c.computedSizes[e.ID()] = rangeCnt
rangeCost := rangeCnt.MultiplyByCost(stepCost.Add(loopCost))
sum = sum.Add(rangeCost)
@@ -674,26 +660,26 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args
return CallEstimate{CostEstimate: CostEstimate{Min: 1, Max: 1}.Add(argCostSum())}
}
func (c *coster) getType(e *exprpb.Expr) *types.Type {
return c.checkedAST.TypeMap[e.GetId()]
func (c *coster) getType(e ast.Expr) *types.Type {
return c.checkedAST.GetType(e.ID())
}
func (c *coster) getPath(e *exprpb.Expr) []string {
return c.exprPath[e.GetId()]
func (c *coster) getPath(e ast.Expr) []string {
return c.exprPath[e.ID()]
}
func (c *coster) addPath(e *exprpb.Expr, path []string) {
c.exprPath[e.GetId()] = path
func (c *coster) addPath(e ast.Expr, path []string) {
c.exprPath[e.ID()] = path
}
func (c *coster) newAstNode(e *exprpb.Expr) *astNode {
func (c *coster) newAstNode(e ast.Expr) *astNode {
path := c.getPath(e)
if len(path) > 0 && path[0] == parser.AccumulatorName {
// only provide paths to root vars; omit accumulator vars
path = nil
}
var derivedSize *SizeEstimate
if size, ok := c.computedSizes[e.GetId()]; ok {
if size, ok := c.computedSizes[e.ID()]; ok {
derivedSize = &size
}
return &astNode{

View File

@@ -67,7 +67,7 @@ func NewAbstractType(name string, paramTypes ...*exprpb.Type) *exprpb.Type {
// NewOptionalType constructs an abstract type indicating that the parameterized type
// may be contained within the object.
func NewOptionalType(paramType *exprpb.Type) *exprpb.Type {
return NewAbstractType("optional", paramType)
return NewAbstractType("optional_type", paramType)
}
// NewFunctionType creates a function invocation contract, typically only used

View File

@@ -146,6 +146,14 @@ func (e *Env) LookupIdent(name string) *decls.VariableDecl {
return decl
}
if i, found := e.provider.FindIdent(candidate); found {
if t, ok := i.(*types.Type); ok {
decl := decls.NewVariable(candidate, types.NewTypeTypeWithParam(t))
e.declarations.AddIdent(decl)
return decl
}
}
// Next try to import this as an enum value by splitting the name in a type prefix and
// the enum inside.
if enumValue := e.provider.EnumValue(candidate); enumValue.Type() != types.ErrType {

View File

@@ -15,13 +15,9 @@
package checker
import (
"reflect"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// typeErrors is a specialization of Errors.
@@ -34,9 +30,9 @@ func (e *typeErrors) fieldTypeMismatch(id int64, l common.Location, name string,
name, FormatCELType(field), FormatCELType(value))
}
func (e *typeErrors) incompatibleType(id int64, l common.Location, ex *exprpb.Expr, prev, next *types.Type) {
func (e *typeErrors) incompatibleType(id int64, l common.Location, ex ast.Expr, prev, next *types.Type) {
e.errs.ReportErrorAtID(id, l,
"incompatible type already exists for expression: %v(%d) old:%v, new:%v", ex, ex.GetId(), prev, next)
"incompatible type already exists for expression: %v(%d) old:%v, new:%v", ex, ex.ID(), prev, next)
}
func (e *typeErrors) noMatchingOverload(id int64, l common.Location, name string, args []*types.Type, isInstance bool) {
@@ -49,7 +45,7 @@ func (e *typeErrors) notAComprehensionRange(id int64, l common.Location, t *type
FormatCELType(t))
}
func (e *typeErrors) notAnOptionalFieldSelection(id int64, l common.Location, field *exprpb.Expr) {
func (e *typeErrors) notAnOptionalFieldSelection(id int64, l common.Location, field ast.Expr) {
e.errs.ReportErrorAtID(id, l, "unsupported optional field selection: %v", field)
}
@@ -61,9 +57,9 @@ func (e *typeErrors) notAMessageType(id int64, l common.Location, typeName strin
e.errs.ReportErrorAtID(id, l, "'%s' is not a message type", typeName)
}
func (e *typeErrors) referenceRedefinition(id int64, l common.Location, ex *exprpb.Expr, prev, next *ast.ReferenceInfo) {
func (e *typeErrors) referenceRedefinition(id int64, l common.Location, ex ast.Expr, prev, next *ast.ReferenceInfo) {
e.errs.ReportErrorAtID(id, l,
"reference already exists for expression: %v(%d) old:%v, new:%v", ex, ex.GetId(), prev, next)
"reference already exists for expression: %v(%d) old:%v, new:%v", ex, ex.ID(), prev, next)
}
func (e *typeErrors) typeDoesNotSupportFieldSelection(id int64, l common.Location, t *types.Type) {
@@ -87,6 +83,6 @@ func (e *typeErrors) unexpectedFailedResolution(id int64, l common.Location, typ
e.errs.ReportErrorAtID(id, l, "unexpected failed resolution of '%s'", typeName)
}
func (e *typeErrors) unexpectedASTType(id int64, l common.Location, ex *exprpb.Expr) {
e.errs.ReportErrorAtID(id, l, "unrecognized ast type: %v", reflect.TypeOf(ex))
func (e *typeErrors) unexpectedASTType(id int64, l common.Location, kind, typeName string) {
e.errs.ReportErrorAtID(id, l, "unexpected %s type: %v", kind, typeName)
}

View File

@@ -17,40 +17,40 @@ package checker
import (
"sort"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/debug"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
type semanticAdorner struct {
checks *exprpb.CheckedExpr
checked *ast.AST
}
var _ debug.Adorner = &semanticAdorner{}
func (a *semanticAdorner) GetMetadata(elem any) string {
result := ""
e, isExpr := elem.(*exprpb.Expr)
e, isExpr := elem.(ast.Expr)
if !isExpr {
return result
}
t := a.checks.TypeMap[e.GetId()]
t := a.checked.TypeMap()[e.ID()]
if t != nil {
result += "~"
result += FormatCheckedType(t)
result += FormatCELType(t)
}
switch e.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr,
*exprpb.Expr_CallExpr,
*exprpb.Expr_StructExpr,
*exprpb.Expr_SelectExpr:
if ref, found := a.checks.ReferenceMap[e.GetId()]; found {
if len(ref.GetOverloadId()) == 0 {
switch e.Kind() {
case ast.IdentKind,
ast.CallKind,
ast.ListKind,
ast.StructKind,
ast.SelectKind:
if ref, found := a.checked.ReferenceMap()[e.ID()]; found {
if len(ref.OverloadIDs) == 0 {
result += "^" + ref.Name
} else {
sort.Strings(ref.GetOverloadId())
for i, overload := range ref.GetOverloadId() {
sort.Strings(ref.OverloadIDs)
for i, overload := range ref.OverloadIDs {
if i == 0 {
result += "^"
} else {
@@ -68,7 +68,7 @@ func (a *semanticAdorner) GetMetadata(elem any) string {
// Print returns a string representation of the Expr message,
// annotated with types from the CheckedExpr. The Expr must
// be a sub-expression embedded in the CheckedExpr.
func Print(e *exprpb.Expr, checks *exprpb.CheckedExpr) string {
a := &semanticAdorner{checks: checks}
func Print(e ast.Expr, checked *ast.AST) string {
a := &semanticAdorner{checked: checked}
return debug.ToAdornedDebugString(e, a)
}

View File

@@ -41,7 +41,7 @@ func isError(t *types.Type) bool {
func isOptional(t *types.Type) bool {
if t.Kind() == types.OpaqueKind {
return t.TypeName() == "optional"
return t.TypeName() == "optional_type"
}
return false
}
@@ -137,7 +137,11 @@ func internalIsAssignable(m *mapping, t1, t2 *types.Type) bool {
case types.BoolKind, types.BytesKind, types.DoubleKind, types.IntKind, types.StringKind, types.UintKind,
types.AnyKind, types.DurationKind, types.TimestampKind,
types.StructKind:
return t1.IsAssignableType(t2)
// Test whether t2 is assignable from t1. The order of this check won't usually matter;
// however, there may be cases where type capabilities are expanded beyond what is supported
// in the current common/types package. For example, an interface designation for a group of
// Struct types.
return t2.IsAssignableType(t1)
case types.TypeKind:
return kind2 == types.TypeKind
case types.OpaqueKind, types.ListKind, types.MapKind:
@@ -256,7 +260,7 @@ func notReferencedIn(m *mapping, t, withinType *types.Type) bool {
return true
}
return notReferencedIn(m, t, wtSub)
case types.OpaqueKind, types.ListKind, types.MapKind:
case types.OpaqueKind, types.ListKind, types.MapKind, types.TypeKind:
for _, pt := range withinType.Parameters() {
if !notReferencedIn(m, t, pt) {
return false
@@ -288,7 +292,8 @@ func substitute(m *mapping, t *types.Type, typeParamToDyn bool) *types.Type {
substitute(m, t.Parameters()[1], typeParamToDyn))
case types.TypeKind:
if len(t.Parameters()) > 0 {
return types.NewTypeTypeWithParam(substitute(m, t.Parameters()[0], typeParamToDyn))
tParam := t.Parameters()[0]
return types.NewTypeTypeWithParam(substitute(m, tParam, typeParamToDyn))
}
return t
default:

View File

@@ -1,12 +1,7 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
package(
default_visibility = [
"//cel:__subpackages__",
"//checker:__subpackages__",
"//common:__subpackages__",
"//interpreter:__subpackages__",
],
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
@@ -14,10 +9,14 @@ go_library(
name = "go_default_library",
srcs = [
"ast.go",
"conversion.go",
"expr.go",
"factory.go",
"navigable.go",
],
importpath = "github.com/google/cel-go/common/ast",
deps = [
"//common:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
@@ -29,7 +28,9 @@ go_test(
name = "go_default_test",
srcs = [
"ast_test.go",
"conversion_test.go",
"expr_test.go",
"navigable_test.go",
],
embed = [
":go_default_library",
@@ -48,5 +49,6 @@ go_test(
"//test/proto3pb:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//encoding/prototext:go_default_library",
],
)
)

View File

@@ -16,74 +16,361 @@
package ast
import (
"fmt"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
structpb "google.golang.org/protobuf/types/known/structpb"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// CheckedAST contains a protobuf expression and source info along with CEL-native type and reference information.
type CheckedAST struct {
Expr *exprpb.Expr
SourceInfo *exprpb.SourceInfo
TypeMap map[int64]*types.Type
ReferenceMap map[int64]*ReferenceInfo
// AST contains a protobuf expression and source info along with CEL-native type and reference information.
type AST struct {
expr Expr
sourceInfo *SourceInfo
typeMap map[int64]*types.Type
refMap map[int64]*ReferenceInfo
}
// CheckedASTToCheckedExpr converts a CheckedAST to a CheckedExpr protobouf.
func CheckedASTToCheckedExpr(ast *CheckedAST) (*exprpb.CheckedExpr, error) {
refMap := make(map[int64]*exprpb.Reference, len(ast.ReferenceMap))
for id, ref := range ast.ReferenceMap {
r, err := ReferenceInfoToReferenceExpr(ref)
if err != nil {
return nil, err
}
refMap[id] = r
// Expr returns the root ast.Expr value in the AST.
func (a *AST) Expr() Expr {
if a == nil {
return nilExpr
}
typeMap := make(map[int64]*exprpb.Type, len(ast.TypeMap))
for id, typ := range ast.TypeMap {
t, err := types.TypeToExprType(typ)
if err != nil {
return nil, err
}
typeMap[id] = t
}
return &exprpb.CheckedExpr{
Expr: ast.Expr,
SourceInfo: ast.SourceInfo,
ReferenceMap: refMap,
TypeMap: typeMap,
}, nil
return a.expr
}
// CheckedExprToCheckedAST converts a CheckedExpr protobuf to a CheckedAST instance.
func CheckedExprToCheckedAST(checked *exprpb.CheckedExpr) (*CheckedAST, error) {
refMap := make(map[int64]*ReferenceInfo, len(checked.GetReferenceMap()))
for id, ref := range checked.GetReferenceMap() {
r, err := ReferenceExprToReferenceInfo(ref)
if err != nil {
return nil, err
}
refMap[id] = r
// SourceInfo returns the source metadata associated with the parse / type-check passes.
func (a *AST) SourceInfo() *SourceInfo {
if a == nil {
return nil
}
typeMap := make(map[int64]*types.Type, len(checked.GetTypeMap()))
for id, typ := range checked.GetTypeMap() {
t, err := types.ExprTypeToType(typ)
if err != nil {
return nil, err
}
typeMap[id] = t
return a.sourceInfo
}
// GetType returns the type for the expression at the given id, if one exists, else types.DynType.
func (a *AST) GetType(id int64) *types.Type {
if t, found := a.TypeMap()[id]; found {
return t
}
return &CheckedAST{
Expr: checked.GetExpr(),
SourceInfo: checked.GetSourceInfo(),
ReferenceMap: refMap,
TypeMap: typeMap,
}, nil
return types.DynType
}
// SetType sets the type of the expression node at the given id.
func (a *AST) SetType(id int64, t *types.Type) {
if a == nil {
return
}
a.typeMap[id] = t
}
// TypeMap returns the map of expression ids to type-checked types.
//
// If the AST is not type-checked, the map will be empty.
func (a *AST) TypeMap() map[int64]*types.Type {
if a == nil {
return map[int64]*types.Type{}
}
return a.typeMap
}
// GetOverloadIDs returns the set of overload function names for a given expression id.
//
// If the expression id is not a function call, or the AST is not type-checked, the result will be empty.
func (a *AST) GetOverloadIDs(id int64) []string {
if ref, found := a.ReferenceMap()[id]; found {
return ref.OverloadIDs
}
return []string{}
}
// ReferenceMap returns the map of expression id to identifier, constant, and function references.
func (a *AST) ReferenceMap() map[int64]*ReferenceInfo {
if a == nil {
return map[int64]*ReferenceInfo{}
}
return a.refMap
}
// SetReference adds a reference to the checked AST type map.
func (a *AST) SetReference(id int64, r *ReferenceInfo) {
if a == nil {
return
}
a.refMap[id] = r
}
// IsChecked returns whether the AST is type-checked.
func (a *AST) IsChecked() bool {
return a != nil && len(a.TypeMap()) > 0
}
// NewAST creates a base AST instance with an ast.Expr and ast.SourceInfo value.
func NewAST(e Expr, sourceInfo *SourceInfo) *AST {
if e == nil {
e = nilExpr
}
return &AST{
expr: e,
sourceInfo: sourceInfo,
typeMap: make(map[int64]*types.Type),
refMap: make(map[int64]*ReferenceInfo),
}
}
// NewCheckedAST wraps an parsed AST and augments it with type and reference metadata.
func NewCheckedAST(parsed *AST, typeMap map[int64]*types.Type, refMap map[int64]*ReferenceInfo) *AST {
return &AST{
expr: parsed.Expr(),
sourceInfo: parsed.SourceInfo(),
typeMap: typeMap,
refMap: refMap,
}
}
// Copy creates a deep copy of the Expr and SourceInfo values in the input AST.
//
// Copies of the Expr value are generated using an internal default ExprFactory.
func Copy(a *AST) *AST {
if a == nil {
return nil
}
e := defaultFactory.CopyExpr(a.expr)
if !a.IsChecked() {
return NewAST(e, CopySourceInfo(a.SourceInfo()))
}
typesCopy := make(map[int64]*types.Type, len(a.typeMap))
for id, t := range a.typeMap {
typesCopy[id] = t
}
refsCopy := make(map[int64]*ReferenceInfo, len(a.refMap))
for id, r := range a.refMap {
refsCopy[id] = r
}
return NewCheckedAST(NewAST(e, CopySourceInfo(a.SourceInfo())), typesCopy, refsCopy)
}
// MaxID returns the upper-bound, non-inclusive, of ids present within the AST's Expr value.
func MaxID(a *AST) int64 {
visitor := &maxIDVisitor{maxID: 1}
PostOrderVisit(a.Expr(), visitor)
for id, call := range a.SourceInfo().MacroCalls() {
PostOrderVisit(call, visitor)
if id > visitor.maxID {
visitor.maxID = id + 1
}
}
return visitor.maxID + 1
}
// NewSourceInfo creates a simple SourceInfo object from an input common.Source value.
func NewSourceInfo(src common.Source) *SourceInfo {
var lineOffsets []int32
var desc string
baseLine := int32(0)
baseCol := int32(0)
if src != nil {
desc = src.Description()
lineOffsets = src.LineOffsets()
// Determine whether the source metadata should be computed relative
// to a base line and column value. This can be determined by requesting
// the location for offset 0 from the source object.
if loc, found := src.OffsetLocation(0); found {
baseLine = int32(loc.Line()) - 1
baseCol = int32(loc.Column())
}
}
return &SourceInfo{
desc: desc,
lines: lineOffsets,
baseLine: baseLine,
baseCol: baseCol,
offsetRanges: make(map[int64]OffsetRange),
macroCalls: make(map[int64]Expr),
}
}
// CopySourceInfo creates a deep copy of the MacroCalls within the input SourceInfo.
//
// Copies of macro Expr values are generated using an internal default ExprFactory.
func CopySourceInfo(info *SourceInfo) *SourceInfo {
if info == nil {
return nil
}
rangesCopy := make(map[int64]OffsetRange, len(info.offsetRanges))
for id, off := range info.offsetRanges {
rangesCopy[id] = off
}
callsCopy := make(map[int64]Expr, len(info.macroCalls))
for id, call := range info.macroCalls {
callsCopy[id] = defaultFactory.CopyExpr(call)
}
return &SourceInfo{
syntax: info.syntax,
desc: info.desc,
lines: info.lines,
baseLine: info.baseLine,
baseCol: info.baseCol,
offsetRanges: rangesCopy,
macroCalls: callsCopy,
}
}
// SourceInfo records basic information about the expression as a textual input and
// as a parsed expression value.
type SourceInfo struct {
syntax string
desc string
lines []int32
baseLine int32
baseCol int32
offsetRanges map[int64]OffsetRange
macroCalls map[int64]Expr
}
// SyntaxVersion returns the syntax version associated with the text expression.
func (s *SourceInfo) SyntaxVersion() string {
if s == nil {
return ""
}
return s.syntax
}
// Description provides information about where the expression came from.
func (s *SourceInfo) Description() string {
if s == nil {
return ""
}
return s.desc
}
// LineOffsets returns a list of the 0-based character offsets in the input text where newlines appear.
func (s *SourceInfo) LineOffsets() []int32 {
if s == nil {
return []int32{}
}
return s.lines
}
// MacroCalls returns a map of expression id to ast.Expr value where the id represents the expression
// node where the macro was inserted into the AST, and the ast.Expr value represents the original call
// signature which was replaced.
func (s *SourceInfo) MacroCalls() map[int64]Expr {
if s == nil {
return map[int64]Expr{}
}
return s.macroCalls
}
// GetMacroCall returns the original ast.Expr value for the given expression if it was generated via
// a macro replacement.
//
// Note, parsing options must be enabled to track macro calls before this method will return a value.
func (s *SourceInfo) GetMacroCall(id int64) (Expr, bool) {
e, found := s.MacroCalls()[id]
return e, found
}
// SetMacroCall records a macro call at a specific location.
func (s *SourceInfo) SetMacroCall(id int64, e Expr) {
if s != nil {
s.macroCalls[id] = e
}
}
// ClearMacroCall removes the macro call at the given expression id.
func (s *SourceInfo) ClearMacroCall(id int64) {
if s != nil {
delete(s.macroCalls, id)
}
}
// OffsetRanges returns a map of expression id to OffsetRange values where the range indicates either:
// the start and end position in the input stream where the expression occurs, or the start position
// only. If the range only captures start position, the stop position of the range will be equal to
// the start.
func (s *SourceInfo) OffsetRanges() map[int64]OffsetRange {
if s == nil {
return map[int64]OffsetRange{}
}
return s.offsetRanges
}
// GetOffsetRange retrieves an OffsetRange for the given expression id if one exists.
func (s *SourceInfo) GetOffsetRange(id int64) (OffsetRange, bool) {
if s == nil {
return OffsetRange{}, false
}
o, found := s.offsetRanges[id]
return o, found
}
// SetOffsetRange sets the OffsetRange for the given expression id.
func (s *SourceInfo) SetOffsetRange(id int64, o OffsetRange) {
if s == nil {
return
}
s.offsetRanges[id] = o
}
// GetStartLocation calculates the human-readable 1-based line and 0-based column of the first character
// of the expression node at the id.
func (s *SourceInfo) GetStartLocation(id int64) common.Location {
if o, found := s.GetOffsetRange(id); found {
line := 1
col := int(o.Start)
for _, lineOffset := range s.LineOffsets() {
if lineOffset < o.Start {
line++
col = int(o.Start - lineOffset)
} else {
break
}
}
return common.NewLocation(line, col)
}
return common.NoLocation
}
// GetStopLocation calculates the human-readable 1-based line and 0-based column of the last character for
// the expression node at the given id.
//
// If the SourceInfo was generated from a serialized protobuf representation, the stop location will
// be identical to the start location for the expression.
func (s *SourceInfo) GetStopLocation(id int64) common.Location {
if o, found := s.GetOffsetRange(id); found {
line := 1
col := int(o.Stop)
for _, lineOffset := range s.LineOffsets() {
if lineOffset < o.Stop {
line++
col = int(o.Stop - lineOffset)
} else {
break
}
}
return common.NewLocation(line, col)
}
return common.NoLocation
}
// ComputeOffset calculates the 0-based character offset from a 1-based line and 0-based column.
func (s *SourceInfo) ComputeOffset(line, col int32) int32 {
if s != nil {
line = s.baseLine + line
col = s.baseCol + col
}
if line == 1 {
return col
}
if line < 1 || line > int32(len(s.LineOffsets())) {
return -1
}
offset := s.LineOffsets()[line-2]
return offset + col
}
// OffsetRange captures the start and stop positions of a section of text in the input expression.
type OffsetRange struct {
Start int32
Stop int32
}
// ReferenceInfo contains a CEL native representation of an identifier reference which may refer to
@@ -149,78 +436,21 @@ func (r *ReferenceInfo) Equals(other *ReferenceInfo) bool {
return true
}
// ReferenceInfoToReferenceExpr converts a ReferenceInfo instance to a protobuf Reference suitable for serialization.
func ReferenceInfoToReferenceExpr(info *ReferenceInfo) (*exprpb.Reference, error) {
c, err := ValToConstant(info.Value)
if err != nil {
return nil, err
}
return &exprpb.Reference{
Name: info.Name,
OverloadId: info.OverloadIDs,
Value: c,
}, nil
type maxIDVisitor struct {
maxID int64
*baseVisitor
}
// ReferenceExprToReferenceInfo converts a protobuf Reference into a CEL-native ReferenceInfo instance.
func ReferenceExprToReferenceInfo(ref *exprpb.Reference) (*ReferenceInfo, error) {
v, err := ConstantToVal(ref.GetValue())
if err != nil {
return nil, err
// VisitExpr updates the max identifier if the incoming expression id is greater than previously observed.
func (v *maxIDVisitor) VisitExpr(e Expr) {
if v.maxID < e.ID() {
v.maxID = e.ID()
}
return &ReferenceInfo{
Name: ref.GetName(),
OverloadIDs: ref.GetOverloadId(),
Value: v,
}, nil
}
// ValToConstant converts a CEL-native ref.Val to a protobuf Constant.
//
// Only simple scalar types are supported by this method.
func ValToConstant(v ref.Val) (*exprpb.Constant, error) {
if v == nil {
return nil, nil
// VisitEntryExpr updates the max identifier if the incoming entry id is greater than previously observed.
func (v *maxIDVisitor) VisitEntryExpr(e EntryExpr) {
if v.maxID < e.ID() {
v.maxID = e.ID()
}
switch v.Type() {
case types.BoolType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: v.Value().(bool)}}, nil
case types.BytesType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: v.Value().([]byte)}}, nil
case types.DoubleType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: v.Value().(float64)}}, nil
case types.IntType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: v.Value().(int64)}}, nil
case types.NullType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: structpb.NullValue_NULL_VALUE}}, nil
case types.StringType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: v.Value().(string)}}, nil
case types.UintType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: v.Value().(uint64)}}, nil
}
return nil, fmt.Errorf("unsupported constant kind: %v", v.Type())
}
// ConstantToVal converts a protobuf Constant to a CEL-native ref.Val.
func ConstantToVal(c *exprpb.Constant) (ref.Val, error) {
if c == nil {
return nil, nil
}
switch c.GetConstantKind().(type) {
case *exprpb.Constant_BoolValue:
return types.Bool(c.GetBoolValue()), nil
case *exprpb.Constant_BytesValue:
return types.Bytes(c.GetBytesValue()), nil
case *exprpb.Constant_DoubleValue:
return types.Double(c.GetDoubleValue()), nil
case *exprpb.Constant_Int64Value:
return types.Int(c.GetInt64Value()), nil
case *exprpb.Constant_NullValue:
return types.NullValue, nil
case *exprpb.Constant_StringValue:
return types.String(c.GetStringValue()), nil
case *exprpb.Constant_Uint64Value:
return types.Uint(c.GetUint64Value()), nil
}
return nil, fmt.Errorf("unsupported constant kind: %v", c.GetConstantKind())
}

View File

@@ -0,0 +1,632 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ast
import (
"fmt"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
structpb "google.golang.org/protobuf/types/known/structpb"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// ToProto converts an AST to a CheckedExpr protobouf.
func ToProto(ast *AST) (*exprpb.CheckedExpr, error) {
refMap := make(map[int64]*exprpb.Reference, len(ast.ReferenceMap()))
for id, ref := range ast.ReferenceMap() {
r, err := ReferenceInfoToProto(ref)
if err != nil {
return nil, err
}
refMap[id] = r
}
typeMap := make(map[int64]*exprpb.Type, len(ast.TypeMap()))
for id, typ := range ast.TypeMap() {
t, err := types.TypeToExprType(typ)
if err != nil {
return nil, err
}
typeMap[id] = t
}
e, err := ExprToProto(ast.Expr())
if err != nil {
return nil, err
}
info, err := SourceInfoToProto(ast.SourceInfo())
if err != nil {
return nil, err
}
return &exprpb.CheckedExpr{
Expr: e,
SourceInfo: info,
ReferenceMap: refMap,
TypeMap: typeMap,
}, nil
}
// ToAST converts a CheckedExpr protobuf to an AST instance.
func ToAST(checked *exprpb.CheckedExpr) (*AST, error) {
refMap := make(map[int64]*ReferenceInfo, len(checked.GetReferenceMap()))
for id, ref := range checked.GetReferenceMap() {
r, err := ProtoToReferenceInfo(ref)
if err != nil {
return nil, err
}
refMap[id] = r
}
typeMap := make(map[int64]*types.Type, len(checked.GetTypeMap()))
for id, typ := range checked.GetTypeMap() {
t, err := types.ExprTypeToType(typ)
if err != nil {
return nil, err
}
typeMap[id] = t
}
info, err := ProtoToSourceInfo(checked.GetSourceInfo())
if err != nil {
return nil, err
}
root, err := ProtoToExpr(checked.GetExpr())
if err != nil {
return nil, err
}
ast := NewCheckedAST(NewAST(root, info), typeMap, refMap)
return ast, nil
}
// ProtoToExpr converts a protobuf Expr value to an ast.Expr value.
func ProtoToExpr(e *exprpb.Expr) (Expr, error) {
factory := NewExprFactory()
return exprInternal(factory, e)
}
// ProtoToEntryExpr converts a protobuf struct/map entry to an ast.EntryExpr
func ProtoToEntryExpr(e *exprpb.Expr_CreateStruct_Entry) (EntryExpr, error) {
factory := NewExprFactory()
switch e.GetKeyKind().(type) {
case *exprpb.Expr_CreateStruct_Entry_FieldKey:
return exprStructField(factory, e.GetId(), e)
case *exprpb.Expr_CreateStruct_Entry_MapKey:
return exprMapEntry(factory, e.GetId(), e)
}
return nil, fmt.Errorf("unsupported expr entry kind: %v", e)
}
func exprInternal(factory ExprFactory, e *exprpb.Expr) (Expr, error) {
id := e.GetId()
switch e.GetExprKind().(type) {
case *exprpb.Expr_CallExpr:
return exprCall(factory, id, e.GetCallExpr())
case *exprpb.Expr_ComprehensionExpr:
return exprComprehension(factory, id, e.GetComprehensionExpr())
case *exprpb.Expr_ConstExpr:
return exprLiteral(factory, id, e.GetConstExpr())
case *exprpb.Expr_IdentExpr:
return exprIdent(factory, id, e.GetIdentExpr())
case *exprpb.Expr_ListExpr:
return exprList(factory, id, e.GetListExpr())
case *exprpb.Expr_SelectExpr:
return exprSelect(factory, id, e.GetSelectExpr())
case *exprpb.Expr_StructExpr:
s := e.GetStructExpr()
if s.GetMessageName() != "" {
return exprStruct(factory, id, s)
}
return exprMap(factory, id, s)
}
return factory.NewUnspecifiedExpr(id), nil
}
func exprCall(factory ExprFactory, id int64, call *exprpb.Expr_Call) (Expr, error) {
var err error
args := make([]Expr, len(call.GetArgs()))
for i, a := range call.GetArgs() {
args[i], err = exprInternal(factory, a)
if err != nil {
return nil, err
}
}
if call.GetTarget() == nil {
return factory.NewCall(id, call.GetFunction(), args...), nil
}
target, err := exprInternal(factory, call.GetTarget())
if err != nil {
return nil, err
}
return factory.NewMemberCall(id, call.GetFunction(), target, args...), nil
}
func exprComprehension(factory ExprFactory, id int64, comp *exprpb.Expr_Comprehension) (Expr, error) {
iterRange, err := exprInternal(factory, comp.GetIterRange())
if err != nil {
return nil, err
}
accuInit, err := exprInternal(factory, comp.GetAccuInit())
if err != nil {
return nil, err
}
loopCond, err := exprInternal(factory, comp.GetLoopCondition())
if err != nil {
return nil, err
}
loopStep, err := exprInternal(factory, comp.GetLoopStep())
if err != nil {
return nil, err
}
result, err := exprInternal(factory, comp.GetResult())
if err != nil {
return nil, err
}
return factory.NewComprehension(id,
iterRange,
comp.GetIterVar(),
comp.GetAccuVar(),
accuInit,
loopCond,
loopStep,
result), nil
}
func exprLiteral(factory ExprFactory, id int64, c *exprpb.Constant) (Expr, error) {
val, err := ConstantToVal(c)
if err != nil {
return nil, err
}
return factory.NewLiteral(id, val), nil
}
func exprIdent(factory ExprFactory, id int64, i *exprpb.Expr_Ident) (Expr, error) {
return factory.NewIdent(id, i.GetName()), nil
}
func exprList(factory ExprFactory, id int64, l *exprpb.Expr_CreateList) (Expr, error) {
elems := make([]Expr, len(l.GetElements()))
for i, e := range l.GetElements() {
elem, err := exprInternal(factory, e)
if err != nil {
return nil, err
}
elems[i] = elem
}
return factory.NewList(id, elems, l.GetOptionalIndices()), nil
}
func exprMap(factory ExprFactory, id int64, s *exprpb.Expr_CreateStruct) (Expr, error) {
entries := make([]EntryExpr, len(s.GetEntries()))
var err error
for i, entry := range s.GetEntries() {
entries[i], err = exprMapEntry(factory, entry.GetId(), entry)
if err != nil {
return nil, err
}
}
return factory.NewMap(id, entries), nil
}
func exprMapEntry(factory ExprFactory, id int64, e *exprpb.Expr_CreateStruct_Entry) (EntryExpr, error) {
k, err := exprInternal(factory, e.GetMapKey())
if err != nil {
return nil, err
}
v, err := exprInternal(factory, e.GetValue())
if err != nil {
return nil, err
}
return factory.NewMapEntry(id, k, v, e.GetOptionalEntry()), nil
}
func exprSelect(factory ExprFactory, id int64, s *exprpb.Expr_Select) (Expr, error) {
op, err := exprInternal(factory, s.GetOperand())
if err != nil {
return nil, err
}
if s.GetTestOnly() {
return factory.NewPresenceTest(id, op, s.GetField()), nil
}
return factory.NewSelect(id, op, s.GetField()), nil
}
func exprStruct(factory ExprFactory, id int64, s *exprpb.Expr_CreateStruct) (Expr, error) {
fields := make([]EntryExpr, len(s.GetEntries()))
var err error
for i, field := range s.GetEntries() {
fields[i], err = exprStructField(factory, field.GetId(), field)
if err != nil {
return nil, err
}
}
return factory.NewStruct(id, s.GetMessageName(), fields), nil
}
func exprStructField(factory ExprFactory, id int64, f *exprpb.Expr_CreateStruct_Entry) (EntryExpr, error) {
v, err := exprInternal(factory, f.GetValue())
if err != nil {
return nil, err
}
return factory.NewStructField(id, f.GetFieldKey(), v, f.GetOptionalEntry()), nil
}
// ExprToProto serializes an ast.Expr value to a protobuf Expr representation.
func ExprToProto(e Expr) (*exprpb.Expr, error) {
if e == nil {
return &exprpb.Expr{}, nil
}
switch e.Kind() {
case CallKind:
return protoCall(e.ID(), e.AsCall())
case ComprehensionKind:
return protoComprehension(e.ID(), e.AsComprehension())
case IdentKind:
return protoIdent(e.ID(), e.AsIdent())
case ListKind:
return protoList(e.ID(), e.AsList())
case LiteralKind:
return protoLiteral(e.ID(), e.AsLiteral())
case MapKind:
return protoMap(e.ID(), e.AsMap())
case SelectKind:
return protoSelect(e.ID(), e.AsSelect())
case StructKind:
return protoStruct(e.ID(), e.AsStruct())
case UnspecifiedExprKind:
// Handle the case where a macro reference may be getting translated.
// A nested macro 'pointer' is a non-zero expression id with no kind set.
if e.ID() != 0 {
return &exprpb.Expr{Id: e.ID()}, nil
}
return &exprpb.Expr{}, nil
}
return nil, fmt.Errorf("unsupported expr kind: %v", e)
}
// EntryExprToProto converts an ast.EntryExpr to a protobuf CreateStruct entry
func EntryExprToProto(e EntryExpr) (*exprpb.Expr_CreateStruct_Entry, error) {
switch e.Kind() {
case MapEntryKind:
return protoMapEntry(e.ID(), e.AsMapEntry())
case StructFieldKind:
return protoStructField(e.ID(), e.AsStructField())
case UnspecifiedEntryExprKind:
return &exprpb.Expr_CreateStruct_Entry{}, nil
}
return nil, fmt.Errorf("unsupported expr entry kind: %v", e)
}
func protoCall(id int64, call CallExpr) (*exprpb.Expr, error) {
var err error
var target *exprpb.Expr
if call.IsMemberFunction() {
target, err = ExprToProto(call.Target())
if err != nil {
return nil, err
}
}
callArgs := call.Args()
args := make([]*exprpb.Expr, len(callArgs))
for i, a := range callArgs {
args[i], err = ExprToProto(a)
if err != nil {
return nil, err
}
}
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{
Function: call.FunctionName(),
Target: target,
Args: args,
},
},
}, nil
}
func protoComprehension(id int64, comp ComprehensionExpr) (*exprpb.Expr, error) {
iterRange, err := ExprToProto(comp.IterRange())
if err != nil {
return nil, err
}
accuInit, err := ExprToProto(comp.AccuInit())
if err != nil {
return nil, err
}
loopCond, err := ExprToProto(comp.LoopCondition())
if err != nil {
return nil, err
}
loopStep, err := ExprToProto(comp.LoopStep())
if err != nil {
return nil, err
}
result, err := ExprToProto(comp.Result())
if err != nil {
return nil, err
}
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_ComprehensionExpr{
ComprehensionExpr: &exprpb.Expr_Comprehension{
IterVar: comp.IterVar(),
IterRange: iterRange,
AccuVar: comp.AccuVar(),
AccuInit: accuInit,
LoopCondition: loopCond,
LoopStep: loopStep,
Result: result,
},
},
}, nil
}
func protoIdent(id int64, name string) (*exprpb.Expr, error) {
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_IdentExpr{
IdentExpr: &exprpb.Expr_Ident{
Name: name,
},
},
}, nil
}
func protoList(id int64, list ListExpr) (*exprpb.Expr, error) {
var err error
elems := make([]*exprpb.Expr, list.Size())
for i, e := range list.Elements() {
elems[i], err = ExprToProto(e)
if err != nil {
return nil, err
}
}
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_ListExpr{
ListExpr: &exprpb.Expr_CreateList{
Elements: elems,
OptionalIndices: list.OptionalIndices(),
},
},
}, nil
}
func protoLiteral(id int64, val ref.Val) (*exprpb.Expr, error) {
c, err := ValToConstant(val)
if err != nil {
return nil, err
}
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_ConstExpr{
ConstExpr: c,
},
}, nil
}
func protoMap(id int64, m MapExpr) (*exprpb.Expr, error) {
entries := make([]*exprpb.Expr_CreateStruct_Entry, len(m.Entries()))
var err error
for i, e := range m.Entries() {
entries[i], err = EntryExprToProto(e)
if err != nil {
return nil, err
}
}
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_StructExpr{
StructExpr: &exprpb.Expr_CreateStruct{
Entries: entries,
},
},
}, nil
}
func protoMapEntry(id int64, e MapEntry) (*exprpb.Expr_CreateStruct_Entry, error) {
k, err := ExprToProto(e.Key())
if err != nil {
return nil, err
}
v, err := ExprToProto(e.Value())
if err != nil {
return nil, err
}
return &exprpb.Expr_CreateStruct_Entry{
Id: id,
KeyKind: &exprpb.Expr_CreateStruct_Entry_MapKey{
MapKey: k,
},
Value: v,
OptionalEntry: e.IsOptional(),
}, nil
}
func protoSelect(id int64, s SelectExpr) (*exprpb.Expr, error) {
op, err := ExprToProto(s.Operand())
if err != nil {
return nil, err
}
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_SelectExpr{
SelectExpr: &exprpb.Expr_Select{
Operand: op,
Field: s.FieldName(),
TestOnly: s.IsTestOnly(),
},
},
}, nil
}
func protoStruct(id int64, s StructExpr) (*exprpb.Expr, error) {
entries := make([]*exprpb.Expr_CreateStruct_Entry, len(s.Fields()))
var err error
for i, e := range s.Fields() {
entries[i], err = EntryExprToProto(e)
if err != nil {
return nil, err
}
}
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_StructExpr{
StructExpr: &exprpb.Expr_CreateStruct{
MessageName: s.TypeName(),
Entries: entries,
},
},
}, nil
}
func protoStructField(id int64, f StructField) (*exprpb.Expr_CreateStruct_Entry, error) {
v, err := ExprToProto(f.Value())
if err != nil {
return nil, err
}
return &exprpb.Expr_CreateStruct_Entry{
Id: id,
KeyKind: &exprpb.Expr_CreateStruct_Entry_FieldKey{
FieldKey: f.Name(),
},
Value: v,
OptionalEntry: f.IsOptional(),
}, nil
}
// SourceInfoToProto serializes an ast.SourceInfo value to a protobuf SourceInfo object.
func SourceInfoToProto(info *SourceInfo) (*exprpb.SourceInfo, error) {
if info == nil {
return &exprpb.SourceInfo{}, nil
}
sourceInfo := &exprpb.SourceInfo{
SyntaxVersion: info.SyntaxVersion(),
Location: info.Description(),
LineOffsets: info.LineOffsets(),
Positions: make(map[int64]int32, len(info.OffsetRanges())),
MacroCalls: make(map[int64]*exprpb.Expr, len(info.MacroCalls())),
}
for id, offset := range info.OffsetRanges() {
sourceInfo.Positions[id] = offset.Start
}
for id, e := range info.MacroCalls() {
call, err := ExprToProto(e)
if err != nil {
return nil, err
}
sourceInfo.MacroCalls[id] = call
}
return sourceInfo, nil
}
// ProtoToSourceInfo deserializes the protobuf into a native SourceInfo value.
func ProtoToSourceInfo(info *exprpb.SourceInfo) (*SourceInfo, error) {
sourceInfo := &SourceInfo{
syntax: info.GetSyntaxVersion(),
desc: info.GetLocation(),
lines: info.GetLineOffsets(),
offsetRanges: make(map[int64]OffsetRange, len(info.GetPositions())),
macroCalls: make(map[int64]Expr, len(info.GetMacroCalls())),
}
for id, offset := range info.GetPositions() {
sourceInfo.SetOffsetRange(id, OffsetRange{Start: offset, Stop: offset})
}
for id, e := range info.GetMacroCalls() {
call, err := ProtoToExpr(e)
if err != nil {
return nil, err
}
sourceInfo.SetMacroCall(id, call)
}
return sourceInfo, nil
}
// ReferenceInfoToProto converts a ReferenceInfo instance to a protobuf Reference suitable for serialization.
func ReferenceInfoToProto(info *ReferenceInfo) (*exprpb.Reference, error) {
c, err := ValToConstant(info.Value)
if err != nil {
return nil, err
}
return &exprpb.Reference{
Name: info.Name,
OverloadId: info.OverloadIDs,
Value: c,
}, nil
}
// ProtoToReferenceInfo converts a protobuf Reference into a CEL-native ReferenceInfo instance.
func ProtoToReferenceInfo(ref *exprpb.Reference) (*ReferenceInfo, error) {
v, err := ConstantToVal(ref.GetValue())
if err != nil {
return nil, err
}
return &ReferenceInfo{
Name: ref.GetName(),
OverloadIDs: ref.GetOverloadId(),
Value: v,
}, nil
}
// ValToConstant converts a CEL-native ref.Val to a protobuf Constant.
//
// Only simple scalar types are supported by this method.
func ValToConstant(v ref.Val) (*exprpb.Constant, error) {
if v == nil {
return nil, nil
}
switch v.Type() {
case types.BoolType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: v.Value().(bool)}}, nil
case types.BytesType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: v.Value().([]byte)}}, nil
case types.DoubleType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: v.Value().(float64)}}, nil
case types.IntType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: v.Value().(int64)}}, nil
case types.NullType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: structpb.NullValue_NULL_VALUE}}, nil
case types.StringType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: v.Value().(string)}}, nil
case types.UintType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: v.Value().(uint64)}}, nil
}
return nil, fmt.Errorf("unsupported constant kind: %v", v.Type())
}
// ConstantToVal converts a protobuf Constant to a CEL-native ref.Val.
func ConstantToVal(c *exprpb.Constant) (ref.Val, error) {
if c == nil {
return nil, nil
}
switch c.GetConstantKind().(type) {
case *exprpb.Constant_BoolValue:
return types.Bool(c.GetBoolValue()), nil
case *exprpb.Constant_BytesValue:
return types.Bytes(c.GetBytesValue()), nil
case *exprpb.Constant_DoubleValue:
return types.Double(c.GetDoubleValue()), nil
case *exprpb.Constant_Int64Value:
return types.Int(c.GetInt64Value()), nil
case *exprpb.Constant_NullValue:
return types.NullValue, nil
case *exprpb.Constant_StringValue:
return types.String(c.GetStringValue()), nil
case *exprpb.Constant_Uint64Value:
return types.Uint(c.GetUint64Value()), nil
}
return nil, fmt.Errorf("unsupported constant kind: %v", c.GetConstantKind())
}

File diff suppressed because it is too large Load Diff

303
vendor/github.com/google/cel-go/common/ast/factory.go generated vendored Normal file
View File

@@ -0,0 +1,303 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ast
import "github.com/google/cel-go/common/types/ref"
// ExprFactory interfaces defines a set of methods necessary for building native expression values.
type ExprFactory interface {
// CopyExpr creates a deep copy of the input Expr value.
CopyExpr(Expr) Expr
// CopyEntryExpr creates a deep copy of the input EntryExpr value.
CopyEntryExpr(EntryExpr) EntryExpr
// NewCall creates an Expr value representing a global function call.
NewCall(id int64, function string, args ...Expr) Expr
// NewComprehension creates an Expr value representing a comprehension over a value range.
NewComprehension(id int64, iterRange Expr, iterVar, accuVar string, accuInit, loopCondition, loopStep, result Expr) Expr
// NewMemberCall creates an Expr value representing a member function call.
NewMemberCall(id int64, function string, receiver Expr, args ...Expr) Expr
// NewIdent creates an Expr value representing an identifier.
NewIdent(id int64, name string) Expr
// NewAccuIdent creates an Expr value representing an accumulator identifier within a
//comprehension.
NewAccuIdent(id int64) Expr
// NewLiteral creates an Expr value representing a literal value, such as a string or integer.
NewLiteral(id int64, value ref.Val) Expr
// NewList creates an Expr value representing a list literal expression with optional indices.
//
// Optional indicies will typically be empty unless the CEL optional types are enabled.
NewList(id int64, elems []Expr, optIndices []int32) Expr
// NewMap creates an Expr value representing a map literal expression
NewMap(id int64, entries []EntryExpr) Expr
// NewMapEntry creates a MapEntry with a given key, value, and a flag indicating whether
// the key is optionally set.
NewMapEntry(id int64, key, value Expr, isOptional bool) EntryExpr
// NewPresenceTest creates an Expr representing a field presence test on an operand expression.
NewPresenceTest(id int64, operand Expr, field string) Expr
// NewSelect creates an Expr representing a field selection on an operand expression.
NewSelect(id int64, operand Expr, field string) Expr
// NewStruct creates an Expr value representing a struct literal with a given type name and a
// set of field initializers.
NewStruct(id int64, typeName string, fields []EntryExpr) Expr
// NewStructField creates a StructField with a given field name, value, and a flag indicating
// whether the field is optionally set.
NewStructField(id int64, field string, value Expr, isOptional bool) EntryExpr
// NewUnspecifiedExpr creates an empty expression node.
NewUnspecifiedExpr(id int64) Expr
isExprFactory()
}
type baseExprFactory struct{}
// NewExprFactory creates an ExprFactory instance.
func NewExprFactory() ExprFactory {
return &baseExprFactory{}
}
func (fac *baseExprFactory) NewCall(id int64, function string, args ...Expr) Expr {
if len(args) == 0 {
args = []Expr{}
}
return fac.newExpr(
id,
&baseCallExpr{
function: function,
target: nilExpr,
args: args,
isMember: false,
})
}
func (fac *baseExprFactory) NewMemberCall(id int64, function string, target Expr, args ...Expr) Expr {
if len(args) == 0 {
args = []Expr{}
}
return fac.newExpr(
id,
&baseCallExpr{
function: function,
target: target,
args: args,
isMember: true,
})
}
func (fac *baseExprFactory) NewComprehension(id int64, iterRange Expr, iterVar, accuVar string, accuInit, loopCond, loopStep, result Expr) Expr {
return fac.newExpr(
id,
&baseComprehensionExpr{
iterRange: iterRange,
iterVar: iterVar,
accuVar: accuVar,
accuInit: accuInit,
loopCond: loopCond,
loopStep: loopStep,
result: result,
})
}
func (fac *baseExprFactory) NewIdent(id int64, name string) Expr {
return fac.newExpr(id, baseIdentExpr(name))
}
func (fac *baseExprFactory) NewAccuIdent(id int64) Expr {
return fac.NewIdent(id, "__result__")
}
func (fac *baseExprFactory) NewLiteral(id int64, value ref.Val) Expr {
return fac.newExpr(id, &baseLiteral{Val: value})
}
func (fac *baseExprFactory) NewList(id int64, elems []Expr, optIndices []int32) Expr {
optIndexMap := make(map[int32]struct{}, len(optIndices))
for _, idx := range optIndices {
optIndexMap[idx] = struct{}{}
}
return fac.newExpr(id,
&baseListExpr{
elements: elems,
optIndices: optIndices,
optIndexMap: optIndexMap,
})
}
func (fac *baseExprFactory) NewMap(id int64, entries []EntryExpr) Expr {
return fac.newExpr(id, &baseMapExpr{entries: entries})
}
func (fac *baseExprFactory) NewMapEntry(id int64, key, value Expr, isOptional bool) EntryExpr {
return fac.newEntryExpr(
id,
&baseMapEntry{
key: key,
value: value,
isOptional: isOptional,
})
}
func (fac *baseExprFactory) NewPresenceTest(id int64, operand Expr, field string) Expr {
return fac.newExpr(
id,
&baseSelectExpr{
operand: operand,
field: field,
testOnly: true,
})
}
func (fac *baseExprFactory) NewSelect(id int64, operand Expr, field string) Expr {
return fac.newExpr(
id,
&baseSelectExpr{
operand: operand,
field: field,
})
}
func (fac *baseExprFactory) NewStruct(id int64, typeName string, fields []EntryExpr) Expr {
return fac.newExpr(
id,
&baseStructExpr{
typeName: typeName,
fields: fields,
})
}
func (fac *baseExprFactory) NewStructField(id int64, field string, value Expr, isOptional bool) EntryExpr {
return fac.newEntryExpr(
id,
&baseStructField{
field: field,
value: value,
isOptional: isOptional,
})
}
func (fac *baseExprFactory) NewUnspecifiedExpr(id int64) Expr {
return fac.newExpr(id, nil)
}
func (fac *baseExprFactory) CopyExpr(e Expr) Expr {
// unwrap navigable expressions to avoid unnecessary allocations during copying.
if nav, ok := e.(*navigableExprImpl); ok {
e = nav.Expr
}
switch e.Kind() {
case CallKind:
c := e.AsCall()
argsCopy := make([]Expr, len(c.Args()))
for i, arg := range c.Args() {
argsCopy[i] = fac.CopyExpr(arg)
}
if !c.IsMemberFunction() {
return fac.NewCall(e.ID(), c.FunctionName(), argsCopy...)
}
return fac.NewMemberCall(e.ID(), c.FunctionName(), fac.CopyExpr(c.Target()), argsCopy...)
case ComprehensionKind:
compre := e.AsComprehension()
return fac.NewComprehension(e.ID(),
fac.CopyExpr(compre.IterRange()),
compre.IterVar(),
compre.AccuVar(),
fac.CopyExpr(compre.AccuInit()),
fac.CopyExpr(compre.LoopCondition()),
fac.CopyExpr(compre.LoopStep()),
fac.CopyExpr(compre.Result()))
case IdentKind:
return fac.NewIdent(e.ID(), e.AsIdent())
case ListKind:
l := e.AsList()
elemsCopy := make([]Expr, l.Size())
for i, elem := range l.Elements() {
elemsCopy[i] = fac.CopyExpr(elem)
}
return fac.NewList(e.ID(), elemsCopy, l.OptionalIndices())
case LiteralKind:
return fac.NewLiteral(e.ID(), e.AsLiteral())
case MapKind:
m := e.AsMap()
entriesCopy := make([]EntryExpr, m.Size())
for i, entry := range m.Entries() {
entriesCopy[i] = fac.CopyEntryExpr(entry)
}
return fac.NewMap(e.ID(), entriesCopy)
case SelectKind:
s := e.AsSelect()
if s.IsTestOnly() {
return fac.NewPresenceTest(e.ID(), fac.CopyExpr(s.Operand()), s.FieldName())
}
return fac.NewSelect(e.ID(), fac.CopyExpr(s.Operand()), s.FieldName())
case StructKind:
s := e.AsStruct()
fieldsCopy := make([]EntryExpr, len(s.Fields()))
for i, field := range s.Fields() {
fieldsCopy[i] = fac.CopyEntryExpr(field)
}
return fac.NewStruct(e.ID(), s.TypeName(), fieldsCopy)
default:
return fac.NewUnspecifiedExpr(e.ID())
}
}
func (fac *baseExprFactory) CopyEntryExpr(e EntryExpr) EntryExpr {
switch e.Kind() {
case MapEntryKind:
entry := e.AsMapEntry()
return fac.NewMapEntry(e.ID(),
fac.CopyExpr(entry.Key()), fac.CopyExpr(entry.Value()), entry.IsOptional())
case StructFieldKind:
field := e.AsStructField()
return fac.NewStructField(e.ID(),
field.Name(), fac.CopyExpr(field.Value()), field.IsOptional())
default:
return fac.newEntryExpr(e.ID(), nil)
}
}
func (*baseExprFactory) isExprFactory() {}
func (fac *baseExprFactory) newExpr(id int64, e exprKindCase) Expr {
return &expr{
id: id,
exprKindCase: e,
}
}
func (fac *baseExprFactory) newEntryExpr(id int64, e entryExprKindCase) EntryExpr {
return &entryExpr{
id: id,
entryExprKindCase: e,
}
}
var (
defaultFactory = &baseExprFactory{}
)

652
vendor/github.com/google/cel-go/common/ast/navigable.go generated vendored Normal file
View File

@@ -0,0 +1,652 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ast
import (
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
// NavigableExpr represents the base navigable expression value with methods to inspect the
// parent and child expressions.
type NavigableExpr interface {
Expr
// Type of the expression.
//
// If the expression is type-checked, the type check metadata is returned. If the expression
// has not been type-checked, the types.DynType value is returned.
Type() *types.Type
// Parent returns the parent expression node, if one exists.
Parent() (NavigableExpr, bool)
// Children returns a list of child expression nodes.
Children() []NavigableExpr
// Depth indicates the depth in the expression tree.
//
// The root expression has depth 0.
Depth() int
}
// NavigateAST converts an AST to a NavigableExpr
func NavigateAST(ast *AST) NavigableExpr {
return NavigateExpr(ast, ast.Expr())
}
// NavigateExpr creates a NavigableExpr whose type information is backed by the input AST.
//
// If the expression is already a NavigableExpr, the parent and depth information will be
// propagated on the new NavigableExpr value; otherwise, the expr value will be treated
// as though it is the root of the expression graph with a depth of 0.
func NavigateExpr(ast *AST, expr Expr) NavigableExpr {
depth := 0
var parent NavigableExpr = nil
if nav, ok := expr.(NavigableExpr); ok {
depth = nav.Depth()
parent, _ = nav.Parent()
}
return newNavigableExpr(ast, parent, expr, depth)
}
// ExprMatcher takes a NavigableExpr in and indicates whether the value is a match.
//
// This function type should be use with the `Match` and `MatchList` calls.
type ExprMatcher func(NavigableExpr) bool
// ConstantValueMatcher returns an ExprMatcher which will return true if the input NavigableExpr
// is comprised of all constant values, such as a simple literal or even list and map literal.
func ConstantValueMatcher() ExprMatcher {
return matchIsConstantValue
}
// KindMatcher returns an ExprMatcher which will return true if the input NavigableExpr.Kind() matches
// the specified `kind`.
func KindMatcher(kind ExprKind) ExprMatcher {
return func(e NavigableExpr) bool {
return e.Kind() == kind
}
}
// FunctionMatcher returns an ExprMatcher which will match NavigableExpr nodes of CallKind type whose
// function name is equal to `funcName`.
func FunctionMatcher(funcName string) ExprMatcher {
return func(e NavigableExpr) bool {
if e.Kind() != CallKind {
return false
}
return e.AsCall().FunctionName() == funcName
}
}
// AllMatcher returns true for all descendants of a NavigableExpr, effectively flattening them into a list.
//
// Such a result would work well with subsequent MatchList calls.
func AllMatcher() ExprMatcher {
return func(NavigableExpr) bool {
return true
}
}
// MatchDescendants takes a NavigableExpr and ExprMatcher and produces a list of NavigableExpr values
// matching the input criteria in post-order (bottom up).
func MatchDescendants(expr NavigableExpr, matcher ExprMatcher) []NavigableExpr {
matches := []NavigableExpr{}
navVisitor := &baseVisitor{
visitExpr: func(e Expr) {
nav := e.(NavigableExpr)
if matcher(nav) {
matches = append(matches, nav)
}
},
}
visit(expr, navVisitor, postOrder, 0, 0)
return matches
}
// MatchSubset applies an ExprMatcher to a list of NavigableExpr values and their descendants, producing a
// subset of NavigableExpr values which match.
func MatchSubset(exprs []NavigableExpr, matcher ExprMatcher) []NavigableExpr {
matches := []NavigableExpr{}
navVisitor := &baseVisitor{
visitExpr: func(e Expr) {
nav := e.(NavigableExpr)
if matcher(nav) {
matches = append(matches, nav)
}
},
}
for _, expr := range exprs {
visit(expr, navVisitor, postOrder, 0, 1)
}
return matches
}
// Visitor defines an object for visiting Expr and EntryExpr nodes within an expression graph.
type Visitor interface {
// VisitExpr visits the input expression.
VisitExpr(Expr)
// VisitEntryExpr visits the input entry expression, i.e. a struct field or map entry.
VisitEntryExpr(EntryExpr)
}
type baseVisitor struct {
visitExpr func(Expr)
visitEntryExpr func(EntryExpr)
}
// VisitExpr visits the Expr if the internal expr visitor has been configured.
func (v *baseVisitor) VisitExpr(e Expr) {
if v.visitExpr != nil {
v.visitExpr(e)
}
}
// VisitEntryExpr visits the entry if the internal expr entry visitor has been configured.
func (v *baseVisitor) VisitEntryExpr(e EntryExpr) {
if v.visitEntryExpr != nil {
v.visitEntryExpr(e)
}
}
// NewExprVisitor creates a visitor which only visits expression nodes.
func NewExprVisitor(v func(Expr)) Visitor {
return &baseVisitor{
visitExpr: v,
visitEntryExpr: nil,
}
}
// PostOrderVisit walks the expression graph and calls the visitor in post-order (bottom-up).
func PostOrderVisit(expr Expr, visitor Visitor) {
visit(expr, visitor, postOrder, 0, 0)
}
// PreOrderVisit walks the expression graph and calls the visitor in pre-order (top-down).
func PreOrderVisit(expr Expr, visitor Visitor) {
visit(expr, visitor, preOrder, 0, 0)
}
type visitOrder int
const (
preOrder = iota + 1
postOrder
)
// TODO: consider exposing a way to configure a limit for the max visit depth.
// It's possible that we could want to configure this on the NewExprVisitor()
// and through MatchDescendents() / MaxID().
func visit(expr Expr, visitor Visitor, order visitOrder, depth, maxDepth int) {
if maxDepth > 0 && depth == maxDepth {
return
}
if order == preOrder {
visitor.VisitExpr(expr)
}
switch expr.Kind() {
case CallKind:
c := expr.AsCall()
if c.IsMemberFunction() {
visit(c.Target(), visitor, order, depth+1, maxDepth)
}
for _, arg := range c.Args() {
visit(arg, visitor, order, depth+1, maxDepth)
}
case ComprehensionKind:
c := expr.AsComprehension()
visit(c.IterRange(), visitor, order, depth+1, maxDepth)
visit(c.AccuInit(), visitor, order, depth+1, maxDepth)
visit(c.LoopCondition(), visitor, order, depth+1, maxDepth)
visit(c.LoopStep(), visitor, order, depth+1, maxDepth)
visit(c.Result(), visitor, order, depth+1, maxDepth)
case ListKind:
l := expr.AsList()
for _, elem := range l.Elements() {
visit(elem, visitor, order, depth+1, maxDepth)
}
case MapKind:
m := expr.AsMap()
for _, e := range m.Entries() {
if order == preOrder {
visitor.VisitEntryExpr(e)
}
entry := e.AsMapEntry()
visit(entry.Key(), visitor, order, depth+1, maxDepth)
visit(entry.Value(), visitor, order, depth+1, maxDepth)
if order == postOrder {
visitor.VisitEntryExpr(e)
}
}
case SelectKind:
visit(expr.AsSelect().Operand(), visitor, order, depth+1, maxDepth)
case StructKind:
s := expr.AsStruct()
for _, f := range s.Fields() {
visitor.VisitEntryExpr(f)
visit(f.AsStructField().Value(), visitor, order, depth+1, maxDepth)
}
}
if order == postOrder {
visitor.VisitExpr(expr)
}
}
func matchIsConstantValue(e NavigableExpr) bool {
if e.Kind() == LiteralKind {
return true
}
if e.Kind() == StructKind || e.Kind() == MapKind || e.Kind() == ListKind {
for _, child := range e.Children() {
if !matchIsConstantValue(child) {
return false
}
}
return true
}
return false
}
func newNavigableExpr(ast *AST, parent NavigableExpr, expr Expr, depth int) NavigableExpr {
// Reduce navigable expression nesting by unwrapping the embedded Expr value.
if nav, ok := expr.(*navigableExprImpl); ok {
expr = nav.Expr
}
nav := &navigableExprImpl{
Expr: expr,
depth: depth,
ast: ast,
parent: parent,
createChildren: getChildFactory(expr),
}
return nav
}
type navigableExprImpl struct {
Expr
depth int
ast *AST
parent NavigableExpr
createChildren childFactory
}
func (nav *navigableExprImpl) Parent() (NavigableExpr, bool) {
if nav.parent != nil {
return nav.parent, true
}
return nil, false
}
func (nav *navigableExprImpl) ID() int64 {
return nav.Expr.ID()
}
func (nav *navigableExprImpl) Kind() ExprKind {
return nav.Expr.Kind()
}
func (nav *navigableExprImpl) Type() *types.Type {
return nav.ast.GetType(nav.ID())
}
func (nav *navigableExprImpl) Children() []NavigableExpr {
return nav.createChildren(nav)
}
func (nav *navigableExprImpl) Depth() int {
return nav.depth
}
func (nav *navigableExprImpl) AsCall() CallExpr {
return navigableCallImpl{navigableExprImpl: nav}
}
func (nav *navigableExprImpl) AsComprehension() ComprehensionExpr {
return navigableComprehensionImpl{navigableExprImpl: nav}
}
func (nav *navigableExprImpl) AsIdent() string {
return nav.Expr.AsIdent()
}
func (nav *navigableExprImpl) AsList() ListExpr {
return navigableListImpl{navigableExprImpl: nav}
}
func (nav *navigableExprImpl) AsLiteral() ref.Val {
return nav.Expr.AsLiteral()
}
func (nav *navigableExprImpl) AsMap() MapExpr {
return navigableMapImpl{navigableExprImpl: nav}
}
func (nav *navigableExprImpl) AsSelect() SelectExpr {
return navigableSelectImpl{navigableExprImpl: nav}
}
func (nav *navigableExprImpl) AsStruct() StructExpr {
return navigableStructImpl{navigableExprImpl: nav}
}
func (nav *navigableExprImpl) createChild(e Expr) NavigableExpr {
return newNavigableExpr(nav.ast, nav, e, nav.depth+1)
}
func (nav *navigableExprImpl) isExpr() {}
type navigableCallImpl struct {
*navigableExprImpl
}
func (call navigableCallImpl) FunctionName() string {
return call.Expr.AsCall().FunctionName()
}
func (call navigableCallImpl) IsMemberFunction() bool {
return call.Expr.AsCall().IsMemberFunction()
}
func (call navigableCallImpl) Target() Expr {
t := call.Expr.AsCall().Target()
if t != nil {
return call.createChild(t)
}
return nil
}
func (call navigableCallImpl) Args() []Expr {
args := call.Expr.AsCall().Args()
navArgs := make([]Expr, len(args))
for i, a := range args {
navArgs[i] = call.createChild(a)
}
return navArgs
}
type navigableComprehensionImpl struct {
*navigableExprImpl
}
func (comp navigableComprehensionImpl) IterRange() Expr {
return comp.createChild(comp.Expr.AsComprehension().IterRange())
}
func (comp navigableComprehensionImpl) IterVar() string {
return comp.Expr.AsComprehension().IterVar()
}
func (comp navigableComprehensionImpl) AccuVar() string {
return comp.Expr.AsComprehension().AccuVar()
}
func (comp navigableComprehensionImpl) AccuInit() Expr {
return comp.createChild(comp.Expr.AsComprehension().AccuInit())
}
func (comp navigableComprehensionImpl) LoopCondition() Expr {
return comp.createChild(comp.Expr.AsComprehension().LoopCondition())
}
func (comp navigableComprehensionImpl) LoopStep() Expr {
return comp.createChild(comp.Expr.AsComprehension().LoopStep())
}
func (comp navigableComprehensionImpl) Result() Expr {
return comp.createChild(comp.Expr.AsComprehension().Result())
}
type navigableListImpl struct {
*navigableExprImpl
}
func (l navigableListImpl) Elements() []Expr {
pbElems := l.Expr.AsList().Elements()
elems := make([]Expr, len(pbElems))
for i := 0; i < len(pbElems); i++ {
elems[i] = l.createChild(pbElems[i])
}
return elems
}
func (l navigableListImpl) IsOptional(index int32) bool {
return l.Expr.AsList().IsOptional(index)
}
func (l navigableListImpl) OptionalIndices() []int32 {
return l.Expr.AsList().OptionalIndices()
}
func (l navigableListImpl) Size() int {
return l.Expr.AsList().Size()
}
type navigableMapImpl struct {
*navigableExprImpl
}
func (m navigableMapImpl) Entries() []EntryExpr {
mapExpr := m.Expr.AsMap()
entries := make([]EntryExpr, len(mapExpr.Entries()))
for i, e := range mapExpr.Entries() {
entry := e.AsMapEntry()
entries[i] = &entryExpr{
id: e.ID(),
entryExprKindCase: navigableEntryImpl{
key: m.createChild(entry.Key()),
val: m.createChild(entry.Value()),
isOpt: entry.IsOptional(),
},
}
}
return entries
}
func (m navigableMapImpl) Size() int {
return m.Expr.AsMap().Size()
}
type navigableEntryImpl struct {
key NavigableExpr
val NavigableExpr
isOpt bool
}
func (e navigableEntryImpl) Kind() EntryExprKind {
return MapEntryKind
}
func (e navigableEntryImpl) Key() Expr {
return e.key
}
func (e navigableEntryImpl) Value() Expr {
return e.val
}
func (e navigableEntryImpl) IsOptional() bool {
return e.isOpt
}
func (e navigableEntryImpl) renumberIDs(IDGenerator) {}
func (e navigableEntryImpl) isEntryExpr() {}
type navigableSelectImpl struct {
*navigableExprImpl
}
func (sel navigableSelectImpl) FieldName() string {
return sel.Expr.AsSelect().FieldName()
}
func (sel navigableSelectImpl) IsTestOnly() bool {
return sel.Expr.AsSelect().IsTestOnly()
}
func (sel navigableSelectImpl) Operand() Expr {
return sel.createChild(sel.Expr.AsSelect().Operand())
}
type navigableStructImpl struct {
*navigableExprImpl
}
func (s navigableStructImpl) TypeName() string {
return s.Expr.AsStruct().TypeName()
}
func (s navigableStructImpl) Fields() []EntryExpr {
fieldInits := s.Expr.AsStruct().Fields()
fields := make([]EntryExpr, len(fieldInits))
for i, f := range fieldInits {
field := f.AsStructField()
fields[i] = &entryExpr{
id: f.ID(),
entryExprKindCase: navigableFieldImpl{
name: field.Name(),
val: s.createChild(field.Value()),
isOpt: field.IsOptional(),
},
}
}
return fields
}
type navigableFieldImpl struct {
name string
val NavigableExpr
isOpt bool
}
func (f navigableFieldImpl) Kind() EntryExprKind {
return StructFieldKind
}
func (f navigableFieldImpl) Name() string {
return f.name
}
func (f navigableFieldImpl) Value() Expr {
return f.val
}
func (f navigableFieldImpl) IsOptional() bool {
return f.isOpt
}
func (f navigableFieldImpl) renumberIDs(IDGenerator) {}
func (f navigableFieldImpl) isEntryExpr() {}
func getChildFactory(expr Expr) childFactory {
if expr == nil {
return noopFactory
}
switch expr.Kind() {
case LiteralKind:
return noopFactory
case IdentKind:
return noopFactory
case SelectKind:
return selectFactory
case CallKind:
return callArgFactory
case ListKind:
return listElemFactory
case MapKind:
return mapEntryFactory
case StructKind:
return structEntryFactory
case ComprehensionKind:
return comprehensionFactory
default:
return noopFactory
}
}
type childFactory func(*navigableExprImpl) []NavigableExpr
func noopFactory(*navigableExprImpl) []NavigableExpr {
return nil
}
func selectFactory(nav *navigableExprImpl) []NavigableExpr {
return []NavigableExpr{nav.createChild(nav.AsSelect().Operand())}
}
func callArgFactory(nav *navigableExprImpl) []NavigableExpr {
call := nav.Expr.AsCall()
argCount := len(call.Args())
if call.IsMemberFunction() {
argCount++
}
navExprs := make([]NavigableExpr, argCount)
i := 0
if call.IsMemberFunction() {
navExprs[i] = nav.createChild(call.Target())
i++
}
for _, arg := range call.Args() {
navExprs[i] = nav.createChild(arg)
i++
}
return navExprs
}
func listElemFactory(nav *navigableExprImpl) []NavigableExpr {
l := nav.Expr.AsList()
navExprs := make([]NavigableExpr, len(l.Elements()))
for i, e := range l.Elements() {
navExprs[i] = nav.createChild(e)
}
return navExprs
}
func structEntryFactory(nav *navigableExprImpl) []NavigableExpr {
s := nav.Expr.AsStruct()
entries := make([]NavigableExpr, len(s.Fields()))
for i, e := range s.Fields() {
f := e.AsStructField()
entries[i] = nav.createChild(f.Value())
}
return entries
}
func mapEntryFactory(nav *navigableExprImpl) []NavigableExpr {
m := nav.Expr.AsMap()
entries := make([]NavigableExpr, len(m.Entries())*2)
j := 0
for _, e := range m.Entries() {
mapEntry := e.AsMapEntry()
entries[j] = nav.createChild(mapEntry.Key())
entries[j+1] = nav.createChild(mapEntry.Value())
j += 2
}
return entries
}
func comprehensionFactory(nav *navigableExprImpl) []NavigableExpr {
compre := nav.Expr.AsComprehension()
return []NavigableExpr{
nav.createChild(compre.IterRange()),
nav.createChild(compre.AccuInit()),
nav.createChild(compre.LoopCondition()),
nav.createChild(compre.LoopStep()),
nav.createChild(compre.Result()),
}
}

View File

@@ -12,7 +12,7 @@ go_library(
],
importpath = "github.com/google/cel-go/common/containers",
deps = [
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"//common/ast:go_default_library",
],
)
@@ -26,6 +26,6 @@ go_test(
":go_default_library",
],
deps = [
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"//common/ast:go_default_library",
],
)

View File

@@ -20,7 +20,7 @@ import (
"fmt"
"strings"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/ast"
)
var (
@@ -297,19 +297,19 @@ func Name(name string) ContainerOption {
// ToQualifiedName converts an expression AST into a qualified name if possible, with a boolean
// 'found' value that indicates if the conversion is successful.
func ToQualifiedName(e *exprpb.Expr) (string, bool) {
switch e.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
id := e.GetIdentExpr()
return id.GetName(), true
case *exprpb.Expr_SelectExpr:
sel := e.GetSelectExpr()
func ToQualifiedName(e ast.Expr) (string, bool) {
switch e.Kind() {
case ast.IdentKind:
id := e.AsIdent()
return id, true
case ast.SelectKind:
sel := e.AsSelect()
// Test only expressions are not valid as qualified names.
if sel.GetTestOnly() {
if sel.IsTestOnly() {
return "", false
}
if qual, found := ToQualifiedName(sel.GetOperand()); found {
return qual + "." + sel.GetField(), true
if qual, found := ToQualifiedName(sel.Operand()); found {
return qual + "." + sel.FieldName(), true
}
}
return "", false

View File

@@ -13,6 +13,8 @@ go_library(
importpath = "github.com/google/cel-go/common/debug",
deps = [
"//common:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"//common/ast:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
],
)

View File

@@ -22,7 +22,9 @@ import (
"strconv"
"strings"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
// Adorner returns debug metadata that will be tacked on to the string
@@ -38,7 +40,7 @@ type Writer interface {
// Buffer pushes an expression into an internal queue of expressions to
// write to a string.
Buffer(e *exprpb.Expr)
Buffer(e ast.Expr)
}
type emptyDebugAdorner struct {
@@ -51,12 +53,12 @@ func (a *emptyDebugAdorner) GetMetadata(e any) string {
}
// ToDebugString gives the unadorned string representation of the Expr.
func ToDebugString(e *exprpb.Expr) string {
func ToDebugString(e ast.Expr) string {
return ToAdornedDebugString(e, emptyAdorner)
}
// ToAdornedDebugString gives the adorned string representation of the Expr.
func ToAdornedDebugString(e *exprpb.Expr, adorner Adorner) string {
func ToAdornedDebugString(e ast.Expr, adorner Adorner) string {
w := newDebugWriter(adorner)
w.Buffer(e)
return w.String()
@@ -78,49 +80,51 @@ func newDebugWriter(a Adorner) *debugWriter {
}
}
func (w *debugWriter) Buffer(e *exprpb.Expr) {
func (w *debugWriter) Buffer(e ast.Expr) {
if e == nil {
return
}
switch e.ExprKind.(type) {
case *exprpb.Expr_ConstExpr:
w.append(formatLiteral(e.GetConstExpr()))
case *exprpb.Expr_IdentExpr:
w.append(e.GetIdentExpr().Name)
case *exprpb.Expr_SelectExpr:
w.appendSelect(e.GetSelectExpr())
case *exprpb.Expr_CallExpr:
w.appendCall(e.GetCallExpr())
case *exprpb.Expr_ListExpr:
w.appendList(e.GetListExpr())
case *exprpb.Expr_StructExpr:
w.appendStruct(e.GetStructExpr())
case *exprpb.Expr_ComprehensionExpr:
w.appendComprehension(e.GetComprehensionExpr())
switch e.Kind() {
case ast.LiteralKind:
w.append(formatLiteral(e.AsLiteral()))
case ast.IdentKind:
w.append(e.AsIdent())
case ast.SelectKind:
w.appendSelect(e.AsSelect())
case ast.CallKind:
w.appendCall(e.AsCall())
case ast.ListKind:
w.appendList(e.AsList())
case ast.MapKind:
w.appendMap(e.AsMap())
case ast.StructKind:
w.appendStruct(e.AsStruct())
case ast.ComprehensionKind:
w.appendComprehension(e.AsComprehension())
}
w.adorn(e)
}
func (w *debugWriter) appendSelect(sel *exprpb.Expr_Select) {
w.Buffer(sel.GetOperand())
func (w *debugWriter) appendSelect(sel ast.SelectExpr) {
w.Buffer(sel.Operand())
w.append(".")
w.append(sel.GetField())
if sel.TestOnly {
w.append(sel.FieldName())
if sel.IsTestOnly() {
w.append("~test-only~")
}
}
func (w *debugWriter) appendCall(call *exprpb.Expr_Call) {
if call.Target != nil {
w.Buffer(call.GetTarget())
func (w *debugWriter) appendCall(call ast.CallExpr) {
if call.IsMemberFunction() {
w.Buffer(call.Target())
w.append(".")
}
w.append(call.GetFunction())
w.append(call.FunctionName())
w.append("(")
if len(call.GetArgs()) > 0 {
if len(call.Args()) > 0 {
w.addIndent()
w.appendLine()
for i, arg := range call.GetArgs() {
for i, arg := range call.Args() {
if i > 0 {
w.append(",")
w.appendLine()
@@ -133,12 +137,12 @@ func (w *debugWriter) appendCall(call *exprpb.Expr_Call) {
w.append(")")
}
func (w *debugWriter) appendList(list *exprpb.Expr_CreateList) {
func (w *debugWriter) appendList(list ast.ListExpr) {
w.append("[")
if len(list.GetElements()) > 0 {
if len(list.Elements()) > 0 {
w.appendLine()
w.addIndent()
for i, elem := range list.GetElements() {
for i, elem := range list.Elements() {
if i > 0 {
w.append(",")
w.appendLine()
@@ -151,32 +155,25 @@ func (w *debugWriter) appendList(list *exprpb.Expr_CreateList) {
w.append("]")
}
func (w *debugWriter) appendStruct(obj *exprpb.Expr_CreateStruct) {
if obj.MessageName != "" {
w.appendObject(obj)
} else {
w.appendMap(obj)
}
}
func (w *debugWriter) appendObject(obj *exprpb.Expr_CreateStruct) {
w.append(obj.GetMessageName())
func (w *debugWriter) appendStruct(obj ast.StructExpr) {
w.append(obj.TypeName())
w.append("{")
if len(obj.GetEntries()) > 0 {
if len(obj.Fields()) > 0 {
w.appendLine()
w.addIndent()
for i, entry := range obj.GetEntries() {
for i, f := range obj.Fields() {
field := f.AsStructField()
if i > 0 {
w.append(",")
w.appendLine()
}
if entry.GetOptionalEntry() {
if field.IsOptional() {
w.append("?")
}
w.append(entry.GetFieldKey())
w.append(field.Name())
w.append(":")
w.Buffer(entry.GetValue())
w.adorn(entry)
w.Buffer(field.Value())
w.adorn(f)
}
w.removeIndent()
w.appendLine()
@@ -184,23 +181,24 @@ func (w *debugWriter) appendObject(obj *exprpb.Expr_CreateStruct) {
w.append("}")
}
func (w *debugWriter) appendMap(obj *exprpb.Expr_CreateStruct) {
func (w *debugWriter) appendMap(m ast.MapExpr) {
w.append("{")
if len(obj.GetEntries()) > 0 {
if m.Size() > 0 {
w.appendLine()
w.addIndent()
for i, entry := range obj.GetEntries() {
for i, e := range m.Entries() {
entry := e.AsMapEntry()
if i > 0 {
w.append(",")
w.appendLine()
}
if entry.GetOptionalEntry() {
if entry.IsOptional() {
w.append("?")
}
w.Buffer(entry.GetMapKey())
w.Buffer(entry.Key())
w.append(":")
w.Buffer(entry.GetValue())
w.adorn(entry)
w.Buffer(entry.Value())
w.adorn(e)
}
w.removeIndent()
w.appendLine()
@@ -208,62 +206,62 @@ func (w *debugWriter) appendMap(obj *exprpb.Expr_CreateStruct) {
w.append("}")
}
func (w *debugWriter) appendComprehension(comprehension *exprpb.Expr_Comprehension) {
func (w *debugWriter) appendComprehension(comprehension ast.ComprehensionExpr) {
w.append("__comprehension__(")
w.addIndent()
w.appendLine()
w.append("// Variable")
w.appendLine()
w.append(comprehension.GetIterVar())
w.append(comprehension.IterVar())
w.append(",")
w.appendLine()
w.append("// Target")
w.appendLine()
w.Buffer(comprehension.GetIterRange())
w.Buffer(comprehension.IterRange())
w.append(",")
w.appendLine()
w.append("// Accumulator")
w.appendLine()
w.append(comprehension.GetAccuVar())
w.append(comprehension.AccuVar())
w.append(",")
w.appendLine()
w.append("// Init")
w.appendLine()
w.Buffer(comprehension.GetAccuInit())
w.Buffer(comprehension.AccuInit())
w.append(",")
w.appendLine()
w.append("// LoopCondition")
w.appendLine()
w.Buffer(comprehension.GetLoopCondition())
w.Buffer(comprehension.LoopCondition())
w.append(",")
w.appendLine()
w.append("// LoopStep")
w.appendLine()
w.Buffer(comprehension.GetLoopStep())
w.Buffer(comprehension.LoopStep())
w.append(",")
w.appendLine()
w.append("// Result")
w.appendLine()
w.Buffer(comprehension.GetResult())
w.Buffer(comprehension.Result())
w.append(")")
w.removeIndent()
}
func formatLiteral(c *exprpb.Constant) string {
switch c.GetConstantKind().(type) {
case *exprpb.Constant_BoolValue:
return fmt.Sprintf("%t", c.GetBoolValue())
case *exprpb.Constant_BytesValue:
return fmt.Sprintf("b\"%s\"", string(c.GetBytesValue()))
case *exprpb.Constant_DoubleValue:
return fmt.Sprintf("%v", c.GetDoubleValue())
case *exprpb.Constant_Int64Value:
return fmt.Sprintf("%d", c.GetInt64Value())
case *exprpb.Constant_StringValue:
return strconv.Quote(c.GetStringValue())
case *exprpb.Constant_Uint64Value:
return fmt.Sprintf("%du", c.GetUint64Value())
case *exprpb.Constant_NullValue:
func formatLiteral(c ref.Val) string {
switch v := c.(type) {
case types.Bool:
return fmt.Sprintf("%t", v)
case types.Bytes:
return fmt.Sprintf("b\"%s\"", string(v))
case types.Double:
return fmt.Sprintf("%v", float64(v))
case types.Int:
return fmt.Sprintf("%d", int64(v))
case types.String:
return strconv.Quote(string(v))
case types.Uint:
return fmt.Sprintf("%du", uint64(v))
case types.Null:
return "null"
default:
panic("Unknown constant type")

View File

@@ -64,7 +64,7 @@ func (e *Errors) GetErrors() []*Error {
// Append creates a new Errors object with the current and input errors.
func (e *Errors) Append(errs []*Error) *Errors {
return &Errors{
errors: append(e.errors, errs...),
errors: append(e.errors[:], errs...),
source: e.source,
numErrors: e.numErrors + len(errs),
maxErrorsToReport: e.maxErrorsToReport,

View File

@@ -31,6 +31,7 @@ type Error interface {
// Err type which extends the built-in go error and implements ref.Val.
type Err struct {
error
id int64
}
var (
@@ -58,7 +59,24 @@ var (
// NewErr creates a new Err described by the format string and args.
// TODO: Audit the use of this function and standardize the error messages and codes.
func NewErr(format string, args ...any) ref.Val {
return &Err{fmt.Errorf(format, args...)}
return &Err{error: fmt.Errorf(format, args...)}
}
// NewErrWithNodeID creates a new Err described by the format string and args.
// TODO: Audit the use of this function and standardize the error messages and codes.
func NewErrWithNodeID(id int64, format string, args ...any) ref.Val {
return &Err{error: fmt.Errorf(format, args...), id: id}
}
// LabelErrNode returns val unaltered it is not an Err or if the error has a non-zero
// AST node ID already present. Otherwise the id is added to the error for
// recovery with the Err.NodeID method.
func LabelErrNode(id int64, val ref.Val) ref.Val {
if err, ok := val.(*Err); ok && err.id == 0 {
err.id = id
return err
}
return val
}
// NoSuchOverloadErr returns a new types.Err instance with a no such overload message.
@@ -124,6 +142,11 @@ func (e *Err) Value() any {
return e.error
}
// NodeID returns the AST node ID of the expression that returned the error.
func (e *Err) NodeID() int64 {
return e.id
}
// Is implements errors.Is.
func (e *Err) Is(target error) bool {
return e.error.Error() == target.Error()

View File

@@ -90,6 +90,18 @@ func (i Int) ConvertToNative(typeDesc reflect.Type) (any, error) {
return nil, err
}
return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
case reflect.Int8:
v, err := int64ToInt8Checked(int64(i))
if err != nil {
return nil, err
}
return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
case reflect.Int16:
v, err := int64ToInt16Checked(int64(i))
if err != nil {
return nil, err
}
return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
case reflect.Int64:
return reflect.ValueOf(i).Convert(typeDesc).Interface(), nil
case reflect.Ptr:

View File

@@ -190,7 +190,13 @@ func (l *baseList) ConvertToNative(typeDesc reflect.Type) (any, error) {
// Allow the element ConvertToNative() function to determine whether conversion is possible.
otherElemType := typeDesc.Elem()
elemCount := l.size
nativeList := reflect.MakeSlice(typeDesc, elemCount, elemCount)
var nativeList reflect.Value
if typeDesc.Kind() == reflect.Array {
nativeList = reflect.New(reflect.ArrayOf(elemCount, typeDesc)).Elem().Index(0)
} else {
nativeList = reflect.MakeSlice(typeDesc, elemCount, elemCount)
}
for i := 0; i < elemCount; i++ {
elem := l.NativeToValue(l.get(i))
nativeElemVal, err := elem.ConvertToNative(otherElemType)

View File

@@ -24,7 +24,7 @@ import (
var (
// OptionalType indicates the runtime type of an optional value.
OptionalType = NewOpaqueType("optional")
OptionalType = NewOpaqueType("optional_type")
// OptionalNone is a sentinel value which is used to indicate an empty optional value.
OptionalNone = &Optional{}

Some files were not shown because too many files have changed in this diff Show More