diff --git a/cmd/containerd/command/main.go b/cmd/containerd/command/main.go index 0c5d4adae..31a21f6c6 100644 --- a/cmd/containerd/command/main.go +++ b/cmd/containerd/command/main.go @@ -88,6 +88,7 @@ func App() *cli.App { app.Commands = []cli.Command{ configCommand, publishCommand, + ociHook, } app.Action = func(context *cli.Context) error { var ( diff --git a/cmd/containerd/command/oci-hook.go b/cmd/containerd/command/oci-hook.go new file mode 100644 index 000000000..1ddc2c7c0 --- /dev/null +++ b/cmd/containerd/command/oci-hook.go @@ -0,0 +1,136 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package command + +import ( + "bytes" + "encoding/json" + "io" + "os" + "path/filepath" + "syscall" + "text/template" + + specs "github.com/opencontainers/runtime-spec/specs-go" + "github.com/urfave/cli" +) + +var ociHook = cli.Command{ + Name: "oci-hook", + Usage: "provides a base for OCI runtime hooks to allow arguments to be injected.", + Action: func(context *cli.Context) error { + state, err := loadHookState(os.Stdin) + if err != nil { + return err + } + var ( + ctx = newTemplateContext(state) + args = []string(context.Args()) + env = os.Environ() + ) + if err := newList(&args).render(ctx); err != nil { + return err + } + if err := newList(&env).render(ctx); err != nil { + return err + } + return syscall.Exec(args[0], args, env) + }, +} + +func loadHookState(r io.Reader) (*specs.State, error) { + var s specs.State + if err := json.NewDecoder(r).Decode(&s); err != nil { + return nil, err + } + return &s, nil +} + +func newTemplateContext(state *specs.State) *templateContext { + t := &templateContext{ + state: state, + } + t.funcs = template.FuncMap{ + "id": t.id, + "bundle": t.bundle, + "rootfs": t.rootfs, + "pid": t.pid, + "annotation": t.annotation, + "status": t.status, + } + return t +} + +type templateContext struct { + state *specs.State + funcs template.FuncMap +} + +func (t *templateContext) id() string { + return t.state.ID +} + +func (t *templateContext) bundle() string { + return t.state.Bundle +} + +func (t *templateContext) rootfs() string { + return filepath.Join(t.state.Bundle, "rootfs") +} + +func (t *templateContext) pid() int { + return t.state.Pid +} + +func (t *templateContext) annotation(k string) string { + return t.state.Annotations[k] +} + +func (t *templateContext) status() string { + return t.state.Status +} + +func render(ctx *templateContext, source string, out io.Writer) error { + t, err := template.New("oci-hook").Funcs(ctx.funcs).Parse(source) + if err != nil { + return err + } + return t.Execute(out, ctx) +} + +func newList(l *[]string) *templateList { + return &templateList{ + l: l, + } +} + +type templateList struct { + l *[]string +} + +func (l *templateList) render(ctx *templateContext) error { + buf := bytes.NewBuffer(nil) + for i, s := range *l.l { + buf.Reset() + if err := render(ctx, s, buf); err != nil { + return err + } + (*l.l)[i] = buf.String() + } + buf.Reset() + return nil +} diff --git a/cmd/ctr/commands/run/run.go b/cmd/ctr/commands/run/run.go index 1daffd3a7..b5a24a618 100644 --- a/cmd/ctr/commands/run/run.go +++ b/cmd/ctr/commands/run/run.go @@ -93,6 +93,10 @@ var ContainerFlags = []cli.Flag{ Name: "pid-file", Usage: "file path to write the task's pid", }, + cli.IntFlag{ + Name: "gpus", + Usage: "add gpus to the container", + }, } func loadSpec(path string, s *specs.Spec) error { diff --git a/cmd/ctr/commands/run/run_unix.go b/cmd/ctr/commands/run/run_unix.go index 0a3a5db68..991ee3350 100644 --- a/cmd/ctr/commands/run/run_unix.go +++ b/cmd/ctr/commands/run/run_unix.go @@ -24,6 +24,7 @@ import ( "github.com/containerd/containerd" "github.com/containerd/containerd/cmd/ctr/commands" + "github.com/containerd/containerd/contrib/nvidia" "github.com/containerd/containerd/oci" specs "github.com/opencontainers/runtime-spec/specs-go" "github.com/pkg/errors" @@ -123,6 +124,9 @@ func NewContainer(ctx gocontext.Context, client *containerd.Client, context *cli Path: parts[1], })) } + if context.IsSet("gpus") { + opts = append(opts, nvidia.WithGPUs(nvidia.WithDevices(context.Int("gpus")), nvidia.WithAllCapabilities)) + } if context.IsSet("config") { var s specs.Spec if err := loadSpec(context.String("config"), &s); err != nil { diff --git a/container_test.go b/container_test.go index 726ff0daf..dd88e627d 100644 --- a/container_test.go +++ b/container_test.go @@ -22,6 +22,7 @@ import ( "io" "io/ioutil" "os" + "os/exec" "runtime" "strings" "syscall" @@ -34,6 +35,7 @@ import ( "github.com/containerd/containerd/oci" _ "github.com/containerd/containerd/runtime" "github.com/containerd/typeurl" + specs "github.com/opencontainers/runtime-spec/specs-go" "github.com/containerd/containerd/errdefs" "github.com/containerd/containerd/windows/hcsshimtypes" @@ -1469,3 +1471,61 @@ func TestContainerLabels(t *testing.T) { t.Fatalf("expected label \"test\" to be \"no\"") } } + +func TestContainerHook(t *testing.T) { + t.Parallel() + + client, err := newClient(t, address) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + var ( + image Image + ctx, cancel = testContext() + id = t.Name() + ) + defer cancel() + + image, err = client.GetImage(ctx, testImage) + if err != nil { + t.Fatal(err) + } + hook := func(_ context.Context, _ oci.Client, _ *containers.Container, s *specs.Spec) error { + if s.Hooks == nil { + s.Hooks = &specs.Hooks{} + } + path, err := exec.LookPath("containerd") + if err != nil { + return err + } + psPath, err := exec.LookPath("ps") + if err != nil { + return err + } + s.Hooks.Prestart = []specs.Hook{ + { + Path: path, + Args: []string{ + "containerd", + "oci-hook", "--", + psPath, "--pid", "{{pid}}", + }, + Env: os.Environ(), + }, + } + return nil + } + container, err := client.NewContainer(ctx, id, WithNewSpec(oci.WithImageConfig(image), hook), WithNewSnapshot(id, image)) + if err != nil { + t.Fatal(err) + } + defer container.Delete(ctx, WithSnapshotCleanup) + + task, err := container.NewTask(ctx, empty()) + if err != nil { + t.Fatal(err) + } + defer task.Delete(ctx, WithProcessKill) +} diff --git a/contrib/nvidia/nvidia.go b/contrib/nvidia/nvidia.go new file mode 100644 index 000000000..296d85b64 --- /dev/null +++ b/contrib/nvidia/nvidia.go @@ -0,0 +1,185 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package nvidia + +import ( + "context" + "fmt" + "os" + "os/exec" + "strconv" + "strings" + + "github.com/containerd/containerd/containers" + "github.com/containerd/containerd/oci" + specs "github.com/opencontainers/runtime-spec/specs-go" +) + +const nvidiaCLI = "nvidia-container-cli" + +// Capability specifies capabilities for the gpu inside the container +// Detailed explaination of options can be found: +// https://github.com/nvidia/nvidia-container-runtime#supported-driver-capabilities +type Capability int + +const ( + // Compute capability + Compute Capability = iota + 1 + // Compat32 capability + Compat32 + // Graphics capability + Graphics + // Utility capability + Utility + // Video capability + Video + // Display capability + Display +) + +// WithGPUs adds NVIDIA gpu support to a container +func WithGPUs(opts ...Opts) oci.SpecOpts { + return func(_ context.Context, _ oci.Client, _ *containers.Container, s *specs.Spec) error { + c := &config{} + for _, o := range opts { + if err := o(c); err != nil { + return err + } + } + path, err := exec.LookPath("containerd") + if err != nil { + return err + } + nvidiaPath, err := exec.LookPath(nvidiaCLI) + if err != nil { + return err + } + if s.Hooks == nil { + s.Hooks = &specs.Hooks{} + } + s.Hooks.Prestart = append(s.Hooks.Prestart, specs.Hook{ + Path: path, + Args: append([]string{ + "containerd", + "oci-hook", + "--", + nvidiaPath, + }, c.args()...), + Env: os.Environ(), + }) + return nil + } +} + +type config struct { + Devices []int + DeviceUUID string + Capabilities []Capability + LoadKmods bool + LDCache string + LDConfig string + Requirements []string +} + +func (c *config) args() []string { + var args []string + + if c.LoadKmods { + args = append(args, "--load-kmods") + } + if c.LDCache != "" { + args = append(args, fmt.Sprintf("--ldcache=%s", c.LDCache)) + } + args = append(args, + "configure", + ) + if len(c.Devices) > 0 { + args = append(args, fmt.Sprintf("--device=%s", strings.Join(toStrings(c.Devices), ","))) + } + if c.DeviceUUID != "" { + args = append(args, fmt.Sprintf("--device=%s", c.DeviceUUID)) + } + for _, c := range c.Capabilities { + args = append(args, fmt.Sprintf("--%s", capFlags[c])) + } + if c.LDConfig != "" { + args = append(args, fmt.Sprintf("--ldconfig=%s", c.LDConfig)) + } + for _, r := range c.Requirements { + args = append(args, fmt.Sprintf("--require=%s", r)) + } + args = append(args, "--pid={{pid}}", "{{rootfs}}") + return args +} + +var capFlags = map[Capability]string{ + Compute: "compute", + Compat32: "compat32", + Graphics: "graphics", + Utility: "utility", + Video: "video", + Display: "display", +} + +func toStrings(ints []int) []string { + var s []string + for _, i := range ints { + s = append(s, strconv.Itoa(i)) + } + return s +} + +// Opts are options for configuring gpu support +type Opts func(*config) error + +// WithDevices adds the provided device indexes to the container +func WithDevices(ids ...int) Opts { + return func(c *config) error { + c.Devices = ids + return nil + } +} + +// WithDeviceUUID adds the specific device UUID to the container +func WithDeviceUUID(guid string) Opts { + return func(c *config) error { + c.DeviceUUID = guid + return nil + } +} + +// WithAllDevices adds all gpus to the container +func WithAllDevices(c *config) error { + c.DeviceUUID = "all" + return nil +} + +// WithAllCapabilities adds all capabilities to the container for the gpus +func WithAllCapabilities(c *config) error { + for k := range capFlags { + c.Capabilities = append(c.Capabilities, k) + } + return nil +} + +// WithRequiredCUDAVersion sets the required cuda version +func WithRequiredCUDAVersion(major, minor int) Opts { + return func(c *config) error { + c.Requirements = append(c.Requirements, fmt.Sprintf("cuda>=%d.%d", major, minor)) + return nil + } +}