diff --git a/cmd/containerd-shim-runc-v2/manager/manager_linux.go b/cmd/containerd-shim-runc-v2/manager/manager_linux.go index 8ab77220a..ccf7e0d65 100644 --- a/cmd/containerd-shim-runc-v2/manager/manager_linux.go +++ b/cmd/containerd-shim-runc-v2/manager/manager_linux.go @@ -23,6 +23,7 @@ import ( "errors" "fmt" "io" + "net" "os" "os/exec" "path/filepath" @@ -126,6 +127,59 @@ func (m manager) Name() string { return m.name } +type shimSocket struct { + addr string + s *net.UnixListener + f *os.File +} + +func (s *shimSocket) Close() { + if s.s != nil { + s.s.Close() + } + if s.f != nil { + s.f.Close() + } + _ = shim.RemoveSocket(s.addr) +} + +func newShimSocket(ctx context.Context, path, id string, debug bool) (*shimSocket, error) { + address, err := shim.SocketAddress(ctx, path, id, debug) + if err != nil { + return nil, err + } + socket, err := shim.NewSocket(address) + if err != nil { + // the only time where this would happen is if there is a bug and the socket + // was not cleaned up in the cleanup method of the shim or we are using the + // grouping functionality where the new process should be run with the same + // shim as an existing container + if !shim.SocketEaddrinuse(err) { + return nil, fmt.Errorf("create new shim socket: %w", err) + } + if !debug && shim.CanConnect(address) { + return &shimSocket{addr: address}, errdefs.ErrAlreadyExists + } + if err := shim.RemoveSocket(address); err != nil { + return nil, fmt.Errorf("remove pre-existing socket: %w", err) + } + if socket, err = shim.NewSocket(address); err != nil { + return nil, fmt.Errorf("try create new shim socket 2x: %w", err) + } + } + s := &shimSocket{ + addr: address, + s: socket, + } + f, err := socket.File() + if err != nil { + s.Close() + return nil, err + } + s.f = f + return s, nil +} + func (manager) Start(ctx context.Context, id string, opts shim.StartOpts) (_ shim.BootstrapParams, retErr error) { var params shim.BootstrapParams params.Version = 3 @@ -146,44 +200,35 @@ func (manager) Start(ctx context.Context, id string, opts shim.StartOpts) (_ shi break } } - address, err := shim.SocketAddress(ctx, opts.Address, grouping) - if err != nil { - return params, err - } - socket, err := shim.NewSocket(address) - if err != nil { - // the only time where this would happen is if there is a bug and the socket - // was not cleaned up in the cleanup method of the shim or we are using the - // grouping functionality where the new process should be run with the same - // shim as an existing container - if !shim.SocketEaddrinuse(err) { - return params, fmt.Errorf("create new shim socket: %w", err) - } - if shim.CanConnect(address) { - params.Address = address - return params, nil - } - if err := shim.RemoveSocket(address); err != nil { - return params, fmt.Errorf("remove pre-existing socket: %w", err) - } - if socket, err = shim.NewSocket(address); err != nil { - return params, fmt.Errorf("try create new shim socket 2x: %w", err) - } - } + var sockets []*shimSocket defer func() { if retErr != nil { - socket.Close() - _ = shim.RemoveSocket(address) + for _, s := range sockets { + s.Close() + } } }() - f, err := socket.File() + s, err := newShimSocket(ctx, opts.Address, grouping, false) if err != nil { + if errdefs.IsAlreadyExists(err) { + params.Address = s.addr + return params, nil + } return params, err } + sockets = append(sockets, s) + cmd.ExtraFiles = append(cmd.ExtraFiles, s.f) - cmd.ExtraFiles = append(cmd.ExtraFiles, f) + if opts.Debug { + s, err = newShimSocket(ctx, opts.Address, grouping, true) + if err != nil { + return params, err + } + sockets = append(sockets, s) + cmd.ExtraFiles = append(cmd.ExtraFiles, s.f) + } goruntime.LockOSThread() if os.Getenv("SCHED_CORE") != "" { @@ -193,7 +238,6 @@ func (manager) Start(ctx context.Context, id string, opts shim.StartOpts) (_ shi } if err := cmd.Start(); err != nil { - f.Close() return params, err } @@ -233,7 +277,7 @@ func (manager) Start(ctx context.Context, id string, opts shim.StartOpts) (_ shi return params, fmt.Errorf("failed to adjust OOM score for shim: %w", err) } - params.Address = address + params.Address = sockets[0].addr return params, nil } diff --git a/cmd/ctr/commands/pprof/pprof.go b/cmd/ctr/commands/pprof/pprof.go index a1d66b05f..726326f5a 100644 --- a/cmd/ctr/commands/pprof/pprof.go +++ b/cmd/ctr/commands/pprof/pprof.go @@ -65,16 +65,7 @@ var pprofGoroutinesCommand = &cli.Command{ }, }, Action: func(cliContext *cli.Context) error { - client := getPProfClient(cliContext) - - debug := cliContext.Uint("debug") - output, err := httpGetRequest(client, fmt.Sprintf("/debug/pprof/goroutine?debug=%d", debug)) - if err != nil { - return err - } - defer output.Close() - _, err = io.Copy(os.Stdout, output) - return err + return GoroutineProfile(cliContext, getPProfClient) }, } @@ -89,16 +80,7 @@ var pprofHeapCommand = &cli.Command{ }, }, Action: func(cliContext *cli.Context) error { - client := getPProfClient(cliContext) - - debug := cliContext.Uint("debug") - output, err := httpGetRequest(client, fmt.Sprintf("/debug/pprof/heap?debug=%d", debug)) - if err != nil { - return err - } - defer output.Close() - _, err = io.Copy(os.Stdout, output) - return err + return HeapProfile(cliContext, getPProfClient) }, } @@ -119,17 +101,7 @@ var pprofProfileCommand = &cli.Command{ }, }, Action: func(cliContext *cli.Context) error { - client := getPProfClient(cliContext) - - seconds := cliContext.Duration("seconds").Seconds() - debug := cliContext.Uint("debug") - output, err := httpGetRequest(client, fmt.Sprintf("/debug/pprof/profile?seconds=%v&debug=%d", seconds, debug)) - if err != nil { - return err - } - defer output.Close() - _, err = io.Copy(os.Stdout, output) - return err + return CPUProfile(cliContext, getPProfClient) }, } @@ -150,18 +122,7 @@ var pprofTraceCommand = &cli.Command{ }, }, Action: func(cliContext *cli.Context) error { - client := getPProfClient(cliContext) - - seconds := cliContext.Duration("seconds").Seconds() - debug := cliContext.Uint("debug") - uri := fmt.Sprintf("/debug/pprof/trace?seconds=%v&debug=%d", seconds, debug) - output, err := httpGetRequest(client, uri) - if err != nil { - return err - } - defer output.Close() - _, err = io.Copy(os.Stdout, output) - return err + return TraceProfile(cliContext, getPProfClient) }, } @@ -176,16 +137,7 @@ var pprofBlockCommand = &cli.Command{ }, }, Action: func(cliContext *cli.Context) error { - client := getPProfClient(cliContext) - - debug := cliContext.Uint("debug") - output, err := httpGetRequest(client, fmt.Sprintf("/debug/pprof/block?debug=%d", debug)) - if err != nil { - return err - } - defer output.Close() - _, err = io.Copy(os.Stdout, output) - return err + return BlockProfile(cliContext, getPProfClient) }, } @@ -200,27 +152,120 @@ var pprofThreadcreateCommand = &cli.Command{ }, }, Action: func(cliContext *cli.Context) error { - client := getPProfClient(cliContext) - - debug := cliContext.Uint("debug") - output, err := httpGetRequest(client, fmt.Sprintf("/debug/pprof/threadcreate?debug=%d", debug)) - if err != nil { - return err - } - defer output.Close() - _, err = io.Copy(os.Stdout, output) - return err + return ThreadcreateProfile(cliContext, getPProfClient) }, } -func getPProfClient(cliContext *cli.Context) *http.Client { +// Client is a func that returns a http client for a pprof server +type Client func(cliContext *cli.Context) (*http.Client, error) + +// GoroutineProfile dumps goroutine stack dump +func GoroutineProfile(cliContext *cli.Context, clientFunc Client) error { + client, err := clientFunc(cliContext) + if err != nil { + return err + } + debug := cliContext.Uint("debug") + output, err := httpGetRequest(client, fmt.Sprintf("/debug/pprof/goroutine?debug=%d", debug)) + if err != nil { + return err + } + defer output.Close() + _, err = io.Copy(os.Stdout, output) + return err +} + +// HeapProfile dumps the heap profile +func HeapProfile(cliContext *cli.Context, clientFunc Client) error { + client, err := clientFunc(cliContext) + if err != nil { + return err + } + debug := cliContext.Uint("debug") + output, err := httpGetRequest(client, fmt.Sprintf("/debug/pprof/heap?debug=%d", debug)) + if err != nil { + return err + } + defer output.Close() + _, err = io.Copy(os.Stdout, output) + return err +} + +// CPUProfile dumps CPU profile +func CPUProfile(cliContext *cli.Context, clientFunc Client) error { + client, err := clientFunc(cliContext) + if err != nil { + return err + } + seconds := cliContext.Duration("seconds").Seconds() + debug := cliContext.Uint("debug") + output, err := httpGetRequest(client, fmt.Sprintf("/debug/pprof/profile?seconds=%v&debug=%d", seconds, debug)) + if err != nil { + return err + } + defer output.Close() + _, err = io.Copy(os.Stdout, output) + return err +} + +// TraceProfile collects execution trace +func TraceProfile(cliContext *cli.Context, clientFunc Client) error { + client, err := clientFunc(cliContext) + if err != nil { + return err + } + seconds := cliContext.Duration("seconds").Seconds() + debug := cliContext.Uint("debug") + uri := fmt.Sprintf("/debug/pprof/trace?seconds=%v&debug=%d", seconds, debug) + output, err := httpGetRequest(client, uri) + if err != nil { + return err + } + defer output.Close() + _, err = io.Copy(os.Stdout, output) + return err +} + +// BlockProfile collects goroutine blocking profile +func BlockProfile(cliContext *cli.Context, clientFunc Client) error { + client, err := clientFunc(cliContext) + if err != nil { + return err + } + debug := cliContext.Uint("debug") + output, err := httpGetRequest(client, fmt.Sprintf("/debug/pprof/block?debug=%d", debug)) + if err != nil { + return err + } + defer output.Close() + _, err = io.Copy(os.Stdout, output) + return err +} + +// ThreadcreateProfile collects goroutine thread creating profile +func ThreadcreateProfile(cliContext *cli.Context, clientFunc Client) error { + client, err := clientFunc(cliContext) + if err != nil { + return err + } + debug := cliContext.Uint("debug") + output, err := httpGetRequest(client, fmt.Sprintf("/debug/pprof/threadcreate?debug=%d", debug)) + if err != nil { + return err + } + defer output.Close() + _, err = io.Copy(os.Stdout, output) + return err +} + +func getPProfClient(cliContext *cli.Context) (*http.Client, error) { dialer := getPProfDialer(cliContext.String("debug-socket")) tr := &http.Transport{ Dial: dialer.pprofDial, } client := &http.Client{Transport: tr} - return client + return client, nil } func httpGetRequest(client *http.Client, request string) (io.ReadCloser, error) { diff --git a/cmd/ctr/commands/shim/pprof.go b/cmd/ctr/commands/shim/pprof.go new file mode 100644 index 000000000..a1c0b0bfc --- /dev/null +++ b/cmd/ctr/commands/shim/pprof.go @@ -0,0 +1,165 @@ +//go:build !windows + +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package shim + +import ( + "context" + "errors" + "net" + "net/http" + "strings" + "time" + + "github.com/containerd/containerd/v2/cmd/ctr/commands/pprof" + "github.com/containerd/containerd/v2/pkg/namespaces" + "github.com/containerd/containerd/v2/pkg/shim" + "github.com/urfave/cli/v2" +) + +var pprofCommand = &cli.Command{ + Name: "pprof", + Usage: "Provide golang pprof outputs for containerd-shim", + Subcommands: []*cli.Command{ + pprofBlockCommand, + pprofGoroutinesCommand, + pprofHeapCommand, + pprofProfileCommand, + pprofThreadcreateCommand, + pprofTraceCommand, + }, +} + +var pprofGoroutinesCommand = &cli.Command{ + Name: "goroutines", + Usage: "Print goroutine stack dump", + Flags: []cli.Flag{ + &cli.UintFlag{ + Name: "debug", + Usage: "Output format, value = 0: binary, value > 0: plaintext", + Value: 2, + }, + }, + Action: func(cliContext *cli.Context) error { + return pprof.GoroutineProfile(cliContext, getPProfClient) + }, +} + +var pprofHeapCommand = &cli.Command{ + Name: "heap", + Usage: "Dump heap profile", + Flags: []cli.Flag{ + &cli.UintFlag{ + Name: "debug", + Usage: "Output format, value = 0: binary, value > 0: plaintext", + Value: 0, + }, + }, + Action: func(cliContext *cli.Context) error { + return pprof.HeapProfile(cliContext, getPProfClient) + }, +} + +var pprofProfileCommand = &cli.Command{ + Name: "profile", + Usage: "CPU profile", + Flags: []cli.Flag{ + &cli.DurationFlag{ + Name: "seconds", + Aliases: []string{"s"}, + Usage: "Duration for collection (seconds)", + Value: 30 * time.Second, + }, + &cli.UintFlag{ + Name: "debug", + Usage: "Output format, value = 0: binary, value > 0: plaintext", + Value: 0, + }, + }, + Action: func(cliContext *cli.Context) error { + return pprof.CPUProfile(cliContext, getPProfClient) + }, +} + +var pprofTraceCommand = &cli.Command{ + Name: "trace", + Usage: "Collect execution trace", + Flags: []cli.Flag{ + &cli.DurationFlag{ + Name: "seconds", + Aliases: []string{"s"}, + Usage: "Trace time (seconds)", + Value: 5 * time.Second, + }, + &cli.UintFlag{ + Name: "debug", + Usage: "Output format, value = 0: binary, value > 0: plaintext", + Value: 0, + }, + }, + Action: func(cliContext *cli.Context) error { + return pprof.TraceProfile(cliContext, getPProfClient) + }, +} + +var pprofBlockCommand = &cli.Command{ + Name: "block", + Usage: "Goroutine blocking profile", + Flags: []cli.Flag{ + &cli.UintFlag{ + Name: "debug", + Usage: "Output format, value = 0: binary, value > 0: plaintext", + Value: 0, + }, + }, + Action: func(cliContext *cli.Context) error { + return pprof.BlockProfile(cliContext, getPProfClient) + }, +} + +var pprofThreadcreateCommand = &cli.Command{ + Name: "threadcreate", + Usage: "Goroutine thread creating profile", + Flags: []cli.Flag{ + &cli.UintFlag{ + Name: "debug", + Usage: "Output format, value = 0: binary, value > 0: plaintext", + Value: 0, + }, + }, + Action: func(cliContext *cli.Context) error { + return pprof.ThreadcreateProfile(cliContext, getPProfClient) + }, +} + +func getPProfClient(cliContext *cli.Context) (*http.Client, error) { + id := cliContext.String("id") + if id == "" { + return nil, errors.New("container id must be provided") + } + tr := &http.Transport{ + Dial: func(_, _ string) (net.Conn, error) { + ns := cliContext.String("namespace") + ctx := namespaces.WithNamespace(context.Background(), ns) + s, _ := shim.SocketAddress(ctx, cliContext.String("address"), id, true) + s = strings.TrimPrefix(s, "unix://") + return net.Dial("unix", s) + }, + } + return &http.Client{Transport: tr}, nil +} diff --git a/cmd/ctr/commands/shim/shim.go b/cmd/ctr/commands/shim/shim.go index 1dd458247..7ea60199a 100644 --- a/cmd/ctr/commands/shim/shim.go +++ b/cmd/ctr/commands/shim/shim.go @@ -75,6 +75,7 @@ var Command = &cli.Command{ execCommand, startCommand, stateCommand, + pprofCommand, }, } @@ -244,7 +245,7 @@ func getTaskService(cliContext *cli.Context) (task.TTRPCTaskService, error) { s1 := filepath.Join(string(filepath.Separator), "containerd-shim", ns, id, "shim.sock") // this should not error, ctr always get a default ns ctx := namespaces.WithNamespace(context.Background(), ns) - s2, _ := shim.SocketAddress(ctx, cliContext.String("address"), id) + s2, _ := shim.SocketAddress(ctx, cliContext.String("address"), id, false) s2 = strings.TrimPrefix(s2, "unix://") for _, socket := range []string{s2, "\x00" + s1} { diff --git a/integration/client/container_linux_test.go b/integration/client/container_linux_test.go index 55ae51443..bbb7bb91e 100644 --- a/integration/client/container_linux_test.go +++ b/integration/client/container_linux_test.go @@ -365,7 +365,7 @@ func TestShimDoesNotLeakSockets(t *testing.T) { t.Fatal(err) } - s, err := shim.SocketAddress(ctx, address, id) + s, err := shim.SocketAddress(ctx, address, id, false) if err != nil { t.Fatal(err) } diff --git a/integration/issue7496_linux_test.go b/integration/issue7496_linux_test.go index 93ad8d25e..0e42e85b1 100644 --- a/integration/issue7496_linux_test.go +++ b/integration/issue7496_linux_test.go @@ -157,7 +157,7 @@ func injectDelayToUmount2(ctx context.Context, t *testing.T, shimCli apitask.TTR } func connectToShim(ctx context.Context, t *testing.T, ctrdEndpoint string, version int, id string) shimcore.TaskServiceClient { - addr, err := shim.SocketAddress(ctx, ctrdEndpoint, id) + addr, err := shim.SocketAddress(ctx, ctrdEndpoint, id, false) require.NoError(t, err) addr = strings.TrimPrefix(addr, "unix://") diff --git a/pkg/shim/shim.go b/pkg/shim/shim.go index b8a28acd6..6cf345258 100644 --- a/pkg/shim/shim.go +++ b/pkg/shim/shim.go @@ -20,10 +20,13 @@ import ( "context" "encoding/json" "errors" + "expvar" "flag" "fmt" "io" "net" + "net/http" + "net/http/pprof" "os" "path/filepath" "runtime" @@ -121,6 +124,7 @@ var ( id string namespaceFlag string socketFlag string + debugSocketFlag string bundlePath string addressFlag string containerdBinaryFlag string @@ -143,6 +147,7 @@ func parseFlags() { flag.StringVar(&namespaceFlag, "namespace", "", "namespace that owns the shim") flag.StringVar(&id, "id", "", "id of the task") flag.StringVar(&socketFlag, "socket", "", "socket path to serve") + flag.StringVar(&debugSocketFlag, "debug-socket", "", "debug socket path to serve") flag.StringVar(&bundlePath, "bundle", "", "path to the bundle if not workdir") flag.StringVar(&addressFlag, "address", "", "grpc address back to main containerd") @@ -435,7 +440,7 @@ func serve(ctx context.Context, server *ttrpc.Server, signals chan os.Signal, sh return err } - l, err := serveListener(socketFlag) + l, err := serveListener(socketFlag, 3) if err != nil { return err } @@ -445,6 +450,13 @@ func serve(ctx context.Context, server *ttrpc.Server, signals chan os.Signal, sh log.G(ctx).WithError(err).Fatal("containerd-shim: ttrpc server failure") } }() + + if debugFlag { + if err := serveDebug(ctx); err != nil { + return err + } + } + logger := log.G(ctx).WithFields(log.Fields{ "pid": os.Getpid(), "path": path, @@ -460,6 +472,31 @@ func serve(ctx context.Context, server *ttrpc.Server, signals chan os.Signal, sh return reap(ctx, logger, signals) } +func serveDebug(ctx context.Context) error { + l, err := serveListener(debugSocketFlag, 4) + if err != nil { + return err + } + go func() { + defer l.Close() + m := http.NewServeMux() + m.Handle("/debug/vars", expvar.Handler()) + m.Handle("/debug/pprof/", http.HandlerFunc(pprof.Index)) + m.Handle("/debug/pprof/cmdline", http.HandlerFunc(pprof.Cmdline)) + m.Handle("/debug/pprof/profile", http.HandlerFunc(pprof.Profile)) + m.Handle("/debug/pprof/symbol", http.HandlerFunc(pprof.Symbol)) + m.Handle("/debug/pprof/trace", http.HandlerFunc(pprof.Trace)) + srv := &http.Server{ + Handler: m, + ReadHeaderTimeout: 5 * time.Minute, + } + if err := srv.Serve(l); err != nil && !errors.Is(err, net.ErrClosed) { + log.G(ctx).WithError(err).Fatal("containerd-shim: pprof endpoint failure") + } + }() + return nil +} + func dumpStacks(logger *log.Entry) { var ( buf []byte diff --git a/pkg/shim/shim_unix.go b/pkg/shim/shim_unix.go index 1d3057ca3..20dd8a87e 100644 --- a/pkg/shim/shim_unix.go +++ b/pkg/shim/shim_unix.go @@ -49,13 +49,13 @@ func setupDumpStacks(dump chan<- os.Signal) { signal.Notify(dump, syscall.SIGUSR1) } -func serveListener(path string) (net.Listener, error) { +func serveListener(path string, fd uintptr) (net.Listener, error) { var ( l net.Listener err error ) if path == "" { - l, err = net.FileListener(os.NewFile(3, "socket")) + l, err = net.FileListener(os.NewFile(fd, "socket")) path = "[inherited from parent]" } else { if len(path) > socketPathLimit { diff --git a/pkg/shim/shim_windows.go b/pkg/shim/shim_windows.go index 9ece51aca..a4296958e 100644 --- a/pkg/shim/shim_windows.go +++ b/pkg/shim/shim_windows.go @@ -42,7 +42,7 @@ func subreaper() error { func setupDumpStacks(dump chan<- os.Signal) { } -func serveListener(path string) (net.Listener, error) { +func serveListener(path string, fd uintptr) (net.Listener, error) { return nil, errdefs.ErrNotImplemented } diff --git a/pkg/shim/util_unix.go b/pkg/shim/util_unix.go index 6ed445ae4..61a8353fb 100644 --- a/pkg/shim/util_unix.go +++ b/pkg/shim/util_unix.go @@ -76,12 +76,16 @@ func AdjustOOMScore(pid int) error { const socketRoot = defaults.DefaultStateDir // SocketAddress returns a socket address -func SocketAddress(ctx context.Context, socketPath, id string) (string, error) { +func SocketAddress(ctx context.Context, socketPath, id string, debug bool) (string, error) { ns, err := namespaces.NamespaceRequired(ctx) if err != nil { return "", err } - d := sha256.Sum256([]byte(filepath.Join(socketPath, ns, id))) + path := filepath.Join(socketPath, ns, id) + if debug { + path = filepath.Join(path, "debug") + } + d := sha256.Sum256([]byte(path)) return fmt.Sprintf("unix://%s/%x", filepath.Join(socketRoot, "s"), d), nil } @@ -286,7 +290,12 @@ func cleanupSockets(ctx context.Context) { } if len(socketFlag) > 0 { _ = RemoveSocket("unix://" + socketFlag) - } else if address, err := SocketAddress(ctx, addressFlag, id); err == nil { + } else if address, err := SocketAddress(ctx, addressFlag, id, false); err == nil { + _ = RemoveSocket(address) + } + if len(debugSocketFlag) > 0 { + _ = RemoveSocket("unix://" + debugSocketFlag) + } else if address, err := SocketAddress(ctx, addressFlag, id, true); err == nil { _ = RemoveSocket(address) } }