diff --git a/oci/spec_opts.go b/oci/spec_opts.go index ccdf7d311..b9949f837 100644 --- a/oci/spec_opts.go +++ b/oci/spec_opts.go @@ -141,8 +141,10 @@ func WithEnv(environmentVariables []string) SpecOpts { // replaced by env key or appended to the list func replaceOrAppendEnvValues(defaults, overrides []string) []string { cache := make(map[string]int, len(defaults)) + results := make([]string, 0, len(defaults)) for i, e := range defaults { parts := strings.SplitN(e, "=", 2) + results = append(results, e) cache[parts[0]] = i } @@ -150,7 +152,7 @@ func replaceOrAppendEnvValues(defaults, overrides []string) []string { // Values w/o = means they want this env to be removed/unset. if !strings.Contains(value, "=") { if i, exists := cache[value]; exists { - defaults[i] = "" // Used to indicate it should be removed + results[i] = "" // Used to indicate it should be removed } continue } @@ -158,21 +160,21 @@ func replaceOrAppendEnvValues(defaults, overrides []string) []string { // Just do a normal set/update parts := strings.SplitN(value, "=", 2) if i, exists := cache[parts[0]]; exists { - defaults[i] = value + results[i] = value } else { - defaults = append(defaults, value) + results = append(results, value) } } // Now remove all entries that we want to "unset" - for i := 0; i < len(defaults); i++ { - if defaults[i] == "" { - defaults = append(defaults[:i], defaults[i+1:]...) + for i := 0; i < len(results); i++ { + if results[i] == "" { + results = append(results[:i], results[i+1:]...) i-- } } - return defaults + return results } // WithProcessArgs replaces the args on the generated spec @@ -310,7 +312,7 @@ func WithImageConfigArgs(image Image, args []string) SpecOpts { setProcess(s) if s.Linux != nil { - s.Process.Env = append(s.Process.Env, config.Env...) + s.Process.Env = replaceOrAppendEnvValues(s.Process.Env, config.Env) cmd := config.Cmd if len(args) > 0 { cmd = args @@ -332,8 +334,14 @@ func WithImageConfigArgs(image Image, args []string) SpecOpts { // even if there is no specified user in the image config return WithAdditionalGIDs("root")(ctx, client, c, s) } else if s.Windows != nil { - s.Process.Env = config.Env - s.Process.Args = append(config.Entrypoint, config.Cmd...) + s.Process.Env = replaceOrAppendEnvValues(s.Process.Env, config.Env) + cmd := config.Cmd + if len(args) > 0 { + cmd = args + } + s.Process.Args = append(config.Entrypoint, cmd...) + + s.Process.Cwd = config.WorkingDir s.Process.User = specs.User{ Username: config.User, } diff --git a/oci/spec_opts_test.go b/oci/spec_opts_test.go index 1bacce817..74b246566 100644 --- a/oci/spec_opts_test.go +++ b/oci/spec_opts_test.go @@ -19,6 +19,9 @@ package oci import ( "context" "encoding/json" + "errors" + "fmt" + "io" "io/ioutil" "log" "os" @@ -26,11 +29,135 @@ import ( "runtime" "testing" + "github.com/containerd/containerd/content" + "github.com/opencontainers/go-digest" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "github.com/containerd/containerd/containers" "github.com/containerd/containerd/namespaces" specs "github.com/opencontainers/runtime-spec/specs-go" ) +type blob []byte + +func (b blob) ReadAt(p []byte, off int64) (int, error) { + if off >= int64(len(b)) { + return 0, io.EOF + } + return copy(p, b[off:]), nil +} + +func (b blob) Close() error { + return nil +} + +func (b blob) Size() int64 { + return int64(len(b)) +} + +type fakeImage struct { + config ocispec.Descriptor + blobs map[string]blob +} + +func newFakeImage(config ocispec.Image) (Image, error) { + configBlob, err := json.Marshal(config) + if err != nil { + return nil, err + } + configDescriptor := ocispec.Descriptor{ + MediaType: ocispec.MediaTypeImageConfig, + Digest: digest.NewDigestFromBytes(digest.SHA256, configBlob), + } + + return fakeImage{ + config: configDescriptor, + blobs: map[string]blob{ + configDescriptor.Digest.String(): configBlob, + }, + }, nil +} + +func (i fakeImage) Config(ctx context.Context) (ocispec.Descriptor, error) { + return i.config, nil +} + +func (i fakeImage) ContentStore() content.Store { + return i +} + +func (i fakeImage) ReaderAt(ctx context.Context, dec ocispec.Descriptor) (content.ReaderAt, error) { + blob, found := i.blobs[dec.Digest.String()] + if !found { + return nil, errors.New("not found") + } + return blob, nil +} + +func (i fakeImage) Info(ctx context.Context, dgst digest.Digest) (content.Info, error) { + return content.Info{}, errors.New("not implemented") +} + +func (i fakeImage) Update(ctx context.Context, info content.Info, fieldpaths ...string) (content.Info, error) { + return content.Info{}, errors.New("not implemented") +} + +func (i fakeImage) Walk(ctx context.Context, fn content.WalkFunc, filters ...string) error { + return errors.New("not implemented") +} + +func (i fakeImage) Delete(ctx context.Context, dgst digest.Digest) error { + return errors.New("not implemented") +} + +func (i fakeImage) Status(ctx context.Context, ref string) (content.Status, error) { + return content.Status{}, errors.New("not implemented") +} + +func (i fakeImage) ListStatuses(ctx context.Context, filters ...string) ([]content.Status, error) { + return nil, errors.New("not implemented") +} + +func (i fakeImage) Abort(ctx context.Context, ref string) error { + return errors.New("not implemented") +} + +func (i fakeImage) Writer(ctx context.Context, opts ...content.WriterOpt) (content.Writer, error) { + return nil, errors.New("not implemented") +} + +func TestReplaceOrAppendEnvValues(t *testing.T) { + t.Parallel() + + defaults := []string{ + "o=ups", "p=$e", "x=foo", "y=boo", "z", "t=", + } + overrides := []string{ + "x=bar", "y", "a=42", "o=", "e", "s=", + } + expected := []string{ + "o=", "p=$e", "x=bar", "z", "t=", "a=42", "s=", + } + + defaultsOrig := make([]string, len(defaults)) + copy(defaultsOrig, defaults) + overridesOrig := make([]string, len(overrides)) + copy(overridesOrig, overrides) + + results := replaceOrAppendEnvValues(defaults, overrides) + + if err := assertEqualsStringArrays(defaults, defaultsOrig); err != nil { + t.Fatal(err) + } + if err := assertEqualsStringArrays(overrides, overridesOrig); err != nil { + t.Fatal(err) + } + + if err := assertEqualsStringArrays(results, expected); err != nil { + t.Fatal(err) + } +} + func TestWithEnv(t *testing.T) { t.Parallel() @@ -232,3 +359,66 @@ func TestWithMemoryLimit(t *testing.T) { } } } + +func isEqualStringArrays(values, expected []string) bool { + if len(values) != len(expected) { + return false + } + + for i, x := range expected { + if values[i] != x { + return false + } + } + return true +} + +func assertEqualsStringArrays(values, expected []string) error { + if !isEqualStringArrays(values, expected) { + return fmt.Errorf("expected %s, but found %s", expected, values) + } + return nil +} + +func TestWithImageConfigArgs(t *testing.T) { + t.Parallel() + + img, err := newFakeImage(ocispec.Image{ + Config: ocispec.ImageConfig{ + Env: []string{"z=bar", "y=baz"}, + Entrypoint: []string{"create", "--namespace=test"}, + Cmd: []string{"", "--debug"}, + }, + }) + if err != nil { + t.Fatal(err) + } + + s := Spec{ + Version: specs.Version, + Root: &specs.Root{}, + Windows: &specs.Windows{}, + } + + opts := []SpecOpts{ + WithEnv([]string{"x=foo", "y=boo"}), + WithProcessArgs("run", "--foo", "xyz", "--bar"), + WithImageConfigArgs(img, []string{"--boo", "bar"}), + } + + expectedEnv := []string{"x=foo", "y=baz", "z=bar"} + expectedArgs := []string{"create", "--namespace=test", "--boo", "bar"} + + for _, opt := range opts { + if err := opt(nil, nil, nil, &s); err != nil { + t.Fatal(err) + } + } + + if err := assertEqualsStringArrays(s.Process.Env, expectedEnv); err != nil { + t.Fatal(err) + } + if err := assertEqualsStringArrays(s.Process.Args, expectedArgs); err != nil { + t.Fatal(err) + } +}