diff --git a/oci/spec_opts.go b/oci/spec_opts.go index 14a2a439c..7156338c8 100644 --- a/oci/spec_opts.go +++ b/oci/spec_opts.go @@ -1228,16 +1228,16 @@ var ErrNoShmMount = errors.New("no /dev/shm mount specified") // // The size value is specified in kb, kilobytes. func WithDevShmSize(kb int64) SpecOpts { - return func(ctx context.Context, _ Client, c *containers.Container, s *Spec) error { - for _, m := range s.Mounts { - if m.Source == "shm" && m.Type == "tmpfs" { - for i, o := range m.Options { - if strings.HasPrefix(o, "size=") { - m.Options[i] = fmt.Sprintf("size=%dk", kb) - return nil + return func(ctx context.Context, _ Client, _ *containers.Container, s *Spec) error { + for i, m := range s.Mounts { + if filepath.Clean(m.Destination) == "/dev/shm" && m.Source == "shm" && m.Type == "tmpfs" { + for i := 0; i < len(m.Options); i++ { + if strings.HasPrefix(m.Options[i], "size=") { + m.Options = append(m.Options[:i], m.Options[i+1:]...) + i-- } } - m.Options = append(m.Options, fmt.Sprintf("size=%dk", kb)) + s.Mounts[i].Options = append(m.Options, fmt.Sprintf("size=%dk", kb)) return nil } } diff --git a/oci/spec_opts_test.go b/oci/spec_opts_test.go index d0d585a38..407cf20aa 100644 --- a/oci/spec_opts_test.go +++ b/oci/spec_opts_test.go @@ -25,6 +25,7 @@ import ( "io/ioutil" "log" "os" + "path/filepath" "reflect" "runtime" "strings" @@ -551,53 +552,100 @@ func TestWithImageConfigArgs(t *testing.T) { func TestDevShmSize(t *testing.T) { t.Parallel() - var ( - s Spec - c = containers.Container{ID: t.Name()} - ctx = namespaces.WithNamespace(context.Background(), "test") - ) - err := populateDefaultUnixSpec(ctx, &s, c.ID) - if err != nil { - t.Fatal(err) + ss := []Spec{ + { + Mounts: []specs.Mount{ + { + Destination: "/dev/shm", + Type: "tmpfs", + Source: "shm", + Options: []string{"nosuid", "noexec", "nodev", "mode=1777"}, + }, + }, + }, + { + Mounts: []specs.Mount{ + { + Destination: "/test/shm", + Type: "tmpfs", + Source: "shm", + Options: []string{"nosuid", "noexec", "nodev", "mode=1777", "size=65536k"}, + }, + }, + }, + { + Mounts: []specs.Mount{ + { + Destination: "/test/shm", + Type: "tmpfs", + Source: "shm", + Options: []string{"nosuid", "noexec", "nodev", "mode=1777", "size=65536k"}, + }, + { + Destination: "/dev/shm", + Type: "tmpfs", + Source: "shm", + Options: []string{"nosuid", "noexec", "nodev", "mode=1777", "size=65536k", "size=131072k"}, + }, + }, + }, } expected := "1024k" - if err := WithDevShmSize(1024)(nil, nil, nil, &s); err != nil { - t.Fatal(err) - } - m := getShmMount(&s) - if m == nil { - t.Fatal("no shm mount found") - } - o := getShmSize(m.Options) - if o == "" { - t.Fatal("shm size not specified") - } - parts := strings.Split(o, "=") - if len(parts) != 2 { - t.Fatal("invalid size format") - } - size := parts[1] - if size != expected { - t.Fatalf("size %s not equal %s", size, expected) + for _, s := range ss { + if err := WithDevShmSize(1024)(nil, nil, nil, &s); err != nil { + if err != ErrNoShmMount { + t.Fatal(err) + } + + if getDevShmMount(&s) == nil { + continue + } + t.Fatal("excepted nil /dev/shm mount") + } + + m := getDevShmMount(&s) + if m == nil { + t.Fatal("no shm mount found") + } + size, err := getShmSize(m.Options) + if err != nil { + t.Fatal(err) + } + if size != expected { + t.Fatalf("size %s not equal %s", size, expected) + } } } -func getShmMount(s *Spec) *specs.Mount { +func getDevShmMount(s *Spec) *specs.Mount { for _, m := range s.Mounts { - if m.Source == "shm" && m.Type == "tmpfs" { + if filepath.Clean(m.Destination) == "/dev/shm" && m.Source == "shm" && m.Type == "tmpfs" { return &m } } return nil } -func getShmSize(opts []string) string { +func getShmSize(opts []string) (string, error) { + // linux will use the last size option + var so string for _, o := range opts { if strings.HasPrefix(o, "size=") { - return o + if so != "" { + return "", errors.New("contains multiple size options") + } + so = o } } - return "" + if so == "" { + return "", errors.New("shm size not specified") + } + + parts := strings.Split(so, "=") + if len(parts) != 2 { + return "", errors.New("invalid size format") + } + return parts[1], nil }