diff --git a/oci/spec_opts.go b/oci/spec_opts.go index 5a952f616..8f74f62ed 100644 --- a/oci/spec_opts.go +++ b/oci/spec_opts.go @@ -1248,16 +1248,16 @@ var ErrNoShmMount = errors.New("no /dev/shm mount specified") // // The size value is specified in kb, kilobytes. func WithDevShmSize(kb int64) SpecOpts { - return func(ctx context.Context, _ Client, c *containers.Container, s *Spec) error { - for _, m := range s.Mounts { - if m.Source == "shm" && m.Type == "tmpfs" { - for i, o := range m.Options { - if strings.HasPrefix(o, "size=") { - m.Options[i] = fmt.Sprintf("size=%dk", kb) - return nil + return func(ctx context.Context, _ Client, _ *containers.Container, s *Spec) error { + for i, m := range s.Mounts { + if filepath.Clean(m.Destination) == "/dev/shm" && m.Source == "shm" && m.Type == "tmpfs" { + for i := 0; i < len(m.Options); i++ { + if strings.HasPrefix(m.Options[i], "size=") { + m.Options = append(m.Options[:i], m.Options[i+1:]...) + i-- } } - m.Options = append(m.Options, fmt.Sprintf("size=%dk", kb)) + s.Mounts[i].Options = append(m.Options, fmt.Sprintf("size=%dk", kb)) return nil } } diff --git a/oci/spec_opts_test.go b/oci/spec_opts_test.go index 7cc94b956..a7fbaa81b 100644 --- a/oci/spec_opts_test.go +++ b/oci/spec_opts_test.go @@ -552,55 +552,102 @@ func TestWithImageConfigArgs(t *testing.T) { func TestDevShmSize(t *testing.T) { t.Parallel() - var ( - s Spec - c = containers.Container{ID: t.Name()} - ctx = namespaces.WithNamespace(context.Background(), "test") - ) - err := populateDefaultUnixSpec(ctx, &s, c.ID) - if err != nil { - t.Fatal(err) + ss := []Spec{ + { + Mounts: []specs.Mount{ + { + Destination: "/dev/shm", + Type: "tmpfs", + Source: "shm", + Options: []string{"nosuid", "noexec", "nodev", "mode=1777"}, + }, + }, + }, + { + Mounts: []specs.Mount{ + { + Destination: "/test/shm", + Type: "tmpfs", + Source: "shm", + Options: []string{"nosuid", "noexec", "nodev", "mode=1777", "size=65536k"}, + }, + }, + }, + { + Mounts: []specs.Mount{ + { + Destination: "/test/shm", + Type: "tmpfs", + Source: "shm", + Options: []string{"nosuid", "noexec", "nodev", "mode=1777", "size=65536k"}, + }, + { + Destination: "/dev/shm", + Type: "tmpfs", + Source: "shm", + Options: []string{"nosuid", "noexec", "nodev", "mode=1777", "size=65536k", "size=131072k"}, + }, + }, + }, } expected := "1024k" - if err := WithDevShmSize(1024)(nil, nil, nil, &s); err != nil { - t.Fatal(err) - } - m := getShmMount(&s) - if m == nil { - t.Fatal("no shm mount found") - } - o := getShmSize(m.Options) - if o == "" { - t.Fatal("shm size not specified") - } - parts := strings.Split(o, "=") - if len(parts) != 2 { - t.Fatal("invalid size format") - } - size := parts[1] - if size != expected { - t.Fatalf("size %s not equal %s", size, expected) + for _, s := range ss { + if err := WithDevShmSize(1024)(nil, nil, nil, &s); err != nil { + if err != ErrNoShmMount { + t.Fatal(err) + } + + if getDevShmMount(&s) == nil { + continue + } + t.Fatal("excepted nil /dev/shm mount") + } + + m := getDevShmMount(&s) + if m == nil { + t.Fatal("no shm mount found") + } + size, err := getShmSize(m.Options) + if err != nil { + t.Fatal(err) + } + if size != expected { + t.Fatalf("size %s not equal %s", size, expected) + } } } -func getShmMount(s *Spec) *specs.Mount { +func getDevShmMount(s *Spec) *specs.Mount { for _, m := range s.Mounts { - if m.Source == "shm" && m.Type == "tmpfs" { + if filepath.Clean(m.Destination) == "/dev/shm" && m.Source == "shm" && m.Type == "tmpfs" { return &m } } return nil } -func getShmSize(opts []string) string { +func getShmSize(opts []string) (string, error) { + // linux will use the last size option + var so string for _, o := range opts { if strings.HasPrefix(o, "size=") { - return o + if so != "" { + return "", errors.New("contains multiple size options") + } + so = o } } - return "" + if so == "" { + return "", errors.New("shm size not specified") + } + + parts := strings.Split(so, "=") + if len(parts) != 2 { + return "", errors.New("invalid size format") + } + return parts[1], nil } func TestWithoutMounts(t *testing.T) {