diff --git a/pkg/server/sandbox_portforward.go b/pkg/server/sandbox_portforward.go index 64df5065c..da11b73be 100644 --- a/pkg/server/sandbox_portforward.go +++ b/pkg/server/sandbox_portforward.go @@ -23,12 +23,12 @@ import ( "os/exec" "strings" + "github.com/containernetworking/plugins/pkg/ns" "github.com/pkg/errors" "github.com/sirupsen/logrus" "golang.org/x/net/context" runtime "k8s.io/kubernetes/pkg/kubelet/apis/cri/runtime/v1alpha2" - ctrdutil "github.com/containerd/cri/pkg/containerd/util" sandboxstore "github.com/containerd/cri/pkg/store/sandbox" ) @@ -46,69 +46,50 @@ func (c *criService) PortForward(ctx context.Context, r *runtime.PortForwardRequ return c.streamServer.GetPortForward(r) } -// portForward requires `nsenter` and `socat` on the node, it uses `nsenter` to enter the -// sandbox namespace, and run `socat` inside the namespace to forward stream for a specific -// port. The `socat` command keeps running until it exits or client disconnect. +// portForward requires it uses netns to enter the sandbox namespace, +// and forward stream for a specific port. func (c *criService) portForward(id string, port int32, stream io.ReadWriteCloser) error { s, err := c.sandboxStore.Get(id) if err != nil { return errors.Wrapf(err, "failed to find sandbox %q in store", id) } - t, err := s.Container.Task(ctrdutil.NamespacedContext(), nil) - if err != nil { - return errors.Wrap(err, "failed to get sandbox container task") - } - pid := t.Pid() - - socat, err := exec.LookPath("socat") - if err != nil { - return errors.Wrap(err, "failed to find socat") + if s.NetNS == nil { + return errors.Errorf("failed to find network namespace fo sandbox %q in store", id) } - // Check following links for meaning of the options: - // * socat: https://linux.die.net/man/1/socat - // * nsenter: http://man7.org/linux/man-pages/man1/nsenter.1.html - args := []string{"-t", fmt.Sprintf("%d", pid), "-n", socat, - "-", fmt.Sprintf("TCP4:localhost:%d", port)} - - nsenter, err := exec.LookPath("nsenter") - if err != nil { - return errors.Wrap(err, "failed to find nsenter") - } - - logrus.Infof("Executing port forwarding command: %s %s", nsenter, strings.Join(args, " ")) - - cmd := exec.Command(nsenter, args...) - cmd.Stdout = stream - - stderr := new(bytes.Buffer) - cmd.Stderr = stderr - - // If we use Stdin, command.Run() won't return until the goroutine that's copying - // from stream finishes. Unfortunately, if you have a client like telnet connected - // via port forwarding, as long as the user's telnet client is connected to the user's - // local listener that port forwarding sets up, the telnet session never exits. This - // means that even if socat has finished running, command.Run() won't ever return - // (because the client still has the connection and stream open). - // - // The work around is to use StdinPipe(), as Wait() (called by Run()) closes the pipe - // when the command (socat) exits. - in, err := cmd.StdinPipe() - if err != nil { - return errors.Wrap(err, "failed to create stdin pipe") - } - go func() { - if _, err := io.Copy(in, stream); err != nil { - logrus.WithError(err).Errorf("Failed to copy port forward input for %q port %d", id, port) + err = s.NetNS.GetNs().Do(func(_ ns.NetNS) error { + var wg sync.WaitGroup + client, err := net.Dial("tcp4", fmt.Sprintf("localhost:%d", port)) + if err != nil { + return errors.Wrap(err, "failed to dial") } - in.Close() - logrus.Debugf("Finish copy port forward input for %q port %d", id, port) - }() + defer client.Close() + defer stream.Close() - if err := cmd.Run(); err != nil { - return errors.Errorf("nsenter command returns error: %v, stderr: %q", err, stderr.String()) + wg.Add(1) + go func() { + if _, err := io.Copy(client, stream); err != nil { + logrus.WithError(err).Errorf("Failed to copy port forward input from %q port %d", id, port) + } + logrus.Infof("Finish copy port forward input for %q port %d: %v", id, port) + wg.Done() + }() + wg.Add(1) + go func() { + if _, err := io.Copy(stream, client); err != nil { + logrus.WithError(err).Errorf("Failed to copy port forward output for %q port %d", id, port) + } + logrus.Infof("Finish copy port forward output for %q port %d: %v", id, port) + + wg.Done() + }() + wg.Wait() + + return nil + }) + if err != nil { + return errors.Wrapf(err, "failed to execute portforward in network namespace %s", s.NetNS.GetPath()) } - logrus.Infof("Finish port forwarding for %q port %d", id, port) return nil diff --git a/pkg/store/sandbox/netns.go b/pkg/store/sandbox/netns.go index 8ec4c1de5..5d56d9222 100644 --- a/pkg/store/sandbox/netns.go +++ b/pkg/store/sandbox/netns.go @@ -117,3 +117,10 @@ func (n *NetNS) GetPath() string { defer n.Unlock() return n.ns.Path() } + +// GetNs returns the network namespace handle +func (n *NetNS) GetNs() cnins.NetNS { + n.Lock() + defer n.Unlock() + return n.ns +}