diff --git a/integration/client/restart_monitor_test.go b/integration/client/restart_monitor_test.go index ad1183927..08eb6dee4 100644 --- a/integration/client/restart_monitor_test.go +++ b/integration/client/restart_monitor_test.go @@ -19,20 +19,24 @@ package client import ( "bytes" "context" + "errors" "fmt" "os" "path/filepath" "runtime" + "strconv" "syscall" "testing" "time" . "github.com/containerd/containerd" - "github.com/containerd/containerd/containers" + eventtypes "github.com/containerd/containerd/api/events" "github.com/containerd/containerd/oci" "github.com/containerd/containerd/pkg/testutil" + "github.com/containerd/containerd/runtime/restart" srvconfig "github.com/containerd/containerd/services/server/config" "github.com/containerd/containerd/sys" + "github.com/containerd/typeurl" exec "golang.org/x/sys/execabs" ) @@ -148,7 +152,7 @@ version = 2 oci.WithImageConfig(image), longCommand, ), - withRestartStatus(Running), + restart.WithStatus(Running), ) if err != nil { t.Fatal(err) @@ -229,14 +233,115 @@ version = 2 t.Logf("%v: the task was restarted since %v", time.Now(), lastCheck) } -// withRestartStatus is a copy of "github.com/containerd/containerd/runtime/restart".WithStatus. -// This copy is needed because `go test` refuses circular imports. -func withRestartStatus(status ProcessStatus) func(context.Context, *Client, *containers.Container) error { - return func(_ context.Context, _ *Client, c *containers.Container) error { - if c.Labels == nil { - c.Labels = make(map[string]string) +func TestRestartMonitorWithOnFailurePolicy(t *testing.T) { + const ( + interval = 5 * time.Second + ) + configTOML := fmt.Sprintf(` +version = 2 +[plugins] + [plugins."io.containerd.internal.v1.restart"] + interval = "%s" +`, interval.String()) + client, _, cleanup := newDaemonWithConfig(t, configTOML) + defer cleanup() + + var ( + ctx, cancel = testContext(t) + id = t.Name() + ) + defer cancel() + + image, err := client.Pull(ctx, testImage, WithPullUnpack) + if err != nil { + t.Fatal(err) + } + + policy, _ := restart.NewPolicy("on-failure:1") + container, err := client.NewContainer(ctx, id, + WithNewSnapshot(id, image), + WithNewSpec( + oci.WithImageConfig(image), + // always exited with 1 + withExitStatus(1), + ), + restart.WithStatus(Running), + restart.WithPolicy(policy), + ) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := container.Delete(ctx, WithSnapshotCleanup); err != nil { + t.Logf("failed to delete container: %v", err) } - c.Labels["containerd.io/restart.status"] = string(status) - return nil + }() + + task, err := container.NewTask(ctx, empty()) + if err != nil { + t.Fatal(err) + } + defer func() { + if _, err := task.Delete(ctx, WithProcessKill); err != nil { + t.Logf("failed to delete task: %v", err) + } + }() + + if err := task.Start(ctx); err != nil { + t.Fatal(err) + } + + statusCh, err := task.Wait(ctx) + if err != nil { + t.Fatal(err) + } + + eventCh, eventErrCh := client.Subscribe(ctx, `topic=="/tasks/create"`) + + select { + case <-statusCh: + case <-time.After(30 * time.Second): + t.Fatal("should receive exit event in time") + } + + select { + case e := <-eventCh: + cid, err := convertTaskCreateEvent(e.Event) + if err != nil { + t.Fatal(err) + } + if cid != id { + t.Fatalf("expected task id = %s, but got %s", id, cid) + } + case err := <-eventErrCh: + t.Fatalf("unexpected error from event channel: %v", err) + case <-time.After(1 * time.Minute): + t.Fatal("should receive create event in time") + } + + labels, err := container.Labels(ctx) + if err != nil { + t.Fatal(err) + } + restartCount, _ := strconv.Atoi(labels[restart.CountLabel]) + if restartCount != 1 { + t.Fatalf("expected restart count to be 1, got %d", restartCount) } } + +func convertTaskCreateEvent(e typeurl.Any) (string, error) { + id := "" + + evt, err := typeurl.UnmarshalAny(e) + if err != nil { + return "", fmt.Errorf("failed to unmarshalany: %w", err) + } + + switch e := evt.(type) { + case *eventtypes.TaskCreate: + id = e.ContainerID + default: + return "", errors.New("unsupported event") + } + return id, nil +} diff --git a/runtime/restart/monitor/change.go b/runtime/restart/monitor/change.go index a74b3dcd1..7188e0a72 100644 --- a/runtime/restart/monitor/change.go +++ b/runtime/restart/monitor/change.go @@ -20,10 +20,12 @@ import ( "context" "fmt" "net/url" + "strconv" "syscall" "github.com/containerd/containerd" "github.com/containerd/containerd/cio" + "github.com/containerd/containerd/runtime/restart" "github.com/sirupsen/logrus" ) @@ -38,6 +40,7 @@ func (s *stopChange) apply(ctx context.Context, client *containerd.Client) error type startChange struct { container containerd.Container logURI string + count int // Deprecated(in release 1.5): but recognized now, prefer to use logURI logPath string @@ -61,6 +64,15 @@ func (s *startChange) apply(ctx context.Context, client *containerd.Client) erro s.logPath, s.logURI) } + if s.count > 0 { + labels := map[string]string{ + restart.CountLabel: strconv.Itoa(s.count), + } + opt := containerd.WithAdditionalContainerLabels(labels) + if err := s.container.Update(ctx, containerd.UpdateContainerOpts(opt)); err != nil { + return err + } + } killTask(ctx, s.container) task, err := s.container.NewTask(ctx, log) if err != nil { diff --git a/runtime/restart/monitor/monitor.go b/runtime/restart/monitor/monitor.go index a504ad9f4..ee2c1e16f 100644 --- a/runtime/restart/monitor/monitor.go +++ b/runtime/restart/monitor/monitor.go @@ -19,6 +19,7 @@ package monitor import ( "context" "fmt" + "strconv" "sync" "time" @@ -72,6 +73,7 @@ func init() { }, }, InitFn: func(ic *plugin.InitContext) (interface{}, error) { + ic.Meta.Capabilities = []string{"no", "always", "on-failure", "unless-stopped"} opts, err := getServicesOpts(ic) if err != nil { return nil, err @@ -213,15 +215,29 @@ func (m *monitor) monitor(ctx context.Context) ([]change, error) { return nil, err } desiredStatus := containerd.ProcessStatus(labels[restart.StatusLabel]) - if m.isSameStatus(ctx, desiredStatus, c) { + task, err := c.Task(ctx, nil) + if err != nil && desiredStatus == containerd.Stopped { continue } + status, err := task.Status(ctx) + if err != nil && desiredStatus == containerd.Stopped { + continue + } + if desiredStatus == status.Status { + continue + } + switch desiredStatus { case containerd.Running: + if !restart.Reconcile(status, labels) { + continue + } + restartCount, _ := strconv.Atoi(labels[restart.CountLabel]) changes = append(changes, &startChange{ container: c, logPath: labels[restart.LogPathLabel], logURI: labels[restart.LogURILabel], + count: restartCount + 1, }) case containerd.Stopped: changes = append(changes, &stopChange{ @@ -231,15 +247,3 @@ func (m *monitor) monitor(ctx context.Context) ([]change, error) { } return changes, nil } - -func (m *monitor) isSameStatus(ctx context.Context, desired containerd.ProcessStatus, container containerd.Container) bool { - task, err := container.Task(ctx, nil) - if err != nil { - return desired == containerd.Stopped - } - state, err := task.Status(ctx) - if err != nil { - return desired == containerd.Stopped - } - return desired == state.Status -} diff --git a/runtime/restart/restart.go b/runtime/restart/restart.go index e761ff01e..41f03f4e9 100644 --- a/runtime/restart/restart.go +++ b/runtime/restart/restart.go @@ -31,11 +31,15 @@ package restart import ( "context" + "fmt" "net/url" + "strconv" + "strings" "github.com/containerd/containerd" "github.com/containerd/containerd/cio" "github.com/containerd/containerd/containers" + "github.com/sirupsen/logrus" ) const ( @@ -44,12 +48,106 @@ const ( // LogURILabel sets the restart log uri label for a container LogURILabel = "containerd.io/restart.loguri" + // PolicyLabel sets the restart policy label for a container + PolicyLabel = "containerd.io/restart.policy" + // CountLabel sets the restart count label for a container + CountLabel = "containerd.io/restart.count" + // ExplicitlyStoppedLabel sets the restart explicitly stopped label for a container + ExplicitlyStoppedLabel = "containerd.io/restart.explicitly-stopped" + // LogPathLabel sets the restart log path label for a container // // Deprecated(in release 1.5): use LogURILabel LogPathLabel = "containerd.io/restart.logpath" ) +// Policy represents the restart policies of a container. +type Policy struct { + name string + maximumRetryCount int +} + +// NewPolicy creates a restart policy with the specified name. +// supports the following restart policies: +// - no, Do not restart the container. +// - always, Always restart the container regardless of the exit status. +// - on-failure[:max-retries], Restart only if the container exits with a non-zero exit status. +// - unless-stopped, Always restart the container unless it is stopped. +func NewPolicy(policy string) (*Policy, error) { + policySlice := strings.Split(policy, ":") + var ( + err error + retryCount int + ) + switch policySlice[0] { + case "", "no", "always", "unless-stopped": + policy = policySlice[0] + if policy == "" { + policy = "always" + } + if len(policySlice) > 1 { + return nil, fmt.Errorf("restart policy %q not support max retry count", policySlice[0]) + } + case "on-failure": + policy = policySlice[0] + if len(policySlice) > 1 { + retryCount, err = strconv.Atoi(policySlice[1]) + if err != nil { + return nil, fmt.Errorf("invalid max retry count: %s", policySlice[1]) + } + } + default: + return nil, fmt.Errorf("restart policy %q not supported", policy) + } + return &Policy{ + name: policy, + maximumRetryCount: retryCount, + }, nil +} + +func (rp *Policy) String() string { + if rp.maximumRetryCount > 0 { + return fmt.Sprintf("%s:%d", rp.name, rp.maximumRetryCount) + } + return rp.name +} + +func (rp *Policy) Name() string { + return rp.name +} + +func (rp *Policy) MaximumRetryCount() int { + return rp.maximumRetryCount +} + +// Reconcile reconciles the restart policy of a container. +func Reconcile(status containerd.Status, labels map[string]string) bool { + rp, err := NewPolicy(labels[PolicyLabel]) + if err != nil { + logrus.WithError(err).Error("policy reconcile") + return false + } + switch rp.Name() { + case "", "always": + return true + case "on-failure": + restartCount, err := strconv.Atoi(labels[CountLabel]) + if err != nil && labels[CountLabel] != "" { + logrus.WithError(err).Error("policy reconcile") + return false + } + if status.ExitStatus != 0 && (rp.maximumRetryCount == 0 || restartCount < rp.maximumRetryCount) { + return true + } + case "unless-stopped": + explicitlyStopped, _ := strconv.ParseBool(labels[ExplicitlyStoppedLabel]) + if !explicitlyStopped { + return true + } + } + return false +} + // WithLogURI sets the specified log uri for a container. func WithLogURI(uri *url.URL) func(context.Context, *containerd.Client, *containers.Container) error { return WithLogURIString(uri.String()) @@ -110,12 +208,22 @@ func WithStatus(status containerd.ProcessStatus) func(context.Context, *containe } } +// WithPolicy sets the restart policy for a container +func WithPolicy(policy *Policy) func(context.Context, *containerd.Client, *containers.Container) error { + return func(_ context.Context, _ *containerd.Client, c *containers.Container) error { + ensureLabels(c) + c.Labels[PolicyLabel] = policy.String() + return nil + } +} + // WithNoRestarts clears any restart information from the container func WithNoRestarts(_ context.Context, _ *containerd.Client, c *containers.Container) error { if c.Labels == nil { return nil } delete(c.Labels, StatusLabel) + delete(c.Labels, PolicyLabel) delete(c.Labels, LogPathLabel) delete(c.Labels, LogURILabel) return nil diff --git a/runtime/restart/restart_test.go b/runtime/restart/restart_test.go new file mode 100644 index 000000000..23958ed8a --- /dev/null +++ b/runtime/restart/restart_test.go @@ -0,0 +1,221 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package restart + +import ( + "testing" + + "github.com/containerd/containerd" + "github.com/stretchr/testify/assert" +) + +func TestNewRestartPolicy(t *testing.T) { + tests := []struct { + policy string + want *Policy + }{ + { + policy: "unknow", + want: nil, + }, + { + policy: "", + want: &Policy{name: "always"}, + }, + { + policy: "always", + want: &Policy{name: "always"}, + }, + { + policy: "always:3", + want: nil, + }, + { + policy: "on-failure", + want: &Policy{name: "on-failure"}, + }, + { + policy: "on-failure:10", + want: &Policy{ + name: "on-failure", + maximumRetryCount: 10, + }, + }, + { + policy: "unless-stopped", + want: &Policy{ + name: "unless-stopped", + }, + }, + } + + for _, testCase := range tests { + result, _ := NewPolicy(testCase.policy) + assert.Equal(t, testCase.want, result) + } +} + +func TestRestartPolicyToString(t *testing.T) { + tests := []struct { + policy string + want string + }{ + { + policy: "", + want: "always", + }, + { + policy: "always", + want: "always", + }, + { + policy: "on-failure", + want: "on-failure", + }, + { + policy: "on-failure:10", + want: "on-failure:10", + }, + { + policy: "unless-stopped", + want: "unless-stopped", + }, + } + + for _, testCase := range tests { + policy, err := NewPolicy(testCase.policy) + if err != nil { + t.Fatal(err) + } + result := policy.String() + assert.Equal(t, testCase.want, result) + } +} + +func TestRestartPolicyReconcile(t *testing.T) { + tests := []struct { + status containerd.Status + labels map[string]string + want bool + }{ + { + status: containerd.Status{ + Status: containerd.Stopped, + }, + labels: map[string]string{ + PolicyLabel: "always", + }, + want: true, + }, + { + status: containerd.Status{ + Status: containerd.Unknown, + }, + labels: map[string]string{ + PolicyLabel: "always", + }, + want: true, + }, + { + status: containerd.Status{ + Status: containerd.Stopped, + }, + labels: map[string]string{ + PolicyLabel: "on-failure:10", + CountLabel: "1", + }, + want: false, + }, + { + status: containerd.Status{ + Status: containerd.Unknown, + ExitStatus: 1, + }, + // test without count label + labels: map[string]string{ + PolicyLabel: "on-failure:10", + }, + want: true, + }, + { + status: containerd.Status{ + Status: containerd.Unknown, + ExitStatus: 1, + }, + // test without valid count label + labels: map[string]string{ + PolicyLabel: "on-failure:10", + CountLabel: "invalid", + }, + want: false, + }, + { + status: containerd.Status{ + Status: containerd.Unknown, + ExitStatus: 1, + }, + labels: map[string]string{ + PolicyLabel: "on-failure:10", + CountLabel: "1", + }, + want: true, + }, + { + status: containerd.Status{ + Status: containerd.Unknown, + ExitStatus: 1, + }, + labels: map[string]string{ + PolicyLabel: "on-failure:3", + CountLabel: "3", + }, + want: false, + }, + { + status: containerd.Status{ + Status: containerd.Unknown, + }, + labels: map[string]string{ + PolicyLabel: "unless-stopped", + }, + want: true, + }, + { + status: containerd.Status{ + Status: containerd.Stopped, + }, + labels: map[string]string{ + PolicyLabel: "unless-stopped", + }, + want: true, + }, + { + status: containerd.Status{ + Status: containerd.Stopped, + }, + labels: map[string]string{ + PolicyLabel: "unless-stopped", + ExplicitlyStoppedLabel: "true", + }, + want: false, + }, + } + for _, testCase := range tests { + result := Reconcile(testCase.status, testCase.labels) + assert.Equal(t, testCase.want, result, testCase) + } +}