pkg/proxy/nftables: refactor packet tracer address matching

Use bool instead of not-equal-operator as string in
tracer.addressMatches of helpers_test.go

Signed-off-by: Daman Arora <aroradaman@gmail.com>
This commit is contained in:
Daman Arora
2024-01-21 12:53:55 +05:30
parent d23483dd7c
commit 80ca91c90c

View File

@@ -174,7 +174,7 @@ func newNFTablesTracer(t *testing.T, nft *knftables.Fake, nodeIPs []string) *nft
} }
} }
func (tracer *nftablesTracer) addressMatches(ipStr, not, ruleAddress string) bool { func (tracer *nftablesTracer) addressMatches(ipStr string, wantMatch bool, ruleAddress string) bool {
ip := netutils.ParseIPSloppy(ipStr) ip := netutils.ParseIPSloppy(ipStr)
if ip == nil { if ip == nil {
tracer.t.Fatalf("Bad IP in test case: %s", ipStr) tracer.t.Fatalf("Bad IP in test case: %s", ipStr)
@@ -195,18 +195,14 @@ func (tracer *nftablesTracer) addressMatches(ipStr, not, ruleAddress string) boo
match = ip.Equal(ip2) match = ip.Equal(ip2)
} }
if not == "!= " { return match == wantMatch
return !match
} else {
return match
}
} }
func (tracer *nftablesTracer) noneAddressesMatch(ipStr, ruleAddress string) bool { func (tracer *nftablesTracer) noneAddressesMatch(ipStr, ruleAddress string) bool {
ruleAddress = strings.ReplaceAll(ruleAddress, " ", "") ruleAddress = strings.ReplaceAll(ruleAddress, " ", "")
addresses := strings.Split(ruleAddress, ",") addresses := strings.Split(ruleAddress, ",")
for _, address := range addresses { for _, address := range addresses {
if tracer.addressMatches(ipStr, "", address) { if tracer.addressMatches(ipStr, true, address) {
return false return false
} }
} }
@@ -240,7 +236,7 @@ func (tracer *nftablesTracer) matchDest(elements []*knftables.Element, destIP, p
// found. // found.
func (tracer *nftablesTracer) matchDestAndSource(elements []*knftables.Element, destIP, protocol, destPort, sourceIP string) *knftables.Element { func (tracer *nftablesTracer) matchDestAndSource(elements []*knftables.Element, destIP, protocol, destPort, sourceIP string) *knftables.Element {
for _, element := range elements { for _, element := range elements {
if element.Key[0] == destIP && element.Key[1] == protocol && element.Key[2] == destPort && tracer.addressMatches(sourceIP, "", element.Key[3]) { if element.Key[0] == destIP && element.Key[1] == protocol && element.Key[2] == destPort && tracer.addressMatches(sourceIP, true, element.Key[3]) {
return element return element
} }
} }
@@ -416,8 +412,8 @@ func (tracer *nftablesTracer) runChain(chname, sourceIP, protocol, destIP, destP
// Tests whether destIP does/doesn't match a literal. // Tests whether destIP does/doesn't match a literal.
match := destAddrRegexp.FindStringSubmatch(rule) match := destAddrRegexp.FindStringSubmatch(rule)
rule = strings.TrimPrefix(rule, match[0]) rule = strings.TrimPrefix(rule, match[0])
not, ip := match[1], match[2] wantMatch, ip := match[1] != "!= ", match[2]
if !tracer.addressMatches(destIP, not, ip) { if !tracer.addressMatches(destIP, wantMatch, ip) {
rule = "" rule = ""
break break
} }
@@ -458,8 +454,8 @@ func (tracer *nftablesTracer) runChain(chname, sourceIP, protocol, destIP, destP
// Tests whether sourceIP does/doesn't match a literal. // Tests whether sourceIP does/doesn't match a literal.
match := sourceAddrRegexp.FindStringSubmatch(rule) match := sourceAddrRegexp.FindStringSubmatch(rule)
rule = strings.TrimPrefix(rule, match[0]) rule = strings.TrimPrefix(rule, match[0])
not, ip := match[1], match[2] wantMatch, ip := match[1] != "!= ", match[2]
if !tracer.addressMatches(sourceIP, not, ip) { if !tracer.addressMatches(sourceIP, wantMatch, ip) {
rule = "" rule = ""
break break
} }