diff --git a/internal/cri/server/sandbox_portforward_windows.go b/internal/cri/server/sandbox_portforward_windows.go index d6c4316f5..c353243ae 100644 --- a/internal/cri/server/sandbox_portforward_windows.go +++ b/internal/cri/server/sandbox_portforward_windows.go @@ -17,61 +17,103 @@ package server import ( - "bytes" "context" "fmt" "io" + "net" + "time" - "k8s.io/utils/exec" + "github.com/containerd/log" - sandboxstore "github.com/containerd/containerd/v2/internal/cri/store/sandbox" - cioutil "github.com/containerd/containerd/v2/pkg/ioutil" + netutils "k8s.io/utils/net" ) func (c *criService) portForward(ctx context.Context, id string, port int32, stream io.ReadWriter) error { - stdout := cioutil.NewNopWriteCloser(stream) - stderrBuffer := new(bytes.Buffer) - stderr := cioutil.NewNopWriteCloser(stderrBuffer) - // localhost is resolved to 127.0.0.1 in ipv4, and ::1 in ipv6. - // Explicitly using ipv4 IP address in here to avoid flakiness. - cmd := []string{"wincat.exe", "127.0.0.1", fmt.Sprint(port)} - err := c.execInSandbox(ctx, id, cmd, stream, stdout, stderr) + sandbox, err := c.sandboxStore.Get(id) if err != nil { - return fmt.Errorf("failed to execute port forward in sandbox: %s: %w", stderrBuffer.String(), err) + return fmt.Errorf("failed to find sandbox %q in store: %w", id, err) } + + var podIP string + if !hostNetwork(sandbox.Config) { + // get ip address of the sandbox + podIP, _, err = c.getIPs(sandbox) + if err != nil { + return fmt.Errorf("failed to get sandbox ip: %w", err) + } + } else { + // HPCs use the host networking namespace. + // Therefore, dial to localhost. + podIP = "127.0.0.1" + } + + err = func() error { + var conn net.Conn + if netutils.IsIPv4String(podIP) { + conn, err = net.Dial("tcp4", fmt.Sprintf("%s:%d", podIP, port)) + if err != nil { + return fmt.Errorf("failed to connect to %s:%d for pod %q: %v", podIP, port, id, err) + } + } else { + conn, err = net.Dial("tcp6", fmt.Sprintf("%s:%d", podIP, port)) + if err != nil { + return fmt.Errorf("failed to connect to %s:%d for pod %q: %v", podIP, port, id, err) + } + } + log.G(ctx).Debugf("Connection to ip %s and port %d was successful", podIP, port) + + defer conn.Close() + + // copy stream + errCh := make(chan error, 2) + // Copy from the namespace port connection to the client stream + go func() { + log.G(ctx).Debugf("PortForward copying data from namespace %q port %d to the client stream", id, port) + _, err := io.Copy(stream, conn) + errCh <- err + }() + + // Copy from the client stream to the namespace port connection + go func() { + log.G(ctx).Debugf("PortForward copying data from client stream to namespace %q port %d", id, port) + _, err := io.Copy(conn, stream) + errCh <- err + }() + + // Wait until the first error is returned by one of the connections + // we use errFwd to store the result of the port forwarding operation + // if the context is cancelled close everything and return + var errFwd error + select { + case errFwd = <-errCh: + log.G(ctx).Debugf("PortForward stop forwarding in one direction in network namespace %q port %d: %v", id, port, errFwd) + case <-ctx.Done(): + log.G(ctx).Debugf("PortForward cancelled in network namespace %q port %d: %v", id, port, ctx.Err()) + return ctx.Err() + } + // give a chance to terminate gracefully or timeout + // after 1s + const timeout = time.Second + select { + case e := <-errCh: + if errFwd == nil { + errFwd = e + } + log.G(ctx).Debugf("PortForward stopped forwarding in both directions in network namespace %q port %d: %v", id, port, e) + case <-time.After(timeout): + log.G(ctx).Debugf("PortForward timed out waiting to close the connection in network namespace %q port %d", id, port) + case <-ctx.Done(): + log.G(ctx).Debugf("PortForward cancelled in network namespace %q port %d: %v", id, port, ctx.Err()) + errFwd = ctx.Err() + } + + return errFwd + }() + + if err != nil { + return fmt.Errorf("failed to execute portforward for podId %v, podIp %v, err: %w", id, podIP, err) + } + log.G(ctx).Debugf("Finish port forwarding for windows %q port %d", id, port) + return nil } - -func (c *criService) execInSandbox(ctx context.Context, sandboxID string, cmd []string, stdin io.Reader, stdout, stderr io.WriteCloser) error { - // Get sandbox from our sandbox store. - sb, err := c.sandboxStore.Get(sandboxID) - if err != nil { - return fmt.Errorf("failed to find sandbox %q in store: %w", sandboxID, err) - } - - // Check the sandbox state - state := sb.Status.Get().State - if state != sandboxstore.StateReady { - return fmt.Errorf("sandbox is in %s state", fmt.Sprint(state)) - } - - opts := execOptions{ - cmd: cmd, - stdin: stdin, - stdout: stdout, - stderr: stderr, - tty: false, - resize: nil, - } - exitCode, err := c.execInternal(ctx, sb.Container, sandboxID, opts) - if err != nil { - return fmt.Errorf("failed to exec in sandbox: %w", err) - } - if *exitCode == 0 { - return nil - } - return &exec.CodeExitError{ - Err: fmt.Errorf("error executing command %v, exit code %d", cmd, *exitCode), - Code: int(*exitCode), - } -}