From 550a6f1d733ad4960ecbe8247c6cb2a545d7dd66 Mon Sep 17 00:00:00 2001 From: Maksym Pavlenko Date: Thu, 11 Jul 2019 11:54:48 -0700 Subject: [PATCH] Fix integration tests Signed-off-by: Maksym Pavlenko --- client.go | 51 +++++++++++++++++++++++++++--------------- container_opts.go | 13 +++++++++++ container_opts_unix.go | 4 ++++ image.go | 6 ++++- 4 files changed, 55 insertions(+), 19 deletions(-) diff --git a/client.go b/client.go index 82505889e..9b0888c79 100644 --- a/client.go +++ b/client.go @@ -143,8 +143,10 @@ func New(address string, opts ...ClientOpt) (*Client, error) { // check namespace labels for default runtime if copts.defaultRuntime == "" && copts.defaultns != "" { ctx := namespaces.WithNamespace(context.Background(), copts.defaultns) - if err := c.GetLabel(ctx, defaults.DefaultRuntimeNSLabel, &c.runtime, ""); err != nil { + if label, err := c.GetLabel(ctx, defaults.DefaultRuntimeNSLabel); err != nil { return nil, err + } else if label != "" { + c.runtime = label } } @@ -168,8 +170,10 @@ func NewWithConn(conn *grpc.ClientConn, opts ...ClientOpt) (*Client, error) { // check namespace labels for default runtime if copts.defaultRuntime == "" && copts.defaultns != "" { ctx := namespaces.WithNamespace(context.Background(), copts.defaultns) - if err := c.GetLabel(ctx, defaults.DefaultRuntimeNSLabel, &c.runtime, ""); err != nil { + if label, err := c.GetLabel(ctx, defaults.DefaultRuntimeNSLabel); err != nil { return nil, err + } else if label != "" { + c.runtime = label } } @@ -482,28 +486,22 @@ func writeIndex(ctx context.Context, index *ocispec.Index, client *Client, ref s return writeContent(ctx, client.ContentStore(), ocispec.MediaTypeImageIndex, ref, bytes.NewReader(data), content.WithLabels(labels)) } -// GetLabel gets a label value from namespace store and saves it in 'out' variable. -// If there is no value, a fallback value will be used instead. -func (c *Client) GetLabel(ctx context.Context, label string, out *string, fallback string) error { +// GetLabel gets a label value from namespace store +// If there is no default label, an empty string returned with nil error +func (c *Client) GetLabel(ctx context.Context, label string) (string, error) { ns, err := namespaces.NamespaceRequired(ctx) if err != nil { - return err + return "", err } srv := c.NamespaceService() labels, err := srv.Labels(ctx, ns) if err != nil { - return err + return "", err } - value, ok := labels[label] - if ok { - *out = value - } else { - *out = fallback - } - - return nil + value := labels[label] + return value, nil } // Subscribe to events that match one or more of the provided filters. @@ -671,17 +669,34 @@ func (c *Client) Version(ctx context.Context) (Version, error) { }, nil } -func (c *Client) getSnapshotter(ctx context.Context, name string) (snapshots.Snapshotter, error) { +func (c *Client) resolveSnapshotterName(ctx context.Context, name string) (string, error) { if name == "" { - if err := c.GetLabel(ctx, defaults.DefaultSnapshotterNSLabel, &name, DefaultSnapshotter); err != nil { - return nil, err + label, err := c.GetLabel(ctx, defaults.DefaultSnapshotterNSLabel) + if err != nil { + return "", err } + + if label != "" { + name = label + } else { + name = DefaultSnapshotter + } + } + + return name, nil +} + +func (c *Client) getSnapshotter(ctx context.Context, name string) (snapshots.Snapshotter, error) { + name, err := c.resolveSnapshotterName(ctx, name) + if err != nil { + return nil, err } s := c.SnapshotService(name) if s == nil { return nil, errors.Wrapf(errdefs.ErrNotFound, "snapshotter %s was not found", name) } + return s, nil } diff --git a/container_opts.go b/container_opts.go index c8a0933b7..4c8a40489 100644 --- a/container_opts.go +++ b/container_opts.go @@ -117,6 +117,11 @@ func WithSnapshotter(name string) NewContainerOpts { func WithSnapshot(id string) NewContainerOpts { return func(ctx context.Context, client *Client, c *containers.Container) error { // check that the snapshot exists, if not, fail on creation + var err error + c.Snapshotter, err = client.resolveSnapshotterName(ctx, c.Snapshotter) + if err != nil { + return err + } s, err := client.getSnapshotter(ctx, c.Snapshotter) if err != nil { return err @@ -139,6 +144,10 @@ func WithNewSnapshot(id string, i Image, opts ...snapshots.Opt) NewContainerOpts } parent := identity.ChainID(diffIDs).String() + c.Snapshotter, err = client.resolveSnapshotterName(ctx, c.Snapshotter) + if err != nil { + return err + } s, err := client.getSnapshotter(ctx, c.Snapshotter) if err != nil { return err @@ -177,6 +186,10 @@ func WithNewSnapshotView(id string, i Image, opts ...snapshots.Opt) NewContainer } parent := identity.ChainID(diffIDs).String() + c.Snapshotter, err = client.resolveSnapshotterName(ctx, c.Snapshotter) + if err != nil { + return err + } s, err := client.getSnapshotter(ctx, c.Snapshotter) if err != nil { return err diff --git a/container_opts_unix.go b/container_opts_unix.go index c8b6247b5..af52d0422 100644 --- a/container_opts_unix.go +++ b/container_opts_unix.go @@ -54,6 +54,10 @@ func withRemappedSnapshotBase(id string, i Image, uid, gid uint32, readonly bool parent = identity.ChainID(diffIDs).String() usernsID = fmt.Sprintf("%s-%d-%d", parent, uid, gid) ) + c.Snapshotter, err = client.resolveSnapshotterName(ctx, c.Snapshotter) + if err != nil { + return err + } snapshotter, err := client.getSnapshotter(ctx, c.Snapshotter) if err != nil { return err diff --git a/image.go b/image.go index 3c820840c..77c95eaa4 100644 --- a/image.go +++ b/image.go @@ -25,7 +25,7 @@ import ( "github.com/containerd/containerd/images" "github.com/containerd/containerd/platforms" "github.com/containerd/containerd/rootfs" - digest "github.com/opencontainers/go-digest" + "github.com/opencontainers/go-digest" "github.com/opencontainers/image-spec/identity" ocispec "github.com/opencontainers/image-spec/specs-go/v1" "github.com/pkg/errors" @@ -149,6 +149,10 @@ func (i *image) Unpack(ctx context.Context, snapshotterName string) error { chain []digest.Digest unpacked bool ) + snapshotterName, err = i.client.resolveSnapshotterName(ctx, snapshotterName) + if err != nil { + return err + } sn, err := i.client.getSnapshotter(ctx, snapshotterName) if err != nil { return err