diff --git a/cmd/ctr/commands/run/run.go b/cmd/ctr/commands/run/run.go index b63bb4890..da91429cc 100644 --- a/cmd/ctr/commands/run/run.go +++ b/cmd/ctr/commands/run/run.go @@ -19,9 +19,7 @@ package run import ( gocontext "context" "encoding/csv" - "encoding/json" "fmt" - "io/ioutil" "strings" "github.com/containerd/console" @@ -37,17 +35,6 @@ import ( "github.com/urfave/cli" ) -func loadSpec(path string, s *specs.Spec) error { - raw, err := ioutil.ReadFile(path) - if err != nil { - return errors.New("cannot load spec config file") - } - if err := json.Unmarshal(raw, s); err != nil { - return errors.Errorf("decoding spec config file failed, current supported OCI runtime-spec : v%s", specs.Version) - } - return nil -} - func withMounts(context *cli.Context) oci.SpecOpts { return func(ctx gocontext.Context, client oci.Client, container *containers.Container, s *specs.Spec) error { mounts := make([]specs.Mount, 0) diff --git a/cmd/ctr/commands/run/run_unix.go b/cmd/ctr/commands/run/run_unix.go index a6ec1a637..6b1f242d2 100644 --- a/cmd/ctr/commands/run/run_unix.go +++ b/cmd/ctr/commands/run/run_unix.go @@ -52,6 +52,13 @@ func NewContainer(ctx gocontext.Context, client *containerd.Client, context *cli cOpts []containerd.NewContainerOpts spec containerd.NewContainerOpts ) + + if context.IsSet("config") { + opts = append(opts, oci.WithSpecFromFile(context.String("config"))) + } else { + opts = append(opts, oci.WithDefaultSpec()) + } + opts = append(opts, oci.WithEnv(context.StringSlice("env"))) opts = append(opts, withMounts(context)) cOpts = append(cOpts, containerd.WithContainerLabels(commands.LabelArgs(context.StringSlice("label")))) @@ -117,15 +124,10 @@ func NewContainer(ctx gocontext.Context, client *containerd.Client, context *cli if context.IsSet("gpus") { opts = append(opts, nvidia.WithGPUs(nvidia.WithDevices(context.Int("gpus")), nvidia.WithAllCapabilities)) } - if context.IsSet("config") { - var s specs.Spec - if err := loadSpec(context.String("config"), &s); err != nil { - return nil, err - } - spec = containerd.WithSpec(&s, opts...) - } else { - spec = containerd.WithNewSpec(opts...) - } + + var s specs.Spec + spec = containerd.WithSpec(&s, opts...) + cOpts = append(cOpts, spec) // oci.WithImageConfig (WithUsername, WithUserID) depends on rootfs snapshot for resolving /etc/passwd. diff --git a/cmd/ctr/commands/run/run_windows.go b/cmd/ctr/commands/run/run_windows.go index 9e4a0ec06..d80d3b065 100644 --- a/cmd/ctr/commands/run/run_windows.go +++ b/cmd/ctr/commands/run/run_windows.go @@ -63,6 +63,13 @@ func NewContainer(ctx gocontext.Context, client *containerd.Client, context *cli cOpts []containerd.NewContainerOpts spec containerd.NewContainerOpts ) + + if context.IsSet("config") { + opts = append(opts, oci.WithSpecFromFile(context.String("config"))) + } else { + opts = append(opts, oci.WithDefaultSpec()) + } + opts = append(opts, oci.WithImageConfig(image)) opts = append(opts, oci.WithEnv(context.StringSlice("env"))) opts = append(opts, withMounts(context)) @@ -74,15 +81,8 @@ func NewContainer(ctx gocontext.Context, client *containerd.Client, context *cli opts = append(opts, oci.WithProcessCwd(cwd)) } - if context.IsSet("config") { - var s specs.Spec - if err := loadSpec(context.String("config"), &s); err != nil { - return nil, err - } - spec = containerd.WithSpec(&s, opts...) - } else { - spec = containerd.WithNewSpec(opts...) - } + var s specs.Spec + spec = containerd.WithSpec(&s, opts...) cOpts = append(cOpts, containerd.WithContainerLabels(commands.LabelArgs(context.StringSlice("label")))) cOpts = append(cOpts, containerd.WithImage(image)) diff --git a/container_opts.go b/container_opts.go index df12cafd5..580715fee 100644 --- a/container_opts.go +++ b/container_opts.go @@ -197,11 +197,10 @@ func WithNewSpec(opts ...oci.SpecOpts) NewContainerOpts { // WithSpec sets the provided spec on the container func WithSpec(s *oci.Spec, opts ...oci.SpecOpts) NewContainerOpts { return func(ctx context.Context, client *Client, c *containers.Container) error { - for _, o := range opts { - if err := o(ctx, client, c, s); err != nil { - return err - } + if err := oci.ApplyOpts(ctx, client, c, s, opts...); err != nil { + return err } + var err error c.Spec, err = typeurl.MarshalAny(s) return err diff --git a/oci/spec.go b/oci/spec.go index 23f9315a4..ffd0bffca 100644 --- a/oci/spec.go +++ b/oci/spec.go @@ -34,10 +34,23 @@ func GenerateSpec(ctx context.Context, client Client, c *containers.Container, o if err != nil { return nil, err } + + return s, ApplyOpts(ctx, client, c, s, opts...) +} + +// ApplyOpts applys the options to the given spec, injecting data from the +// context, client and container instance. +func ApplyOpts(ctx context.Context, client Client, c *containers.Container, s *Spec, opts ...SpecOpts) error { for _, o := range opts { if err := o(ctx, client, c, s); err != nil { - return nil, err + return err } } - return s, nil + + return nil +} + +func createDefaultSpec(ctx context.Context, id string) (*Spec, error) { + var s Spec + return &s, populateDefaultSpec(ctx, &s, id) } diff --git a/oci/spec_opts.go b/oci/spec_opts.go index 57d95306c..fd2cfb039 100644 --- a/oci/spec_opts.go +++ b/oci/spec_opts.go @@ -18,10 +18,13 @@ package oci import ( "context" + "encoding/json" + "io/ioutil" "strings" "github.com/containerd/containerd/containers" specs "github.com/opencontainers/runtime-spec/specs-go" + "github.com/pkg/errors" ) // SpecOpts sets spec specific information to a newly generated OCI spec @@ -46,6 +49,38 @@ func setProcess(s *Spec) { } } +// WithDefaultSpec returns a SpecOpts that will populate the spec with default +// values. +// +// Use as the first option to clear the spec, then apply options afterwards. +func WithDefaultSpec() SpecOpts { + return func(ctx context.Context, _ Client, c *containers.Container, s *Spec) error { + return populateDefaultSpec(ctx, s, c.ID) + } +} + +// WithSpecFromBytes loads the the spec from the provided byte slice. +func WithSpecFromBytes(p []byte) SpecOpts { + return func(_ context.Context, _ Client, _ *containers.Container, s *Spec) error { + *s = Spec{} // make sure spec is cleared. + if err := json.Unmarshal(p, s); err != nil { + return errors.Wrapf(err, "decoding spec config file failed, current supported OCI runtime-spec : v%s", specs.Version) + } + return nil + } +} + +// WithSpecFromFile loads the specification from the provided filename. +func WithSpecFromFile(filename string) SpecOpts { + return func(ctx context.Context, c Client, container *containers.Container, s *Spec) error { + p, err := ioutil.ReadFile(filename) + if err != nil { + return errors.Wrap(err, "cannot load spec config file") + } + return WithSpecFromBytes(p)(ctx, c, container, s) + } +} + // WithProcessArgs replaces the args on the generated spec func WithProcessArgs(args ...string) SpecOpts { return func(_ context.Context, _ Client, _ *containers.Container, s *Spec) error { diff --git a/oci/spec_opts_test.go b/oci/spec_opts_test.go index 44f928cdb..83ad32173 100644 --- a/oci/spec_opts_test.go +++ b/oci/spec_opts_test.go @@ -17,8 +17,16 @@ package oci import ( + "context" + "encoding/json" + "io/ioutil" + "log" + "os" + "reflect" "testing" + "github.com/containerd/containerd/containers" + "github.com/containerd/containerd/namespaces" specs "github.com/opencontainers/runtime-spec/specs-go" ) @@ -87,3 +95,75 @@ func TestWithMounts(t *testing.T) { t.Fatal("invaid mount") } } + +func TestWithDefaultSpec(t *testing.T) { + t.Parallel() + var ( + s Spec + c = containers.Container{ID: "TestWithDefaultSpec"} + ctx = namespaces.WithNamespace(context.Background(), "test") + ) + + if err := ApplyOpts(ctx, nil, &c, &s, WithDefaultSpec()); err != nil { + t.Fatal(err) + } + + expected, err := createDefaultSpec(ctx, c.ID) + if err != nil { + t.Fatal(err) + } + + if reflect.DeepEqual(s, Spec{}) { + t.Fatalf("spec should not be empty") + } + + if !reflect.DeepEqual(&s, expected) { + t.Fatalf("spec from option differs from default: \n%#v != \n%#v", &s, expected) + } +} + +func TestWithSpecFromFile(t *testing.T) { + t.Parallel() + var ( + s Spec + c = containers.Container{ID: "TestWithDefaultSpec"} + ctx = namespaces.WithNamespace(context.Background(), "test") + ) + + fp, err := ioutil.TempFile("", "testwithdefaultspec.json") + if err != nil { + t.Fatal(err) + } + defer fp.Close() + defer func() { + if err := os.Remove(fp.Name()); err != nil { + log.Printf("failed to remove tempfile %v: %v", fp.Name(), err) + } + }() + + expected, err := GenerateSpec(ctx, nil, &c) + if err != nil { + t.Fatal(err) + } + + p, err := json.Marshal(expected) + if err != nil { + t.Fatal(err) + } + + if _, err := fp.Write(p); err != nil { + t.Fatal(err) + } + + if err := ApplyOpts(ctx, nil, &c, &s, WithSpecFromFile(fp.Name())); err != nil { + t.Fatal(err) + } + + if reflect.DeepEqual(s, Spec{}) { + t.Fatalf("spec should not be empty") + } + + if !reflect.DeepEqual(&s, expected) { + t.Fatalf("spec from option differs from default: \n%#v != \n%#v", &s, expected) + } +} diff --git a/oci/spec_unix.go b/oci/spec_unix.go index f8d8524dd..cb69434cb 100644 --- a/oci/spec_unix.go +++ b/oci/spec_unix.go @@ -76,12 +76,13 @@ func defaultNamespaces() []specs.LinuxNamespace { } } -func createDefaultSpec(ctx context.Context, id string) (*Spec, error) { +func populateDefaultSpec(ctx context.Context, s *Spec, id string) error { ns, err := namespaces.NamespaceRequired(ctx) if err != nil { - return nil, err + return err } - s := &Spec{ + + *s = Spec{ Version: specs.Version, Root: &specs.Root{ Path: defaultRootfsPath, @@ -183,5 +184,5 @@ func createDefaultSpec(ctx context.Context, id string) (*Spec, error) { Namespaces: defaultNamespaces(), }, } - return s, nil + return nil } diff --git a/oci/spec_windows.go b/oci/spec_windows.go index 82d7ef158..d0236585d 100644 --- a/oci/spec_windows.go +++ b/oci/spec_windows.go @@ -22,8 +22,8 @@ import ( specs "github.com/opencontainers/runtime-spec/specs-go" ) -func createDefaultSpec(ctx context.Context, id string) (*Spec, error) { - return &Spec{ +func populateDefaultSpec(ctx context.Context, s *Spec, id string) error { + *s = Spec{ Version: specs.Version, Root: &specs.Root{}, Process: &specs.Process{ @@ -39,5 +39,6 @@ func createDefaultSpec(ctx context.Context, id string) (*Spec, error) { AllowUnqualifiedDNSQuery: true, }, }, - }, nil + } + return nil }