diff --git a/pkg/failpoint/fail.go b/pkg/failpoint/fail.go new file mode 100644 index 000000000..7056f5c41 --- /dev/null +++ b/pkg/failpoint/fail.go @@ -0,0 +1,293 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package failpoint + +import ( + "bytes" + "fmt" + "strconv" + "strings" + "sync" + "time" +) + +// Type is the type of failpoint to specifies which action to take. +type Type int + +const ( + // TypeInvalid is invalid type + TypeInvalid Type = iota + // TypeOff takes no action + TypeOff + // TypeError triggers failpoint error with specified argument + TypeError + // TypePanic triggers panic with specified argument + TypePanic + // TypeDelay sleeps with the specified number of milliseconds + TypeDelay +) + +// String returns the name of type. +func (t Type) String() string { + switch t { + case TypeOff: + return "off" + case TypeError: + return "error" + case TypePanic: + return "panic" + case TypeDelay: + return "delay" + default: + return "invalid" + } +} + +// Failpoint is used to add code points where error or panic may be injected by +// user. The user controlled variable will be parsed for how the error injected +// code should fire. There is the way to set the rule for failpoint. +// +// *[(arg)][->] +// +// The argument specifies which action to take; it can be one of: +// +// off: Takes no action (does not trigger failpoint and no argument) +// error: Triggers failpoint error with specified argument(string) +// panic: Triggers panic with specified argument(string) +// delay: Sleep the specified number of milliseconds +// +// The * modifiers prior to control when is executed. For +// example, "5*error(oops)" means "return error oops 5 times total". The +// operator -> can be used to express cascading terms. If you specify +// ->, it means that if does not execute, will +// be evaluated. If you want the error injected code should fire in second +// call, you can specify "1*off->1*error(oops)". +// +// Based on fail(9) freebsd: https://www.freebsd.org/cgi/man.cgi?query=fail&sektion=9&apropos=0&manpath=FreeBSD%2B10.0-RELEASE +type Failpoint struct { + sync.Mutex + + fnName string + entries []*failpointEntry +} + +// NewFailpoint returns failpoint control. +func NewFailpoint(fnName string, terms string) (*Failpoint, error) { + entries, err := parseTerms([]byte(terms)) + if err != nil { + return nil, err + } + + return &Failpoint{ + fnName: fnName, + entries: entries, + }, nil +} + +// Evaluate evaluates a failpoint. +func (fp *Failpoint) Evaluate() error { + var target *failpointEntry + + func() { + fp.Lock() + defer fp.Unlock() + + for _, entry := range fp.entries { + if entry.count == 0 { + continue + } + + entry.count-- + target = entry + break + } + }() + + if target == nil { + return nil + } + return target.evaluate() +} + +// Failpoint returns the current state of control in string format. +func (fp *Failpoint) Marshal() string { + fp.Lock() + defer fp.Unlock() + + res := make([]string, 0, len(fp.entries)) + for _, entry := range fp.entries { + res = append(res, entry.marshal()) + } + return strings.Join(res, "->") +} + +type failpointEntry struct { + typ Type + arg interface{} + count int64 +} + +func newFailpointEntry() *failpointEntry { + return &failpointEntry{ + typ: TypeInvalid, + count: 0, + } +} + +func (fpe *failpointEntry) marshal() string { + base := fmt.Sprintf("%d*%s", fpe.count, fpe.typ) + switch fpe.typ { + case TypeOff: + return base + case TypeError, TypePanic: + return fmt.Sprintf("%s(%s)", base, fpe.arg.(string)) + case TypeDelay: + return fmt.Sprintf("%s(%d)", base, fpe.arg.(time.Duration)/time.Millisecond) + default: + return base + } +} + +func (fpe *failpointEntry) evaluate() error { + switch fpe.typ { + case TypeOff: + return nil + case TypeError: + return fmt.Errorf("%v", fpe.arg) + case TypePanic: + panic(fpe.arg) + case TypeDelay: + time.Sleep(fpe.arg.(time.Duration)) + return nil + default: + panic("invalid failpoint type") + } +} + +func parseTerms(term []byte) ([]*failpointEntry, error) { + var entry *failpointEntry + var err error + + // count*type[(arg)] + term, entry, err = parseTerm(term) + if err != nil { + return nil, err + } + + res := []*failpointEntry{entry} + + // cascading terms + for len(term) > 0 { + if !bytes.HasPrefix(term, []byte("->")) { + return nil, fmt.Errorf("invalid cascading terms: %s", string(term)) + } + + term = term[2:] + term, entry, err = parseTerm(term) + if err != nil { + return nil, fmt.Errorf("failed to parse cascading term: %w", err) + } + + res = append(res, entry) + } + return res, nil +} + +func parseTerm(term []byte) ([]byte, *failpointEntry, error) { + var err error + var entry = newFailpointEntry() + + // count* + term, err = parseInt64(term, '*', &entry.count) + if err != nil { + return nil, nil, err + } + + // type[(arg)] + term, err = parseType(term, entry) + return term, entry, err +} + +func parseType(term []byte, entry *failpointEntry) ([]byte, error) { + var nameToTyp = map[string]Type{ + "off": TypeOff, + "error(": TypeError, + "panic(": TypePanic, + "delay(": TypeDelay, + } + + var found bool + for name, typ := range nameToTyp { + if bytes.HasPrefix(term, []byte(name)) { + found = true + term = term[len(name):] + entry.typ = typ + break + } + } + + if !found { + return nil, fmt.Errorf("invalid type format: %s", string(term)) + } + + switch entry.typ { + case TypePanic, TypeError: + endIdx := bytes.IndexByte(term, ')') + if endIdx <= 0 { + return nil, fmt.Errorf("invalid argument for %s type", entry.typ) + } + entry.arg = string(term[:endIdx]) + return term[endIdx+1:], nil + case TypeOff: + // do nothing + return term, nil + case TypeDelay: + var msVal int64 + var err error + + term, err = parseInt64(term, ')', &msVal) + if err != nil { + return nil, err + } + entry.arg = time.Millisecond * time.Duration(msVal) + return term, nil + default: + panic("unreachable") + } +} + +func parseInt64(term []byte, terminate byte, val *int64) ([]byte, error) { + i := 0 + + for ; i < len(term); i++ { + if b := term[i]; b < '0' || b > '9' { + break + } + } + + if i == 0 || i == len(term) || term[i] != terminate { + return nil, fmt.Errorf("failed to parse int64 because of invalid terminate byte: %s", string(term)) + } + + v, err := strconv.ParseInt(string(term[:i]), 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse int64 from %s: %v", string(term[:i]), err) + } + + *val = v + return term[i+1:], nil +} diff --git a/pkg/failpoint/fail_test.go b/pkg/failpoint/fail_test.go new file mode 100644 index 000000000..1b79ff6eb --- /dev/null +++ b/pkg/failpoint/fail_test.go @@ -0,0 +1,134 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package failpoint + +import ( + "reflect" + "testing" + "time" +) + +func TestParseTerms(t *testing.T) { + cases := []struct { + terms string + hasError bool + }{ + // off + {"5", true}, + {"*off()", true}, + {"5*off()", true}, + {"5*off(nothing)", true}, + {"5*off(", true}, + {"5*off", false}, + + // error + {"10000error(oops)", true}, + {"10*error(oops)", false}, + {"1234*error(oops))", true}, + {"12342*error()", true}, + + // panic + {"1panic(oops)", true}, + {"1000000*panic(oops)", false}, + {"12345*panic(oops))", true}, + {"12*panic()", true}, + + // delay + {"1*delay(oops)", true}, + {"1000000*delay(-1)", true}, + {"1000000*delay(1)", false}, + + // cascading terms + {"1*delay(1)-", true}, + {"10*delay(2)->", true}, + {"11*delay(3)->10*off(", true}, + {"12*delay(4)->10*of", true}, + {"13*delay(5)->10*off->1000*panic(oops)", false}, + } + + for i, c := range cases { + fp, err := NewFailpoint(t.Name(), c.terms) + + if (err != nil && !c.hasError) || + (err == nil && c.hasError) { + + t.Fatalf("[%v - %s] expected hasError=%v, but got %v", i, c.terms, c.hasError, err) + } + + if err != nil { + continue + } + + if got := fp.Marshal(); !reflect.DeepEqual(got, c.terms) { + t.Fatalf("[%v] expected %v, but got %v", i, c.terms, got) + } + } +} + +func TestEvaluate(t *testing.T) { + terms := "1*error(oops-)->1*off->1*delay(1000)->1*panic(panic)" + + fp, err := NewFailpoint(t.Name(), terms) + if err != nil { + t.Fatalf("unexpected error %v", err) + } + + injectedFn := func() error { + if err := fp.Evaluate(); err != nil { + return err + } + return nil + } + + // should return oops- error + if err := injectedFn(); err == nil || err.Error() != "oops-" { + t.Fatalf("expected error %v, but got %v", "oops-", err) + } + + // should return nil + if err := injectedFn(); err != nil { + t.Fatalf("expected nil, but got %v", err) + } + + // should sleep 1s and return nil + now := time.Now() + err = injectedFn() + du := time.Since(now) + if err != nil { + t.Fatalf("expected nil, but got %v", err) + } + if du < 1*time.Second { + t.Fatalf("expected sleep 1s, but got %v", du) + } + + // should panic + defer func() { + if err := recover(); err == nil || err.(string) != "panic" { + t.Fatalf("should panic(panic), but got %v", err) + } + + expected := "0*error(oops-)->0*off->0*delay(1000)->0*panic(panic)" + if got := fp.Marshal(); got != expected { + t.Fatalf("expected %v, but got %v", expected, got) + } + + if err := injectedFn(); err != nil { + t.Fatalf("expected nil, but got %v", err) + } + }() + injectedFn() +}