diff --git a/container.go b/container.go index d418ec27d..fb4b29435 100644 --- a/container.go +++ b/container.go @@ -45,6 +45,8 @@ type Container interface { SetLabels(context.Context, map[string]string) (map[string]string, error) // Extensions returns the extensions set on the container Extensions() map[string]prototypes.Any + // Update a container + Update(context.Context, ...UpdateContainerOpts) error } func containerFromRecord(client *Client, c containers.Container) *container { @@ -238,6 +240,27 @@ func (c *container) NewTask(ctx context.Context, ioCreate IOCreation, opts ...Ne return t, nil } +func (c *container) Update(ctx context.Context, opts ...UpdateContainerOpts) error { + c.mu.Lock() + defer c.mu.Unlock() + // fetch the current container config before updating it + current, err := c.client.ContainerService().Get(ctx, c.ID()) + if err != nil { + return err + } + for _, o := range opts { + if err := o(ctx, c.client, ¤t); err != nil { + return err + } + } + nc, err := c.client.ContainerService().Update(ctx, current) + if err != nil { + return errdefs.FromGRPC(err) + } + c.c = nc + return nil +} + func (c *container) loadTask(ctx context.Context, ioAttach IOAttach) (Task, error) { response, err := c.client.TaskService().Get(ctx, &tasks.GetRequest{ ContainerID: c.c.ID, diff --git a/container_linux_test.go b/container_linux_test.go index b04c57810..c161a3b40 100644 --- a/container_linux_test.go +++ b/container_linux_test.go @@ -22,7 +22,7 @@ import ( "golang.org/x/sys/unix" ) -func TestContainerUpdate(t *testing.T) { +func TestTaskUpdate(t *testing.T) { t.Parallel() client, err := newClient(t, address) diff --git a/container_opts.go b/container_opts.go index 5ad0a9739..40d922a64 100644 --- a/container_opts.go +++ b/container_opts.go @@ -15,6 +15,9 @@ import ( // NewContainerOpts allows the caller to set additional options when creating a container type NewContainerOpts func(ctx context.Context, client *Client, c *containers.Container) error +// UpdateContainerOpts allows the caller to set additional options when updating a container +type UpdateContainerOpts func(ctx context.Context, client *Client, c *containers.Container) error + // WithRuntime allows a user to specify the runtime name and additional options that should // be used to create tasks for the container func WithRuntime(name string, options interface{}) NewContainerOpts { diff --git a/container_test.go b/container_test.go index ef8f88cc0..d1792fd12 100644 --- a/container_test.go +++ b/container_test.go @@ -2,6 +2,7 @@ package containerd import ( "bytes" + "context" "io/ioutil" "os" "runtime" @@ -11,7 +12,9 @@ import ( "time" // Register the typeurl + "github.com/containerd/containerd/containers" _ "github.com/containerd/containerd/runtime" + "github.com/containerd/typeurl" "github.com/containerd/containerd/errdefs" gogotypes "github.com/gogo/protobuf/types" @@ -1451,20 +1454,21 @@ func TestContainerExtensions(t *testing.T) { ext := gogotypes.Any{TypeUrl: "test.ext.url", Value: []byte("hello")} container, err := client.NewContainer(ctx, id, WithNewSpec(), WithContainerExtension("hello", &ext)) if err != nil { - t.Fatal(err) + t.Error(err) + return } defer container.Delete(ctx) checkExt := func(container Container) { cExts := container.Extensions() if len(cExts) != 1 { - t.Fatal("expected 1 container extension") + t.Errorf("expected 1 container extension") } if cExts["hello"].TypeUrl != ext.TypeUrl { - t.Fatalf("got unexpected type url for extension: %s", cExts["hello"].TypeUrl) + t.Errorf("got unexpected type url for extension: %s", cExts["hello"].TypeUrl) } if !bytes.Equal(cExts["hello"].Value, ext.Value) { - t.Fatalf("expected extension value %q, got: %q", ext.Value, cExts["hello"].Value) + t.Errorf("expected extension value %q, got: %q", ext.Value, cExts["hello"].Value) } } @@ -1472,7 +1476,57 @@ func TestContainerExtensions(t *testing.T) { container, err = client.LoadContainer(ctx, container.ID()) if err != nil { - t.Fatal(err) + t.Error(err) + return } checkExt(container) } + +func TestContainerUpdate(t *testing.T) { + t.Parallel() + + ctx, cancel := testContext() + defer cancel() + id := t.Name() + + client, err := newClient(t, address) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + container, err := client.NewContainer(ctx, id, WithNewSpec()) + if err != nil { + t.Error(err) + return + } + defer container.Delete(ctx) + + spec, err := container.Spec() + if err != nil { + t.Error(err) + return + } + + const hostname = "updated-hostname" + spec.Hostname = hostname + + if err := container.Update(ctx, func(ctx context.Context, client *Client, c *containers.Container) error { + a, err := typeurl.MarshalAny(spec) + if err != nil { + return err + } + c.Spec = a + return nil + }); err != nil { + t.Error(err) + return + } + if spec, err = container.Spec(); err != nil { + t.Error(err) + return + } + if spec.Hostname != hostname { + t.Errorf("hostname %q != %q", spec.Hostname, hostname) + } +}