diff --git a/cmd/kubelet/app/options/container_runtime.go b/cmd/kubelet/app/options/container_runtime.go index 2f04036bffb..d07cfb699f4 100644 --- a/cmd/kubelet/app/options/container_runtime.go +++ b/cmd/kubelet/app/options/container_runtime.go @@ -45,12 +45,13 @@ func NewContainerRuntimeOptions() *config.ContainerRuntimeOptions { } return &config.ContainerRuntimeOptions{ - ContainerRuntime: kubetypes.DockerContainerRuntime, - DockerEndpoint: dockerEndpoint, - DockershimRootDirectory: "/var/lib/dockershim", - DockerDisableSharedPID: true, - PodSandboxImage: defaultPodSandboxImage, - ImagePullProgressDeadline: metav1.Duration{Duration: 1 * time.Minute}, - ExperimentalDockershim: false, + ContainerRuntime: kubetypes.DockerContainerRuntime, + RedirectContainerStreaming: false, + DockerEndpoint: dockerEndpoint, + DockershimRootDirectory: "/var/lib/dockershim", + DockerDisableSharedPID: true, + PodSandboxImage: defaultPodSandboxImage, + ImagePullProgressDeadline: metav1.Duration{Duration: 1 * time.Minute}, + ExperimentalDockershim: false, } } diff --git a/cmd/kubelet/app/server.go b/cmd/kubelet/app/server.go index cfa23f2f7f4..c98e262b321 100644 --- a/cmd/kubelet/app/server.go +++ b/cmd/kubelet/app/server.go @@ -1173,30 +1173,13 @@ func RunDockershim(f *options.KubeletFlags, c *kubeletconfiginternal.KubeletConf SupportedPortForwardProtocols: streaming.DefaultConfig.SupportedPortForwardProtocols, } + // Standalone dockershim will always start the local streaming server. ds, err := dockershim.NewDockerService(dockerClientConfig, r.PodSandboxImage, streamingConfig, &pluginSettings, - f.RuntimeCgroups, c.CgroupDriver, r.DockershimRootDirectory, r.DockerDisableSharedPID) + f.RuntimeCgroups, c.CgroupDriver, r.DockershimRootDirectory, r.DockerDisableSharedPID, true /*startLocalStreamingServer*/) if err != nil { return err } glog.V(2).Infof("Starting the GRPC server for the docker CRI shim.") server := dockerremote.NewDockerServer(f.RemoteRuntimeEndpoint, ds) - if err := server.Start(stopCh); err != nil { - return err - } - - streamingServer := &http.Server{ - Addr: net.JoinHostPort(c.Address, strconv.Itoa(int(c.Port))), - Handler: ds, - } - - go func() { - <-stopCh - streamingServer.Shutdown(context.Background()) - }() - - // Start the streaming server - if err := streamingServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { - return err - } - return nil + return server.Start(stopCh) } diff --git a/pkg/kubelet/BUILD b/pkg/kubelet/BUILD index 710af32b7a3..aa6cba39bf6 100644 --- a/pkg/kubelet/BUILD +++ b/pkg/kubelet/BUILD @@ -132,7 +132,6 @@ go_library( "//vendor/k8s.io/client-go/listers/core/v1:go_default_library", "//vendor/k8s.io/client-go/tools/cache:go_default_library", "//vendor/k8s.io/client-go/tools/record:go_default_library", - "//vendor/k8s.io/client-go/tools/remotecommand:go_default_library", "//vendor/k8s.io/client-go/util/certificate:go_default_library", "//vendor/k8s.io/client-go/util/flowcontrol:go_default_library", "//vendor/k8s.io/client-go/util/integer:go_default_library", diff --git a/pkg/kubelet/config/flags.go b/pkg/kubelet/config/flags.go index c51b4de8444..377d64e4574 100644 --- a/pkg/kubelet/config/flags.go +++ b/pkg/kubelet/config/flags.go @@ -31,6 +31,15 @@ type ContainerRuntimeOptions struct { ContainerRuntime string // RuntimeCgroups that container runtime is expected to be isolated in. RuntimeCgroups string + // RedirectContainerStreaming enables container streaming redirect. + // When RedirectContainerStreaming is false, kubelet will proxy container streaming data + // between apiserver and container runtime. This approach is more secure, but the proxy + // introduces some overhead. + // When RedirectContainerStreaming is true, kubelet will return an http redirect to apiserver, + // and apiserver will access container runtime directly. This approach is more performant, + // but less secure because the connection between apiserver and container runtime is not + // authenticated. + RedirectContainerStreaming bool // Docker-specific options. @@ -77,6 +86,7 @@ func (s *ContainerRuntimeOptions) AddFlags(fs *pflag.FlagSet) { // General settings. fs.StringVar(&s.ContainerRuntime, "container-runtime", s.ContainerRuntime, "The container runtime to use. Possible values: 'docker', 'remote', 'rkt (deprecated)'.") fs.StringVar(&s.RuntimeCgroups, "runtime-cgroups", s.RuntimeCgroups, "Optional absolute name of cgroups to create and run the runtime in.") + fs.BoolVar(&s.RedirectContainerStreaming, "redirect-container-streaming", s.RedirectContainerStreaming, "Enables container streaming redirect. If false, kubelet will proxy container streaming data between apiserver and container runtime; if true, kubelet will return an http redirect to apiserver, and apiserver will access container runtime directly. The proxy approach is more secure, but introduces some overhead. The redirect approach is more performant, but less secure because the connection between apiserver and container runtime is not authenticated.") // Docker-specific settings. fs.BoolVar(&s.ExperimentalDockershim, "experimental-dockershim", s.ExperimentalDockershim, "Enable dockershim only mode. In this mode, kubelet will only start dockershim without any other functionalities. This flag only serves test purpose, please do not use it unless you are conscious of what you are doing. [default=false]") diff --git a/pkg/kubelet/container/BUILD b/pkg/kubelet/container/BUILD index 0950b53b3da..287cd501367 100644 --- a/pkg/kubelet/container/BUILD +++ b/pkg/kubelet/container/BUILD @@ -21,7 +21,6 @@ go_library( "//pkg/api/legacyscheme:go_default_library", "//pkg/kubelet/apis/cri/runtime/v1alpha2:go_default_library", "//pkg/kubelet/util/format:go_default_library", - "//pkg/kubelet/util/ioutils:go_default_library", "//pkg/util/hash:go_default_library", "//pkg/volume:go_default_library", "//third_party/forked/golang/expansion:go_default_library", diff --git a/pkg/kubelet/container/helpers.go b/pkg/kubelet/container/helpers.go index b1299942f98..01fe7129d38 100644 --- a/pkg/kubelet/container/helpers.go +++ b/pkg/kubelet/container/helpers.go @@ -17,11 +17,9 @@ limitations under the License. package container import ( - "bytes" "fmt" "hash/fnv" "strings" - "time" "github.com/golang/glog" @@ -32,7 +30,6 @@ import ( "k8s.io/client-go/tools/record" runtimeapi "k8s.io/kubernetes/pkg/kubelet/apis/cri/runtime/v1alpha2" "k8s.io/kubernetes/pkg/kubelet/util/format" - "k8s.io/kubernetes/pkg/kubelet/util/ioutils" hashutil "k8s.io/kubernetes/pkg/util/hash" "k8s.io/kubernetes/third_party/forked/golang/expansion" ) @@ -265,22 +262,6 @@ func FormatPod(pod *Pod) string { return fmt.Sprintf("%s_%s(%s)", pod.Name, pod.Namespace, pod.ID) } -type containerCommandRunnerWrapper struct { - DirectStreamingRuntime -} - -var _ ContainerCommandRunner = &containerCommandRunnerWrapper{} - -func (r *containerCommandRunnerWrapper) RunInContainer(id ContainerID, cmd []string, timeout time.Duration) ([]byte, error) { - var buffer bytes.Buffer - output := ioutils.WriteCloserWrapper(&buffer) - err := r.ExecInContainer(id, cmd, nil, output, output, false, nil, timeout) - // Even if err is non-nil, there still may be output (e.g. the exec wrote to stdout or stderr but - // the command returned a nonzero exit code). Therefore, always return the output along with the - // error. - return buffer.Bytes(), err -} - // GetContainerSpec gets the container spec by containerName. func GetContainerSpec(pod *v1.Pod, containerName string) *v1.Container { for i, c := range pod.Spec.Containers { diff --git a/pkg/kubelet/container/runtime.go b/pkg/kubelet/container/runtime.go index 29852d435ed..70b72024c9c 100644 --- a/pkg/kubelet/container/runtime.go +++ b/pkg/kubelet/container/runtime.go @@ -124,22 +124,10 @@ type Runtime interface { UpdatePodCIDR(podCIDR string) error } -// DirectStreamingRuntime is the interface implemented by runtimes for which the streaming calls -// (exec/attach/port-forward) should be served directly by the Kubelet. -type DirectStreamingRuntime interface { - // Runs the command in the container of the specified pod. Attaches - // the processes stdin, stdout, and stderr. Optionally uses a tty. - ExecInContainer(containerID ContainerID, cmd []string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize, timeout time.Duration) error - // Forward the specified port from the specified pod to the stream. - PortForward(pod *Pod, port int32, stream io.ReadWriteCloser) error - // ContainerAttach encapsulates the attaching to containers for testability - ContainerAttacher -} - -// IndirectStreamingRuntime is the interface implemented by runtimes that handle the serving of the +// StreamingRuntime is the interface implemented by runtimes that handle the serving of the // streaming calls (exec/attach/port-forward) themselves. In this case, Kubelet should redirect to // the runtime server. -type IndirectStreamingRuntime interface { +type StreamingRuntime interface { GetExec(id ContainerID, cmd []string, stdin, stdout, stderr, tty bool) (*url.URL, error) GetAttach(id ContainerID, stdin, stdout, stderr, tty bool) (*url.URL, error) GetPortForward(podName, podNamespace string, podUID types.UID, ports []int32) (*url.URL, error) diff --git a/pkg/kubelet/container/testing/fake_runtime.go b/pkg/kubelet/container/testing/fake_runtime.go index 3019d30094e..707ee1ac456 100644 --- a/pkg/kubelet/container/testing/fake_runtime.go +++ b/pkg/kubelet/container/testing/fake_runtime.go @@ -26,7 +26,6 @@ import ( "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/types" - "k8s.io/client-go/tools/remotecommand" "k8s.io/client-go/util/flowcontrol" . "k8s.io/kubernetes/pkg/kubelet/container" "k8s.io/kubernetes/pkg/volume" @@ -59,34 +58,13 @@ type FakeRuntime struct { StatusErr error } -type FakeDirectStreamingRuntime struct { - *FakeRuntime - - // Arguments to streaming method calls. - Args struct { - // Attach / Exec args - ContainerID ContainerID - Cmd []string - Stdin io.Reader - Stdout io.WriteCloser - Stderr io.WriteCloser - TTY bool - // Port-forward args - Pod *Pod - Port int32 - Stream io.ReadWriteCloser - } -} - -var _ DirectStreamingRuntime = &FakeDirectStreamingRuntime{} - const FakeHost = "localhost:12345" -type FakeIndirectStreamingRuntime struct { +type FakeStreamingRuntime struct { *FakeRuntime } -var _ IndirectStreamingRuntime = &FakeIndirectStreamingRuntime{} +var _ StreamingRuntime = &FakeStreamingRuntime{} // FakeRuntime should implement Runtime. var _ Runtime = &FakeRuntime{} @@ -311,35 +289,6 @@ func (f *FakeRuntime) GetPodStatus(uid types.UID, name, namespace string) (*PodS return &status, f.Err } -func (f *FakeDirectStreamingRuntime) ExecInContainer(containerID ContainerID, cmd []string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize, timeout time.Duration) error { - f.Lock() - defer f.Unlock() - - f.CalledFunctions = append(f.CalledFunctions, "ExecInContainer") - f.Args.ContainerID = containerID - f.Args.Cmd = cmd - f.Args.Stdin = stdin - f.Args.Stdout = stdout - f.Args.Stderr = stderr - f.Args.TTY = tty - - return f.Err -} - -func (f *FakeDirectStreamingRuntime) AttachContainer(containerID ContainerID, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error { - f.Lock() - defer f.Unlock() - - f.CalledFunctions = append(f.CalledFunctions, "AttachContainer") - f.Args.ContainerID = containerID - f.Args.Stdin = stdin - f.Args.Stdout = stdout - f.Args.Stderr = stderr - f.Args.TTY = tty - - return f.Err -} - func (f *FakeRuntime) GetContainerLogs(pod *v1.Pod, containerID ContainerID, logOptions *v1.PodLogOptions, stdout, stderr io.Writer) (err error) { f.Lock() defer f.Unlock() @@ -394,18 +343,6 @@ func (f *FakeRuntime) RemoveImage(image ImageSpec) error { return f.Err } -func (f *FakeDirectStreamingRuntime) PortForward(pod *Pod, port int32, stream io.ReadWriteCloser) error { - f.Lock() - defer f.Unlock() - - f.CalledFunctions = append(f.CalledFunctions, "PortForward") - f.Args.Pod = pod - f.Args.Port = port - f.Args.Stream = stream - - return f.Err -} - func (f *FakeRuntime) GetNetNS(containerID ContainerID) (string, error) { f.Lock() defer f.Unlock() @@ -455,7 +392,7 @@ func (f *FakeRuntime) ImageStats() (*ImageStats, error) { return nil, f.Err } -func (f *FakeIndirectStreamingRuntime) GetExec(id ContainerID, cmd []string, stdin, stdout, stderr, tty bool) (*url.URL, error) { +func (f *FakeStreamingRuntime) GetExec(id ContainerID, cmd []string, stdin, stdout, stderr, tty bool) (*url.URL, error) { f.Lock() defer f.Unlock() @@ -463,7 +400,7 @@ func (f *FakeIndirectStreamingRuntime) GetExec(id ContainerID, cmd []string, std return &url.URL{Host: FakeHost}, f.Err } -func (f *FakeIndirectStreamingRuntime) GetAttach(id ContainerID, stdin, stdout, stderr, tty bool) (*url.URL, error) { +func (f *FakeStreamingRuntime) GetAttach(id ContainerID, stdin, stdout, stderr, tty bool) (*url.URL, error) { f.Lock() defer f.Unlock() @@ -471,7 +408,7 @@ func (f *FakeIndirectStreamingRuntime) GetAttach(id ContainerID, stdin, stdout, return &url.URL{Host: FakeHost}, f.Err } -func (f *FakeIndirectStreamingRuntime) GetPortForward(podName, podNamespace string, podUID types.UID, ports []int32) (*url.URL, error) { +func (f *FakeStreamingRuntime) GetPortForward(podName, podNamespace string, podUID types.UID, ports []int32) (*url.URL, error) { f.Lock() defer f.Unlock() diff --git a/pkg/kubelet/dockershim/docker_service.go b/pkg/kubelet/dockershim/docker_service.go index 5399094cc3c..2a7cce57f0d 100644 --- a/pkg/kubelet/dockershim/docker_service.go +++ b/pkg/kubelet/dockershim/docker_service.go @@ -85,7 +85,7 @@ const ( type CRIService interface { runtimeapi.RuntimeServiceServer runtimeapi.ImageServiceServer - Start() error + Start(<-chan struct{}) error } // DockerService is an interface that embeds the new RuntimeService and @@ -188,7 +188,8 @@ func NewDockerClientFromConfig(config *ClientConfig) libdocker.Interface { // NOTE: Anything passed to DockerService should be eventually handled in another way when we switch to running the shim as a different process. func NewDockerService(config *ClientConfig, podSandboxImage string, streamingConfig *streaming.Config, - pluginSettings *NetworkPluginSettings, cgroupsName string, kubeCgroupDriver string, dockershimRootDir string, disableSharedPID bool) (DockerService, error) { + pluginSettings *NetworkPluginSettings, cgroupsName string, kubeCgroupDriver string, dockershimRootDir string, + disableSharedPID, startLocalStreamingServer bool) (DockerService, error) { client := NewDockerClientFromConfig(config) @@ -207,10 +208,11 @@ func NewDockerService(config *ClientConfig, podSandboxImage string, streamingCon client: client, execHandler: &NativeExecHandler{}, }, - containerManager: cm.NewContainerManager(cgroupsName, client), - checkpointManager: checkpointManager, - disableSharedPID: disableSharedPID, - networkReady: make(map[string]bool), + containerManager: cm.NewContainerManager(cgroupsName, client), + checkpointManager: checkpointManager, + disableSharedPID: disableSharedPID, + startLocalStreamingServer: startLocalStreamingServer, + networkReady: make(map[string]bool), } // check docker version compatibility. @@ -307,6 +309,9 @@ type dockerService struct { // See proposals/pod-pid-namespace.md for details. // TODO: Remove once the escape hatch is no longer used (https://issues.k8s.io/41938) disableSharedPID bool + // startLocalStreamingServer indicates whether dockershim should start a + // streaming server on localhost. + startLocalStreamingServer bool } // TODO: handle context. @@ -395,13 +400,25 @@ func (ds *dockerService) GetPodPortMappings(podSandboxID string) ([]*hostport.Po } // Start initializes and starts components in dockerService. -func (ds *dockerService) Start() error { +func (ds *dockerService) Start(stopCh <-chan struct{}) error { // Initialize the legacy cleanup flag. + if ds.startLocalStreamingServer { + go func() { + <-stopCh + if err := ds.streamingServer.Stop(); err != nil { + glog.Errorf("Failed to stop streaming server: %v", err) + } + }() + go func() { + if err := ds.streamingServer.Start(true); err != nil && err != http.ErrServerClosed { + glog.Fatalf("Failed to start streaming server: %v", err) + } + }() + } return ds.containerManager.Start() } // Status returns the status of the runtime. -// TODO(random-liu): Set network condition accordingly here. func (ds *dockerService) Status(_ context.Context, r *runtimeapi.StatusRequest) (*runtimeapi.StatusResponse, error) { runtimeReady := &runtimeapi.RuntimeCondition{ Type: runtimeapi.RuntimeReady, diff --git a/pkg/kubelet/dockershim/remote/docker_server.go b/pkg/kubelet/dockershim/remote/docker_server.go index 1ac7560d41b..5e8967a8d5f 100644 --- a/pkg/kubelet/dockershim/remote/docker_server.go +++ b/pkg/kubelet/dockershim/remote/docker_server.go @@ -51,7 +51,7 @@ func NewDockerServer(endpoint string, s dockershim.CRIService) *DockerServer { // Start starts the dockershim grpc server. func (s *DockerServer) Start(stopCh <-chan struct{}) error { // Start the internal service. - if err := s.service.Start(); err != nil { + if err := s.service.Start(stopCh); err != nil { glog.Errorf("Unable to start docker service") return err } diff --git a/pkg/kubelet/kubelet.go b/pkg/kubelet/kubelet.go index a5459480eeb..8567f34e9d4 100644 --- a/pkg/kubelet/kubelet.go +++ b/pkg/kubelet/kubelet.go @@ -516,21 +516,22 @@ func NewMainKubelet(kubeCfg *kubeletconfiginternal.KubeletConfiguration, nodeRef: nodeRef, nodeLabels: nodeLabels, nodeStatusUpdateFrequency: kubeCfg.NodeStatusUpdateFrequency.Duration, - os: kubeDeps.OSInterface, - oomWatcher: oomWatcher, - cgroupsPerQOS: kubeCfg.CgroupsPerQOS, - cgroupRoot: kubeCfg.CgroupRoot, - mounter: kubeDeps.Mounter, - writer: kubeDeps.Writer, - maxPods: int(kubeCfg.MaxPods), - podsPerCore: int(kubeCfg.PodsPerCore), - syncLoopMonitor: atomic.Value{}, - daemonEndpoints: daemonEndpoints, - containerManager: kubeDeps.ContainerManager, - containerRuntimeName: containerRuntime, - nodeIP: parsedNodeIP, - nodeIPValidator: validateNodeIP, - clock: clock.RealClock{}, + os: kubeDeps.OSInterface, + oomWatcher: oomWatcher, + cgroupsPerQOS: kubeCfg.CgroupsPerQOS, + cgroupRoot: kubeCfg.CgroupRoot, + mounter: kubeDeps.Mounter, + writer: kubeDeps.Writer, + maxPods: int(kubeCfg.MaxPods), + podsPerCore: int(kubeCfg.PodsPerCore), + syncLoopMonitor: atomic.Value{}, + daemonEndpoints: daemonEndpoints, + containerManager: kubeDeps.ContainerManager, + containerRuntimeName: containerRuntime, + redirectContainerStreaming: crOptions.RedirectContainerStreaming, + nodeIP: parsedNodeIP, + nodeIPValidator: validateNodeIP, + clock: clock.RealClock{}, enableControllerAttachDetach: kubeCfg.EnableControllerAttachDetach, iptClient: utilipt.New(utilexec.New(), utildbus.New(), utilipt.ProtocolIpv4), makeIPTablesUtilChains: kubeCfg.MakeIPTablesUtilChains, @@ -610,16 +611,16 @@ func NewMainKubelet(kubeCfg *kubeletconfiginternal.KubeletConfiguration, switch containerRuntime { case kubetypes.DockerContainerRuntime: // Create and start the CRI shim running as a grpc server. - streamingConfig := getStreamingConfig(kubeCfg, kubeDeps) + streamingConfig := getStreamingConfig(kubeCfg, kubeDeps, crOptions) ds, err := dockershim.NewDockerService(kubeDeps.DockerClientConfig, crOptions.PodSandboxImage, streamingConfig, &pluginSettings, runtimeCgroups, kubeCfg.CgroupDriver, crOptions.DockershimRootDirectory, - crOptions.DockerDisableSharedPID) + crOptions.DockerDisableSharedPID, !crOptions.RedirectContainerStreaming) if err != nil { return nil, err } - // For now, the CRI shim redirects the streaming requests to the - // kubelet, which handles the requests using DockerService.. - klet.criHandler = ds + if crOptions.RedirectContainerStreaming { + klet.criHandler = ds + } // The unix socket for kubelet <-> dockershim communication. glog.V(5).Infof("RemoteRuntimeEndpoint: %q, RemoteImageEndpoint: %q", @@ -675,6 +676,7 @@ func NewMainKubelet(kubeCfg *kubeletconfiginternal.KubeletConfiguration, return nil, err } klet.containerRuntime = runtime + klet.streamingRuntime = runtime klet.runner = runtime if cadvisor.UsingLegacyCadvisorStats(containerRuntime, remoteRuntimeEndpoint) { @@ -1005,9 +1007,15 @@ type Kubelet struct { // The name of the container runtime containerRuntimeName string + // redirectContainerStreaming enables container streaming redirect. + redirectContainerStreaming bool + // Container runtime. containerRuntime kubecontainer.Runtime + // Streaming runtime handles container streaming. + streamingRuntime kubecontainer.StreamingRuntime + // Container runtime service (needed by container runtime Start()). // TODO(CD): try to make this available without holding a reference in this // struct. For example, by adding a getter to generic runtime. @@ -2112,11 +2120,6 @@ func (kl *Kubelet) BirthCry() { kl.recorder.Eventf(kl.nodeRef, v1.EventTypeNormal, events.StartingKubelet, "Starting kubelet.") } -// StreamingConnectionIdleTimeout returns the timeout for streaming connections to the HTTP server. -func (kl *Kubelet) StreamingConnectionIdleTimeout() time.Duration { - return kl.streamingConnectionIdleTimeout -} - // ResyncInterval returns the interval used for periodic syncs. func (kl *Kubelet) ResyncInterval() time.Duration { return kl.resyncInterval @@ -2124,12 +2127,12 @@ func (kl *Kubelet) ResyncInterval() time.Duration { // ListenAndServe runs the kubelet HTTP server. func (kl *Kubelet) ListenAndServe(address net.IP, port uint, tlsOptions *server.TLSOptions, auth server.AuthInterface, enableDebuggingHandlers, enableContentionProfiling bool) { - server.ListenAndServeKubeletServer(kl, kl.resourceAnalyzer, address, port, tlsOptions, auth, enableDebuggingHandlers, enableContentionProfiling, kl.containerRuntime, kl.criHandler) + server.ListenAndServeKubeletServer(kl, kl.resourceAnalyzer, address, port, tlsOptions, auth, enableDebuggingHandlers, enableContentionProfiling, kl.redirectContainerStreaming, kl.criHandler) } // ListenAndServeReadOnly runs the kubelet HTTP server in read-only mode. func (kl *Kubelet) ListenAndServeReadOnly(address net.IP, port uint) { - server.ListenAndServeKubeletReadOnlyServer(kl, kl.resourceAnalyzer, address, port, kl.containerRuntime) + server.ListenAndServeKubeletReadOnlyServer(kl, kl.resourceAnalyzer, address, port) } // Delete the eligible dead container instances in a pod. Depending on the configuration, the latest dead containers may be kept around. @@ -2153,19 +2156,23 @@ func isSyncPodWorthy(event *pleg.PodLifecycleEvent) bool { } // Gets the streaming server configuration to use with in-process CRI shims. -func getStreamingConfig(kubeCfg *kubeletconfiginternal.KubeletConfiguration, kubeDeps *Dependencies) *streaming.Config { +func getStreamingConfig(kubeCfg *kubeletconfiginternal.KubeletConfiguration, kubeDeps *Dependencies, crOptions *config.ContainerRuntimeOptions) *streaming.Config { config := &streaming.Config{ - // Use a relative redirect (no scheme or host). - BaseURL: &url.URL{ - Path: "/cri/", - }, StreamIdleTimeout: kubeCfg.StreamingConnectionIdleTimeout.Duration, StreamCreationTimeout: streaming.DefaultConfig.StreamCreationTimeout, SupportedRemoteCommandProtocols: streaming.DefaultConfig.SupportedRemoteCommandProtocols, SupportedPortForwardProtocols: streaming.DefaultConfig.SupportedPortForwardProtocols, } - if kubeDeps.TLSOptions != nil { - config.TLSConfig = kubeDeps.TLSOptions.Config + if !crOptions.RedirectContainerStreaming { + config.Addr = net.JoinHostPort("localhost", "0") + } else { + // Use a relative redirect (no scheme or host). + config.BaseURL = &url.URL{ + Path: "/cri/", + } + if kubeDeps.TLSOptions != nil { + config.TLSConfig = kubeDeps.TLSOptions.Config + } } return config } diff --git a/pkg/kubelet/kubelet_pods.go b/pkg/kubelet/kubelet_pods.go index 50c28c0ebec..d0307986696 100644 --- a/pkg/kubelet/kubelet_pods.go +++ b/pkg/kubelet/kubelet_pods.go @@ -30,7 +30,6 @@ import ( "sort" "strings" "sync" - "time" "github.com/golang/glog" "k8s.io/api/core/v1" @@ -41,7 +40,6 @@ import ( "k8s.io/apimachinery/pkg/util/sets" utilvalidation "k8s.io/apimachinery/pkg/util/validation" utilfeature "k8s.io/apiserver/pkg/util/feature" - "k8s.io/client-go/tools/remotecommand" podutil "k8s.io/kubernetes/pkg/api/v1/pod" "k8s.io/kubernetes/pkg/api/v1/resource" podshelper "k8s.io/kubernetes/pkg/apis/core/pods" @@ -1595,142 +1593,60 @@ func (kl *Kubelet) RunInContainer(podFullName string, podUID types.UID, containe return kl.runner.RunInContainer(container.ID, cmd, 0) } -// ExecInContainer executes a command in a container, connecting the supplied -// stdin/stdout/stderr to the command's IO streams. -func (kl *Kubelet) ExecInContainer(podFullName string, podUID types.UID, containerName string, cmd []string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize, timeout time.Duration) error { - streamingRuntime, ok := kl.containerRuntime.(kubecontainer.DirectStreamingRuntime) - if !ok { - return fmt.Errorf("streaming methods not supported by runtime") - } - - container, err := kl.findContainer(podFullName, podUID, containerName) - if err != nil { - return err - } - if container == nil { - return fmt.Errorf("container not found (%q)", containerName) - } - return streamingRuntime.ExecInContainer(container.ID, cmd, stdin, stdout, stderr, tty, resize, timeout) -} - -// AttachContainer uses the container runtime to attach the given streams to -// the given container. -func (kl *Kubelet) AttachContainer(podFullName string, podUID types.UID, containerName string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error { - streamingRuntime, ok := kl.containerRuntime.(kubecontainer.DirectStreamingRuntime) - if !ok { - return fmt.Errorf("streaming methods not supported by runtime") - } - - container, err := kl.findContainer(podFullName, podUID, containerName) - if err != nil { - return err - } - if container == nil { - return fmt.Errorf("container not found (%q)", containerName) - } - return streamingRuntime.AttachContainer(container.ID, stdin, stdout, stderr, tty, resize) -} - -// PortForward connects to the pod's port and copies data between the port -// and the stream. -func (kl *Kubelet) PortForward(podFullName string, podUID types.UID, port int32, stream io.ReadWriteCloser) error { - streamingRuntime, ok := kl.containerRuntime.(kubecontainer.DirectStreamingRuntime) - if !ok { - return fmt.Errorf("streaming methods not supported by runtime") - } - - pods, err := kl.containerRuntime.GetPods(false) - if err != nil { - return err - } - // Resolve and type convert back again. - // We need the static pod UID but the kubecontainer API works with types.UID. - podUID = types.UID(kl.podManager.TranslatePodUID(podUID)) - pod := kubecontainer.Pods(pods).FindPod(podFullName, podUID) - if pod.IsEmpty() { - return fmt.Errorf("pod not found (%q)", podFullName) - } - return streamingRuntime.PortForward(&pod, port, stream) -} - // GetExec gets the URL the exec will be served from, or nil if the Kubelet will serve it. func (kl *Kubelet) GetExec(podFullName string, podUID types.UID, containerName string, cmd []string, streamOpts remotecommandserver.Options) (*url.URL, error) { - switch streamingRuntime := kl.containerRuntime.(type) { - case kubecontainer.DirectStreamingRuntime: - // Kubelet will serve the exec directly. - return nil, nil - case kubecontainer.IndirectStreamingRuntime: - container, err := kl.findContainer(podFullName, podUID, containerName) - if err != nil { - return nil, err - } - if container == nil { - return nil, fmt.Errorf("container not found (%q)", containerName) - } - return streamingRuntime.GetExec(container.ID, cmd, streamOpts.Stdin, streamOpts.Stdout, streamOpts.Stderr, streamOpts.TTY) - default: - return nil, fmt.Errorf("container runtime does not support exec") + container, err := kl.findContainer(podFullName, podUID, containerName) + if err != nil { + return nil, err } + if container == nil { + return nil, fmt.Errorf("container not found (%q)", containerName) + } + return kl.streamingRuntime.GetExec(container.ID, cmd, streamOpts.Stdin, streamOpts.Stdout, streamOpts.Stderr, streamOpts.TTY) } // GetAttach gets the URL the attach will be served from, or nil if the Kubelet will serve it. func (kl *Kubelet) GetAttach(podFullName string, podUID types.UID, containerName string, streamOpts remotecommandserver.Options) (*url.URL, error) { - switch streamingRuntime := kl.containerRuntime.(type) { - case kubecontainer.DirectStreamingRuntime: - // Kubelet will serve the attach directly. - return nil, nil - case kubecontainer.IndirectStreamingRuntime: - container, err := kl.findContainer(podFullName, podUID, containerName) - if err != nil { - return nil, err - } - if container == nil { - return nil, fmt.Errorf("container %s not found in pod %s", containerName, podFullName) - } - - // The TTY setting for attach must match the TTY setting in the initial container configuration, - // since whether the process is running in a TTY cannot be changed after it has started. We - // need the api.Pod to get the TTY status. - pod, found := kl.GetPodByFullName(podFullName) - if !found || (string(podUID) != "" && pod.UID != podUID) { - return nil, fmt.Errorf("pod %s not found", podFullName) - } - containerSpec := kubecontainer.GetContainerSpec(pod, containerName) - if containerSpec == nil { - return nil, fmt.Errorf("container %s not found in pod %s", containerName, podFullName) - } - tty := containerSpec.TTY - - return streamingRuntime.GetAttach(container.ID, streamOpts.Stdin, streamOpts.Stdout, streamOpts.Stderr, tty) - default: - return nil, fmt.Errorf("container runtime does not support attach") + container, err := kl.findContainer(podFullName, podUID, containerName) + if err != nil { + return nil, err } + if container == nil { + return nil, fmt.Errorf("container %s not found in pod %s", containerName, podFullName) + } + + // The TTY setting for attach must match the TTY setting in the initial container configuration, + // since whether the process is running in a TTY cannot be changed after it has started. We + // need the api.Pod to get the TTY status. + pod, found := kl.GetPodByFullName(podFullName) + if !found || (string(podUID) != "" && pod.UID != podUID) { + return nil, fmt.Errorf("pod %s not found", podFullName) + } + containerSpec := kubecontainer.GetContainerSpec(pod, containerName) + if containerSpec == nil { + return nil, fmt.Errorf("container %s not found in pod %s", containerName, podFullName) + } + tty := containerSpec.TTY + + return kl.streamingRuntime.GetAttach(container.ID, streamOpts.Stdin, streamOpts.Stdout, streamOpts.Stderr, tty) } // GetPortForward gets the URL the port-forward will be served from, or nil if the Kubelet will serve it. func (kl *Kubelet) GetPortForward(podName, podNamespace string, podUID types.UID, portForwardOpts portforward.V4Options) (*url.URL, error) { - switch streamingRuntime := kl.containerRuntime.(type) { - case kubecontainer.DirectStreamingRuntime: - // Kubelet will serve the attach directly. - return nil, nil - case kubecontainer.IndirectStreamingRuntime: - pods, err := kl.containerRuntime.GetPods(false) - if err != nil { - return nil, err - } - // Resolve and type convert back again. - // We need the static pod UID but the kubecontainer API works with types.UID. - podUID = types.UID(kl.podManager.TranslatePodUID(podUID)) - podFullName := kubecontainer.BuildPodFullName(podName, podNamespace) - pod := kubecontainer.Pods(pods).FindPod(podFullName, podUID) - if pod.IsEmpty() { - return nil, fmt.Errorf("pod not found (%q)", podFullName) - } - - return streamingRuntime.GetPortForward(podName, podNamespace, podUID, portForwardOpts.Ports) - default: - return nil, fmt.Errorf("container runtime does not support port-forward") + pods, err := kl.containerRuntime.GetPods(false) + if err != nil { + return nil, err } + // Resolve and type convert back again. + // We need the static pod UID but the kubecontainer API works with types.UID. + podUID = types.UID(kl.podManager.TranslatePodUID(podUID)) + podFullName := kubecontainer.BuildPodFullName(podName, podNamespace) + pod := kubecontainer.Pods(pods).FindPod(podFullName, podUID) + if pod.IsEmpty() { + return nil, fmt.Errorf("pod not found (%q)", podFullName) + } + + return kl.streamingRuntime.GetPortForward(podName, podNamespace, podUID, portForwardOpts.Ports) } // cleanupOrphanedPodCgroups removes cgroups that should no longer exist. diff --git a/pkg/kubelet/kubelet_pods_test.go b/pkg/kubelet/kubelet_pods_test.go index 6a26e20550e..f4fc43b5fe1 100644 --- a/pkg/kubelet/kubelet_pods_test.go +++ b/pkg/kubelet/kubelet_pods_test.go @@ -17,7 +17,6 @@ limitations under the License. package kubelet import ( - "bytes" "errors" "fmt" "io/ioutil" @@ -2095,7 +2094,7 @@ func (f *fakeReadWriteCloser) Close() error { return nil } -func TestExec(t *testing.T) { +func TestGetExec(t *testing.T) { const ( podName = "podFoo" podNamespace = "nsFoo" @@ -2106,9 +2105,6 @@ func TestExec(t *testing.T) { var ( podFullName = kubecontainer.GetPodFullName(podWithUIDNameNs(podUID, podName, podNamespace)) command = []string{"ls"} - stdin = &bytes.Buffer{} - stdout = &fakeReadWriteCloser{} - stderr = &fakeReadWriteCloser{} ) testcases := []struct { @@ -2149,66 +2145,28 @@ func TestExec(t *testing.T) { }}, } - { // No streaming case - description := "no streaming - " + tc.description - redirect, err := kubelet.GetExec(tc.podFullName, podUID, tc.container, command, remotecommand.Options{}) - assert.Error(t, err, description) - assert.Nil(t, redirect, description) + description := "streaming - " + tc.description + fakeRuntime := &containertest.FakeStreamingRuntime{FakeRuntime: testKubelet.fakeRuntime} + kubelet.containerRuntime = fakeRuntime + kubelet.streamingRuntime = fakeRuntime - err = kubelet.ExecInContainer(tc.podFullName, podUID, tc.container, command, stdin, stdout, stderr, tty, nil, 0) + redirect, err := kubelet.GetExec(tc.podFullName, podUID, tc.container, command, remotecommand.Options{}) + if tc.expectError { assert.Error(t, err, description) - } - { // Direct streaming case - description := "direct streaming - " + tc.description - fakeRuntime := &containertest.FakeDirectStreamingRuntime{FakeRuntime: testKubelet.fakeRuntime} - kubelet.containerRuntime = fakeRuntime - - redirect, err := kubelet.GetExec(tc.podFullName, podUID, tc.container, command, remotecommand.Options{}) + } else { assert.NoError(t, err, description) - assert.Nil(t, redirect, description) - - err = kubelet.ExecInContainer(tc.podFullName, podUID, tc.container, command, stdin, stdout, stderr, tty, nil, 0) - if tc.expectError { - assert.Error(t, err, description) - } else { - assert.NoError(t, err, description) - assert.Equal(t, fakeRuntime.Args.ContainerID.ID, containerID, description+": ID") - assert.Equal(t, fakeRuntime.Args.Cmd, command, description+": Command") - assert.Equal(t, fakeRuntime.Args.Stdin, stdin, description+": Stdin") - assert.Equal(t, fakeRuntime.Args.Stdout, stdout, description+": Stdout") - assert.Equal(t, fakeRuntime.Args.Stderr, stderr, description+": Stderr") - assert.Equal(t, fakeRuntime.Args.TTY, tty, description+": TTY") - } - } - { // Indirect streaming case - description := "indirect streaming - " + tc.description - fakeRuntime := &containertest.FakeIndirectStreamingRuntime{FakeRuntime: testKubelet.fakeRuntime} - kubelet.containerRuntime = fakeRuntime - - redirect, err := kubelet.GetExec(tc.podFullName, podUID, tc.container, command, remotecommand.Options{}) - if tc.expectError { - assert.Error(t, err, description) - } else { - assert.NoError(t, err, description) - assert.Equal(t, containertest.FakeHost, redirect.Host, description+": redirect") - } - - err = kubelet.ExecInContainer(tc.podFullName, podUID, tc.container, command, stdin, stdout, stderr, tty, nil, 0) - assert.Error(t, err, description) + assert.Equal(t, containertest.FakeHost, redirect.Host, description+": redirect") } } } -func TestPortForward(t *testing.T) { +func TestGetPortForward(t *testing.T) { const ( podName = "podFoo" podNamespace = "nsFoo" podUID types.UID = "12345678" port int32 = 5000 ) - var ( - stream = &fakeReadWriteCloser{} - ) testcases := []struct { description string @@ -2240,50 +2198,17 @@ func TestPortForward(t *testing.T) { }}, } - podFullName := kubecontainer.GetPodFullName(podWithUIDNameNs(podUID, tc.podName, podNamespace)) - { // No streaming case - description := "no streaming - " + tc.description - redirect, err := kubelet.GetPortForward(tc.podName, podNamespace, podUID, portforward.V4Options{}) - assert.Error(t, err, description) - assert.Nil(t, redirect, description) + description := "streaming - " + tc.description + fakeRuntime := &containertest.FakeStreamingRuntime{FakeRuntime: testKubelet.fakeRuntime} + kubelet.containerRuntime = fakeRuntime + kubelet.streamingRuntime = fakeRuntime - err = kubelet.PortForward(podFullName, podUID, port, stream) + redirect, err := kubelet.GetPortForward(tc.podName, podNamespace, podUID, portforward.V4Options{}) + if tc.expectError { assert.Error(t, err, description) - } - { // Direct streaming case - description := "direct streaming - " + tc.description - fakeRuntime := &containertest.FakeDirectStreamingRuntime{FakeRuntime: testKubelet.fakeRuntime} - kubelet.containerRuntime = fakeRuntime - - redirect, err := kubelet.GetPortForward(tc.podName, podNamespace, podUID, portforward.V4Options{}) + } else { assert.NoError(t, err, description) - assert.Nil(t, redirect, description) - - err = kubelet.PortForward(podFullName, podUID, port, stream) - if tc.expectError { - assert.Error(t, err, description) - } else { - assert.NoError(t, err, description) - require.Equal(t, fakeRuntime.Args.Pod.ID, podUID, description+": Pod UID") - require.Equal(t, fakeRuntime.Args.Port, port, description+": Port") - require.Equal(t, fakeRuntime.Args.Stream, stream, description+": stream") - } - } - { // Indirect streaming case - description := "indirect streaming - " + tc.description - fakeRuntime := &containertest.FakeIndirectStreamingRuntime{FakeRuntime: testKubelet.fakeRuntime} - kubelet.containerRuntime = fakeRuntime - - redirect, err := kubelet.GetPortForward(tc.podName, podNamespace, podUID, portforward.V4Options{}) - if tc.expectError { - assert.Error(t, err, description) - } else { - assert.NoError(t, err, description) - assert.Equal(t, containertest.FakeHost, redirect.Host, description+": redirect") - } - - err = kubelet.PortForward(podFullName, podUID, port, stream) - assert.Error(t, err, description) + assert.Equal(t, containertest.FakeHost, redirect.Host, description+": redirect") } } } diff --git a/pkg/kubelet/kuberuntime/kuberuntime_manager.go b/pkg/kubelet/kuberuntime/kuberuntime_manager.go index c34136b569a..df207fb4352 100644 --- a/pkg/kubelet/kuberuntime/kuberuntime_manager.go +++ b/pkg/kubelet/kuberuntime/kuberuntime_manager.go @@ -120,7 +120,7 @@ type kubeGenericRuntimeManager struct { type KubeGenericRuntime interface { kubecontainer.Runtime - kubecontainer.IndirectStreamingRuntime + kubecontainer.StreamingRuntime kubecontainer.ContainerCommandRunner } diff --git a/pkg/kubelet/server/BUILD b/pkg/kubelet/server/BUILD index c2e338a1ad8..69a71140f0f 100644 --- a/pkg/kubelet/server/BUILD +++ b/pkg/kubelet/server/BUILD @@ -37,7 +37,7 @@ go_library( "//vendor/k8s.io/apimachinery/pkg/runtime:go_default_library", "//vendor/k8s.io/apimachinery/pkg/runtime/schema:go_default_library", "//vendor/k8s.io/apimachinery/pkg/types:go_default_library", - "//vendor/k8s.io/apimachinery/pkg/util/remotecommand:go_default_library", + "//vendor/k8s.io/apimachinery/pkg/util/proxy:go_default_library", "//vendor/k8s.io/apimachinery/pkg/util/runtime:go_default_library", "//vendor/k8s.io/apiserver/pkg/authentication/authenticator:go_default_library", "//vendor/k8s.io/apiserver/pkg/authentication/user:go_default_library", @@ -45,7 +45,6 @@ go_library( "//vendor/k8s.io/apiserver/pkg/server/healthz:go_default_library", "//vendor/k8s.io/apiserver/pkg/server/httplog:go_default_library", "//vendor/k8s.io/apiserver/pkg/util/flushwriter:go_default_library", - "//vendor/k8s.io/client-go/tools/remotecommand:go_default_library", ], ) @@ -60,13 +59,14 @@ go_test( deps = [ "//pkg/apis/core:go_default_library", "//pkg/apis/core/install:go_default_library", + "//pkg/kubelet/apis/cri/runtime/v1alpha2:go_default_library", "//pkg/kubelet/apis/stats/v1alpha1:go_default_library", "//pkg/kubelet/cm:go_default_library", "//pkg/kubelet/container:go_default_library", - "//pkg/kubelet/container/testing:go_default_library", "//pkg/kubelet/server/portforward:go_default_library", "//pkg/kubelet/server/remotecommand:go_default_library", "//pkg/kubelet/server/stats:go_default_library", + "//pkg/kubelet/server/streaming:go_default_library", "//pkg/volume:go_default_library", "//vendor/github.com/google/cadvisor/info/v1:go_default_library", "//vendor/github.com/stretchr/testify/assert:go_default_library", diff --git a/pkg/kubelet/server/server.go b/pkg/kubelet/server/server.go index fdde1fee4e8..2e5bbde211f 100644 --- a/pkg/kubelet/server/server.go +++ b/pkg/kubelet/server/server.go @@ -42,14 +42,13 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/types" - remotecommandconsts "k8s.io/apimachinery/pkg/util/remotecommand" + "k8s.io/apimachinery/pkg/util/proxy" utilruntime "k8s.io/apimachinery/pkg/util/runtime" "k8s.io/apiserver/pkg/authentication/authenticator" "k8s.io/apiserver/pkg/authorization/authorizer" "k8s.io/apiserver/pkg/server/healthz" "k8s.io/apiserver/pkg/server/httplog" "k8s.io/apiserver/pkg/util/flushwriter" - "k8s.io/client-go/tools/remotecommand" "k8s.io/kubernetes/pkg/api/legacyscheme" api "k8s.io/kubernetes/pkg/apis/core" "k8s.io/kubernetes/pkg/apis/core/v1/validation" @@ -74,11 +73,11 @@ const ( // Server is a http.Handler which exposes kubelet functionality over HTTP. type Server struct { - auth AuthInterface - host HostInterface - restfulCont containerInterface - resourceAnalyzer stats.ResourceAnalyzer - runtime kubecontainer.Runtime + auth AuthInterface + host HostInterface + restfulCont containerInterface + resourceAnalyzer stats.ResourceAnalyzer + redirectContainerStreaming bool } type TLSOptions struct { @@ -124,11 +123,11 @@ func ListenAndServeKubeletServer( tlsOptions *TLSOptions, auth AuthInterface, enableDebuggingHandlers, - enableContentionProfiling bool, - runtime kubecontainer.Runtime, + enableContentionProfiling, + redirectContainerStreaming bool, criHandler http.Handler) { glog.Infof("Starting to listen on %s:%d", address, port) - handler := NewServer(host, resourceAnalyzer, auth, enableDebuggingHandlers, enableContentionProfiling, runtime, criHandler) + handler := NewServer(host, resourceAnalyzer, auth, enableDebuggingHandlers, enableContentionProfiling, redirectContainerStreaming, criHandler) s := &http.Server{ Addr: net.JoinHostPort(address.String(), strconv.FormatUint(uint64(port), 10)), Handler: &handler, @@ -146,9 +145,9 @@ func ListenAndServeKubeletServer( } // ListenAndServeKubeletReadOnlyServer initializes a server to respond to HTTP network requests on the Kubelet. -func ListenAndServeKubeletReadOnlyServer(host HostInterface, resourceAnalyzer stats.ResourceAnalyzer, address net.IP, port uint, runtime kubecontainer.Runtime) { +func ListenAndServeKubeletReadOnlyServer(host HostInterface, resourceAnalyzer stats.ResourceAnalyzer, address net.IP, port uint) { glog.V(1).Infof("Starting to listen read-only on %s:%d", address, port) - s := NewServer(host, resourceAnalyzer, nil, false, false, runtime, nil) + s := NewServer(host, resourceAnalyzer, nil, false, false, false, nil) server := &http.Server{ Addr: net.JoinHostPort(address.String(), strconv.FormatUint(uint64(port), 10)), @@ -173,12 +172,8 @@ type HostInterface interface { GetCachedMachineInfo() (*cadvisorapi.MachineInfo, error) GetRunningPods() ([]*v1.Pod, error) RunInContainer(name string, uid types.UID, container string, cmd []string) ([]byte, error) - ExecInContainer(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize, timeout time.Duration) error - AttachContainer(name string, uid types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error GetKubeletContainerLogs(podFullName, containerName string, logOptions *v1.PodLogOptions, stdout, stderr io.Writer) error ServeLogs(w http.ResponseWriter, req *http.Request) - PortForward(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error - StreamingConnectionIdleTimeout() time.Duration ResyncInterval() time.Duration GetHostname() string LatestLoopEntryTime() time.Time @@ -193,15 +188,15 @@ func NewServer( resourceAnalyzer stats.ResourceAnalyzer, auth AuthInterface, enableDebuggingHandlers, - enableContentionProfiling bool, - runtime kubecontainer.Runtime, + enableContentionProfiling, + redirectContainerStreaming bool, criHandler http.Handler) Server { server := Server{ - host: host, - resourceAnalyzer: resourceAnalyzer, - auth: auth, - restfulCont: &filteringContainer{Container: restful.NewContainer()}, - runtime: runtime, + host: host, + resourceAnalyzer: resourceAnalyzer, + auth: auth, + restfulCont: &filteringContainer{Container: restful.NewContainer()}, + redirectContainerStreaming: redirectContainerStreaming, } if auth != nil { server.InstallAuthFilter() @@ -627,6 +622,15 @@ func getPortForwardRequestParams(req *restful.Request) portForwardRequestParams } } +type responder struct { + errorMessage string +} + +func (r *responder) Error(w http.ResponseWriter, req *http.Request, err error) { + glog.Errorf("Error while proxying request: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) +} + // getAttach handles requests to attach to a container. func (s *Server) getAttach(request *restful.Request, response *restful.Response) { params := getExecRequestParams(request) @@ -643,26 +647,18 @@ func (s *Server) getAttach(request *restful.Request, response *restful.Response) } podFullName := kubecontainer.GetPodFullName(pod) - redirect, err := s.host.GetAttach(podFullName, params.podUID, params.containerName, *streamOpts) + url, err := s.host.GetAttach(podFullName, params.podUID, params.containerName, *streamOpts) if err != nil { streaming.WriteError(err, response.ResponseWriter) return } - if redirect != nil { - http.Redirect(response.ResponseWriter, request.Request, redirect.String(), http.StatusFound) + + if s.redirectContainerStreaming { + http.Redirect(response.ResponseWriter, request.Request, url.String(), http.StatusFound) return } - - remotecommandserver.ServeAttach(response.ResponseWriter, - request.Request, - s.host, - podFullName, - params.podUID, - params.containerName, - streamOpts, - s.host.StreamingConnectionIdleTimeout(), - remotecommandconsts.DefaultStreamCreationTimeout, - remotecommandconsts.SupportedStreamingProtocols) + handler := proxy.NewUpgradeAwareHandler(url, nil /*transport*/, false /*wrapTransport*/, false /*upgradeRequired*/, &responder{}) + handler.ServeHTTP(response.ResponseWriter, request.Request) } // getExec handles requests to run a command inside a container. @@ -681,27 +677,17 @@ func (s *Server) getExec(request *restful.Request, response *restful.Response) { } podFullName := kubecontainer.GetPodFullName(pod) - redirect, err := s.host.GetExec(podFullName, params.podUID, params.containerName, params.cmd, *streamOpts) + url, err := s.host.GetExec(podFullName, params.podUID, params.containerName, params.cmd, *streamOpts) if err != nil { streaming.WriteError(err, response.ResponseWriter) return } - if redirect != nil { - http.Redirect(response.ResponseWriter, request.Request, redirect.String(), http.StatusFound) + if s.redirectContainerStreaming { + http.Redirect(response.ResponseWriter, request.Request, url.String(), http.StatusFound) return } - - remotecommandserver.ServeExec(response.ResponseWriter, - request.Request, - s.host, - podFullName, - params.podUID, - params.containerName, - params.cmd, - streamOpts, - s.host.StreamingConnectionIdleTimeout(), - remotecommandconsts.DefaultStreamCreationTimeout, - remotecommandconsts.SupportedStreamingProtocols) + handler := proxy.NewUpgradeAwareHandler(url, nil /*transport*/, false /*wrapTransport*/, false /*upgradeRequired*/, &responder{}) + handler.ServeHTTP(response.ResponseWriter, request.Request) } // getRun handles requests to run a command inside a container. @@ -758,25 +744,17 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp return } - redirect, err := s.host.GetPortForward(pod.Name, pod.Namespace, pod.UID, *portForwardOptions) + url, err := s.host.GetPortForward(pod.Name, pod.Namespace, pod.UID, *portForwardOptions) if err != nil { streaming.WriteError(err, response.ResponseWriter) return } - if redirect != nil { - http.Redirect(response.ResponseWriter, request.Request, redirect.String(), http.StatusFound) + if s.redirectContainerStreaming { + http.Redirect(response.ResponseWriter, request.Request, url.String(), http.StatusFound) return } - - portforward.ServePortForward(response.ResponseWriter, - request.Request, - s.host, - kubecontainer.GetPodFullName(pod), - params.podUID, - portForwardOptions, - s.host.StreamingConnectionIdleTimeout(), - remotecommandconsts.DefaultStreamCreationTimeout, - portforward.SupportedProtocols) + handler := proxy.NewUpgradeAwareHandler(url, nil /*transport*/, false /*wrapTransport*/, false /*upgradeRequired*/, &responder{}) + handler.ServeHTTP(response.ResponseWriter, request.Request) } // ServeHTTP responds to HTTP requests on the Kubelet. diff --git a/pkg/kubelet/server/server_test.go b/pkg/kubelet/server/server_test.go index cdb978078cf..e84bec4d649 100644 --- a/pkg/kubelet/server/server_test.go +++ b/pkg/kubelet/server/server_test.go @@ -48,41 +48,44 @@ import ( "k8s.io/client-go/tools/remotecommand" utiltesting "k8s.io/client-go/util/testing" api "k8s.io/kubernetes/pkg/apis/core" + runtimeapi "k8s.io/kubernetes/pkg/kubelet/apis/cri/runtime/v1alpha2" statsapi "k8s.io/kubernetes/pkg/kubelet/apis/stats/v1alpha1" // Do some initialization to decode the query parameters correctly. _ "k8s.io/kubernetes/pkg/apis/core/install" "k8s.io/kubernetes/pkg/kubelet/cm" kubecontainer "k8s.io/kubernetes/pkg/kubelet/container" - kubecontainertesting "k8s.io/kubernetes/pkg/kubelet/container/testing" "k8s.io/kubernetes/pkg/kubelet/server/portforward" remotecommandserver "k8s.io/kubernetes/pkg/kubelet/server/remotecommand" "k8s.io/kubernetes/pkg/kubelet/server/stats" + "k8s.io/kubernetes/pkg/kubelet/server/streaming" "k8s.io/kubernetes/pkg/volume" ) const ( - testUID = "9b01b80f-8fb4-11e4-95ab-4200af06647" + testUID = "9b01b80f-8fb4-11e4-95ab-4200af06647" + testContainerID = "container789" + testPodSandboxID = "pod0987" ) type fakeKubelet struct { - podByNameFunc func(namespace, name string) (*v1.Pod, bool) - containerInfoFunc func(podFullName string, uid types.UID, containerName string, req *cadvisorapi.ContainerInfoRequest) (*cadvisorapi.ContainerInfo, error) - rawInfoFunc func(query *cadvisorapi.ContainerInfoRequest) (map[string]*cadvisorapi.ContainerInfo, error) - machineInfoFunc func() (*cadvisorapi.MachineInfo, error) - podsFunc func() []*v1.Pod - runningPodsFunc func() ([]*v1.Pod, error) - logFunc func(w http.ResponseWriter, req *http.Request) - runFunc func(podFullName string, uid types.UID, containerName string, cmd []string) ([]byte, error) - execFunc func(pod string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error - attachFunc func(pod string, uid types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool) error - portForwardFunc func(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error - containerLogsFunc func(podFullName, containerName string, logOptions *v1.PodLogOptions, stdout, stderr io.Writer) error - streamingConnectionIdleTimeoutFunc func() time.Duration - hostnameFunc func() string - resyncInterval time.Duration - loopEntryTime time.Time - plegHealth bool - redirectURL *url.URL + podByNameFunc func(namespace, name string) (*v1.Pod, bool) + containerInfoFunc func(podFullName string, uid types.UID, containerName string, req *cadvisorapi.ContainerInfoRequest) (*cadvisorapi.ContainerInfo, error) + rawInfoFunc func(query *cadvisorapi.ContainerInfoRequest) (map[string]*cadvisorapi.ContainerInfo, error) + machineInfoFunc func() (*cadvisorapi.MachineInfo, error) + podsFunc func() []*v1.Pod + runningPodsFunc func() ([]*v1.Pod, error) + logFunc func(w http.ResponseWriter, req *http.Request) + runFunc func(podFullName string, uid types.UID, containerName string, cmd []string) ([]byte, error) + getExecCheck func(string, types.UID, string, []string, remotecommandserver.Options) + getAttachCheck func(string, types.UID, string, remotecommandserver.Options) + getPortForwardCheck func(string, string, types.UID, portforward.V4Options) + + containerLogsFunc func(podFullName, containerName string, logOptions *v1.PodLogOptions, stdout, stderr io.Writer) error + hostnameFunc func() string + resyncInterval time.Duration + loopEntryTime time.Time + plegHealth bool + streamingRuntime streaming.Server } func (fk *fakeKubelet) ResyncInterval() time.Duration { @@ -137,32 +140,109 @@ func (fk *fakeKubelet) RunInContainer(podFullName string, uid types.UID, contain return fk.runFunc(podFullName, uid, containerName, cmd) } -func (fk *fakeKubelet) ExecInContainer(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize, timeout time.Duration) error { - return fk.execFunc(name, uid, container, cmd, in, out, err, tty) +type fakeRuntime struct { + execFunc func(string, []string, io.Reader, io.WriteCloser, io.WriteCloser, bool, <-chan remotecommand.TerminalSize) error + attachFunc func(string, io.Reader, io.WriteCloser, io.WriteCloser, bool, <-chan remotecommand.TerminalSize) error + portForwardFunc func(string, int32, io.ReadWriteCloser) error } -func (fk *fakeKubelet) AttachContainer(name string, uid types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error { - return fk.attachFunc(name, uid, container, in, out, err, tty) +func (f *fakeRuntime) Exec(containerID string, cmd []string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error { + return f.execFunc(containerID, cmd, stdin, stdout, stderr, tty, resize) } -func (fk *fakeKubelet) PortForward(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error { - return fk.portForwardFunc(name, uid, port, stream) +func (f *fakeRuntime) Attach(containerID string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error { + return f.attachFunc(containerID, stdin, stdout, stderr, tty, resize) +} + +func (f *fakeRuntime) PortForward(podSandboxID string, port int32, stream io.ReadWriteCloser) error { + return f.portForwardFunc(podSandboxID, port, stream) +} + +type testStreamingServer struct { + streaming.Server + fakeRuntime *fakeRuntime + testHTTPServer *httptest.Server +} + +func newTestStreamingServer(streamIdleTimeout time.Duration) (s *testStreamingServer, err error) { + s = &testStreamingServer{} + s.testHTTPServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.ServeHTTP(w, r) + })) + defer func() { + if err != nil { + s.testHTTPServer.Close() + } + }() + + testURL, err := url.Parse(s.testHTTPServer.URL) + if err != nil { + return nil, err + } + + s.fakeRuntime = &fakeRuntime{} + config := streaming.DefaultConfig + config.BaseURL = testURL + if streamIdleTimeout != 0 { + config.StreamIdleTimeout = streamIdleTimeout + } + s.Server, err = streaming.NewServer(config, s.fakeRuntime) + if err != nil { + return nil, err + } + return s, nil } func (fk *fakeKubelet) GetExec(podFullName string, podUID types.UID, containerName string, cmd []string, streamOpts remotecommandserver.Options) (*url.URL, error) { - return fk.redirectURL, nil + if fk.getExecCheck != nil { + fk.getExecCheck(podFullName, podUID, containerName, cmd, streamOpts) + } + // Always use testContainerID + resp, err := fk.streamingRuntime.GetExec(&runtimeapi.ExecRequest{ + ContainerId: testContainerID, + Cmd: cmd, + Tty: streamOpts.TTY, + Stdin: streamOpts.Stdin, + Stdout: streamOpts.Stdout, + Stderr: streamOpts.Stderr, + }) + if err != nil { + return nil, err + } + return url.Parse(resp.GetUrl()) } func (fk *fakeKubelet) GetAttach(podFullName string, podUID types.UID, containerName string, streamOpts remotecommandserver.Options) (*url.URL, error) { - return fk.redirectURL, nil + if fk.getAttachCheck != nil { + fk.getAttachCheck(podFullName, podUID, containerName, streamOpts) + } + // Always use testContainerID + resp, err := fk.streamingRuntime.GetAttach(&runtimeapi.AttachRequest{ + ContainerId: testContainerID, + Tty: streamOpts.TTY, + Stdin: streamOpts.Stdin, + Stdout: streamOpts.Stdout, + Stderr: streamOpts.Stderr, + }) + if err != nil { + return nil, err + } + return url.Parse(resp.GetUrl()) } func (fk *fakeKubelet) GetPortForward(podName, podNamespace string, podUID types.UID, portForwardOpts portforward.V4Options) (*url.URL, error) { - return fk.redirectURL, nil -} - -func (fk *fakeKubelet) StreamingConnectionIdleTimeout() time.Duration { - return fk.streamingConnectionIdleTimeoutFunc() + if fk.getPortForwardCheck != nil { + fk.getPortForwardCheck(podName, podNamespace, podUID, portForwardOpts) + } + // Always use testPodSandboxID + resp, err := fk.streamingRuntime.GetPortForward(&runtimeapi.PortForwardRequest{ + PodSandboxId: testPodSandboxID, + Port: portForwardOpts.Ports, + }) + if err != nil { + return nil, err + } + return url.Parse(resp.GetUrl()) } // Unused functions @@ -199,18 +279,20 @@ func (f *fakeAuth) Authorize(a authorizer.Attributes) (authorized authorizer.Dec } type serverTestFramework struct { - serverUnderTest *Server - fakeKubelet *fakeKubelet - fakeAuth *fakeAuth - testHTTPServer *httptest.Server - criHandler *utiltesting.FakeHandler + serverUnderTest *Server + fakeKubelet *fakeKubelet + fakeAuth *fakeAuth + testHTTPServer *httptest.Server + fakeRuntime *fakeRuntime + testStreamingHTTPServer *httptest.Server + criHandler *utiltesting.FakeHandler } func newServerTest() *serverTestFramework { - return newServerTestWithDebug(true) + return newServerTestWithDebug(true, false, nil) } -func newServerTestWithDebug(enableDebugging bool) *serverTestFramework { +func newServerTestWithDebug(enableDebugging, redirectContainerStreaming bool, streamingServer streaming.Server) *serverTestFramework { fw := &serverTestFramework{} fw.fakeKubelet = &fakeKubelet{ hostnameFunc: func() string { @@ -225,7 +307,8 @@ func newServerTestWithDebug(enableDebugging bool) *serverTestFramework { }, }, true }, - plegHealth: true, + plegHealth: true, + streamingRuntime: streamingServer, } fw.fakeAuth = &fakeAuth{ authenticateFunc: func(req *http.Request) (user.Info, bool, error) { @@ -247,7 +330,7 @@ func newServerTestWithDebug(enableDebugging bool) *serverTestFramework { fw.fakeAuth, enableDebugging, false, - &kubecontainertesting.Mock{}, + redirectContainerStreaming, fw.criHandler) fw.serverUnderTest = &server fw.testHTTPServer = httptest.NewServer(fw.serverUnderTest) @@ -1070,13 +1153,12 @@ func TestContainerLogsWithFollow(t *testing.T) { } func TestServeExecInContainerIdleTimeout(t *testing.T) { - fw := newServerTest() + ss, err := newTestStreamingServer(100 * time.Millisecond) + require.NoError(t, err) + defer ss.testHTTPServer.Close() + fw := newServerTestWithDebug(true, false, ss) defer fw.testHTTPServer.Close() - fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration { - return 100 * time.Millisecond - } - podNamespace := "other" podName := "foo" expectedContainerName := "baz" @@ -1108,280 +1190,221 @@ func TestServeExecInContainerIdleTimeout(t *testing.T) { } func testExecAttach(t *testing.T, verb string) { - tests := []struct { + tests := map[string]struct { stdin bool stdout bool stderr bool tty bool responseStatusCode int uid bool - responseLocation string + redirect bool }{ - {responseStatusCode: http.StatusBadRequest}, - {stdin: true, responseStatusCode: http.StatusSwitchingProtocols}, - {stdout: true, responseStatusCode: http.StatusSwitchingProtocols}, - {stderr: true, responseStatusCode: http.StatusSwitchingProtocols}, - {stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols}, - {stdout: true, stderr: true, tty: true, responseStatusCode: http.StatusSwitchingProtocols}, - {stdin: true, stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols}, - {stdout: true, responseStatusCode: http.StatusFound, responseLocation: "http://localhost:12345/" + verb}, + "no input or output": {responseStatusCode: http.StatusBadRequest}, + "stdin": {stdin: true, responseStatusCode: http.StatusSwitchingProtocols}, + "stdout": {stdout: true, responseStatusCode: http.StatusSwitchingProtocols}, + "stderr": {stderr: true, responseStatusCode: http.StatusSwitchingProtocols}, + "stdout and stderr": {stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols}, + "stdout stderr and tty": {stdout: true, stderr: true, tty: true, responseStatusCode: http.StatusSwitchingProtocols}, + "stdin stdout and stderr": {stdin: true, stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols}, + "stdin stdout stderr with uid": {stdin: true, stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols, uid: true}, + "stdout with redirect": {stdout: true, responseStatusCode: http.StatusFound, redirect: true}, } - for i, test := range tests { - fw := newServerTest() - defer fw.testHTTPServer.Close() - - fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration { - return 0 - } - - if test.responseLocation != "" { - var err error - fw.fakeKubelet.redirectURL, err = url.Parse(test.responseLocation) + for desc, test := range tests { + test := test + t.Run(desc, func(t *testing.T) { + ss, err := newTestStreamingServer(0) require.NoError(t, err) - } + defer ss.testHTTPServer.Close() + fw := newServerTestWithDebug(true, test.redirect, ss) + defer fw.testHTTPServer.Close() + fmt.Println(desc) - podNamespace := "other" - podName := "foo" - expectedPodName := getPodName(podName, podNamespace) - expectedContainerName := "baz" - expectedCommand := "ls -a" - expectedStdin := "stdin" - expectedStdout := "stdout" - expectedStderr := "stderr" - done := make(chan struct{}) - clientStdoutReadDone := make(chan struct{}) - clientStderrReadDone := make(chan struct{}) - execInvoked := false - attachInvoked := false + podNamespace := "other" + podName := "foo" + expectedPodName := getPodName(podName, podNamespace) + expectedContainerName := "baz" + expectedCommand := "ls -a" + expectedStdin := "stdin" + expectedStdout := "stdout" + expectedStderr := "stderr" + done := make(chan struct{}) + clientStdoutReadDone := make(chan struct{}) + clientStderrReadDone := make(chan struct{}) + execInvoked := false + attachInvoked := false - testStreamFunc := func(podFullName string, uid types.UID, containerName string, cmd []string, in io.Reader, out, stderr io.WriteCloser, tty bool, done chan struct{}) error { - defer close(done) + checkStream := func(podFullName string, uid types.UID, containerName string, streamOpts remotecommandserver.Options) { + assert.Equal(t, expectedPodName, podFullName, "podFullName") + if test.uid { + assert.Equal(t, testUID, string(uid), "uid") + } + assert.Equal(t, expectedContainerName, containerName, "containerName") + assert.Equal(t, test.stdin, streamOpts.Stdin, "stdin") + assert.Equal(t, test.stdout, streamOpts.Stdout, "stdout") + assert.Equal(t, test.tty, streamOpts.TTY, "tty") + assert.Equal(t, !test.tty && test.stderr, streamOpts.Stderr, "stderr") + } - if podFullName != expectedPodName { - t.Fatalf("%d: podFullName: expected %s, got %s", i, expectedPodName, podFullName) + fw.fakeKubelet.getExecCheck = func(podFullName string, uid types.UID, containerName string, cmd []string, streamOpts remotecommandserver.Options) { + execInvoked = true + assert.Equal(t, expectedCommand, strings.Join(cmd, " "), "cmd") + checkStream(podFullName, uid, containerName, streamOpts) } - if test.uid && string(uid) != testUID { - t.Fatalf("%d: uid: expected %v, got %v", i, testUID, uid) + + fw.fakeKubelet.getAttachCheck = func(podFullName string, uid types.UID, containerName string, streamOpts remotecommandserver.Options) { + attachInvoked = true + checkStream(podFullName, uid, containerName, streamOpts) } - if containerName != expectedContainerName { - t.Fatalf("%d: containerName: expected %s, got %s", i, expectedContainerName, containerName) + + testStream := func(containerID string, in io.Reader, out, stderr io.WriteCloser, tty bool, done chan struct{}) error { + close(done) + assert.Equal(t, testContainerID, containerID, "containerID") + assert.Equal(t, test.tty, tty, "tty") + require.Equal(t, test.stdin, in != nil, "in") + require.Equal(t, test.stdout, out != nil, "out") + require.Equal(t, !test.tty && test.stderr, stderr != nil, "err") + + if test.stdin { + b := make([]byte, 10) + n, err := in.Read(b) + assert.NoError(t, err, "reading from stdin") + assert.Equal(t, expectedStdin, string(b[0:n]), "content from stdin") + } + + if test.stdout { + _, err := out.Write([]byte(expectedStdout)) + assert.NoError(t, err, "writing to stdout") + out.Close() + <-clientStdoutReadDone + } + + if !test.tty && test.stderr { + _, err := stderr.Write([]byte(expectedStderr)) + assert.NoError(t, err, "writing to stderr") + stderr.Close() + <-clientStderrReadDone + } + return nil } + ss.fakeRuntime.execFunc = func(containerID string, cmd []string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error { + assert.Equal(t, expectedCommand, strings.Join(cmd, " "), "cmd") + return testStream(containerID, stdin, stdout, stderr, tty, done) + } + + ss.fakeRuntime.attachFunc = func(containerID string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error { + return testStream(containerID, stdin, stdout, stderr, tty, done) + } + + var url string + if test.uid { + url = fw.testHTTPServer.URL + "/" + verb + "/" + podNamespace + "/" + podName + "/" + testUID + "/" + expectedContainerName + "?ignore=1" + } else { + url = fw.testHTTPServer.URL + "/" + verb + "/" + podNamespace + "/" + podName + "/" + expectedContainerName + "?ignore=1" + } + if verb == "exec" { + url += "&command=ls&command=-a" + } + if test.stdin { + url += "&" + api.ExecStdinParam + "=1" + } + if test.stdout { + url += "&" + api.ExecStdoutParam + "=1" + } + if test.stderr && !test.tty { + url += "&" + api.ExecStderrParam + "=1" + } + if test.tty { + url += "&" + api.ExecTTYParam + "=1" + } + + var ( + resp *http.Response + upgradeRoundTripper httpstream.UpgradeRoundTripper + c *http.Client + ) + if test.redirect { + c = &http.Client{} + // Don't follow redirects, since we want to inspect the redirect response. + c.CheckRedirect = func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + } + } else { + upgradeRoundTripper = spdy.NewRoundTripper(nil, true) + c = &http.Client{Transport: upgradeRoundTripper} + } + + resp, err = c.Post(url, "", nil) + require.NoError(t, err, "POSTing") + defer resp.Body.Close() + + _, err = ioutil.ReadAll(resp.Body) + assert.NoError(t, err, "reading response body") + + require.Equal(t, test.responseStatusCode, resp.StatusCode, "response status") + if test.responseStatusCode != http.StatusSwitchingProtocols { + return + } + + conn, err := upgradeRoundTripper.NewConnection(resp) + require.NoError(t, err, "creating streaming connection") + defer conn.Close() + + h := http.Header{} + h.Set(api.StreamType, api.StreamTypeError) + _, err = conn.CreateStream(h) + require.NoError(t, err, "creating error stream") + if test.stdin { - if in == nil { - t.Fatalf("%d: stdin: expected non-nil", i) - } - b := make([]byte, 10) - n, err := in.Read(b) - if err != nil { - t.Fatalf("%d: error reading from stdin: %v", i, err) - } - if e, a := expectedStdin, string(b[0:n]); e != a { - t.Fatalf("%d: stdin: expected to read %v, got %v", i, e, a) - } - } else if in != nil { - t.Fatalf("%d: stdin: expected nil: %#v", i, in) + h.Set(api.StreamType, api.StreamTypeStdin) + stream, err := conn.CreateStream(h) + require.NoError(t, err, "creating stdin stream") + _, err = stream.Write([]byte(expectedStdin)) + require.NoError(t, err, "writing to stdin stream") + } + + var stdoutStream httpstream.Stream + if test.stdout { + h.Set(api.StreamType, api.StreamTypeStdout) + stdoutStream, err = conn.CreateStream(h) + require.NoError(t, err, "creating stdout stream") + } + + var stderrStream httpstream.Stream + if test.stderr && !test.tty { + h.Set(api.StreamType, api.StreamTypeStderr) + stderrStream, err = conn.CreateStream(h) + require.NoError(t, err, "creating stderr stream") } if test.stdout { - if out == nil { - t.Fatalf("%d: stdout: expected non-nil", i) - } - _, err := out.Write([]byte(expectedStdout)) - if err != nil { - t.Fatalf("%d:, error writing to stdout: %v", i, err) - } - out.Close() - <-clientStdoutReadDone - } else if out != nil { - t.Fatalf("%d: stdout: expected nil: %#v", i, out) + output := make([]byte, 10) + n, err := stdoutStream.Read(output) + close(clientStdoutReadDone) + assert.NoError(t, err, "reading from stdout stream") + assert.Equal(t, expectedStdout, string(output[0:n]), "stdout") } - if tty { - if stderr != nil { - t.Fatalf("%d: tty set but received non-nil stderr: %v", i, stderr) - } - } else if test.stderr { - if stderr == nil { - t.Fatalf("%d: stderr: expected non-nil", i) - } - _, err := stderr.Write([]byte(expectedStderr)) - if err != nil { - t.Fatalf("%d:, error writing to stderr: %v", i, err) - } - stderr.Close() - <-clientStderrReadDone - } else if stderr != nil { - t.Fatalf("%d: stderr: expected nil: %#v", i, stderr) + if test.stderr && !test.tty { + output := make([]byte, 10) + n, err := stderrStream.Read(output) + close(clientStderrReadDone) + assert.NoError(t, err, "reading from stderr stream") + assert.Equal(t, expectedStderr, string(output[0:n]), "stderr") } - return nil - } + // wait for the server to finish before checking if the attach/exec funcs were invoked + <-done - fw.fakeKubelet.execFunc = func(podFullName string, uid types.UID, containerName string, cmd []string, in io.Reader, out, stderr io.WriteCloser, tty bool) error { - execInvoked = true - if strings.Join(cmd, " ") != expectedCommand { - t.Fatalf("%d: cmd: expected: %s, got %v", i, expectedCommand, cmd) + if verb == "exec" { + assert.True(t, execInvoked, "exec should be invoked") + assert.False(t, attachInvoked, "attach should not be invoked") + } else { + assert.True(t, attachInvoked, "attach should be invoked") + assert.False(t, execInvoked, "exec should not be invoked") } - return testStreamFunc(podFullName, uid, containerName, cmd, in, out, stderr, tty, done) - } - - fw.fakeKubelet.attachFunc = func(podFullName string, uid types.UID, containerName string, in io.Reader, out, stderr io.WriteCloser, tty bool) error { - attachInvoked = true - return testStreamFunc(podFullName, uid, containerName, nil, in, out, stderr, tty, done) - } - - var url string - if test.uid { - url = fw.testHTTPServer.URL + "/" + verb + "/" + podNamespace + "/" + podName + "/" + testUID + "/" + expectedContainerName + "?ignore=1" - } else { - url = fw.testHTTPServer.URL + "/" + verb + "/" + podNamespace + "/" + podName + "/" + expectedContainerName + "?ignore=1" - } - if verb == "exec" { - url += "&command=ls&command=-a" - } - if test.stdin { - url += "&" + api.ExecStdinParam + "=1" - } - if test.stdout { - url += "&" + api.ExecStdoutParam + "=1" - } - if test.stderr && !test.tty { - url += "&" + api.ExecStderrParam + "=1" - } - if test.tty { - url += "&" + api.ExecTTYParam + "=1" - } - - var ( - resp *http.Response - err error - upgradeRoundTripper httpstream.UpgradeRoundTripper - c *http.Client - ) - - if test.responseStatusCode != http.StatusSwitchingProtocols { - c = &http.Client{} - // Don't follow redirects, since we want to inspect the redirect response. - c.CheckRedirect = func(*http.Request, []*http.Request) error { - return http.ErrUseLastResponse - } - } else { - upgradeRoundTripper = spdy.NewRoundTripper(nil, true) - c = &http.Client{Transport: upgradeRoundTripper} - } - - resp, err = c.Post(url, "", nil) - if err != nil { - t.Fatalf("%d: Got error POSTing: %v", i, err) - } - defer resp.Body.Close() - - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Errorf("%d: Error reading response body: %v", i, err) - } - - if e, a := test.responseStatusCode, resp.StatusCode; e != a { - t.Fatalf("%d: response status: expected %v, got %v", i, e, a) - } - - if e, a := test.responseLocation, resp.Header.Get("Location"); e != a { - t.Errorf("%d: response location: expected %v, got %v", i, e, a) - } - - if test.responseStatusCode != http.StatusSwitchingProtocols { - continue - } - - conn, err := upgradeRoundTripper.NewConnection(resp) - if err != nil { - t.Fatalf("Unexpected error creating streaming connection: %s", err) - } - if conn == nil { - t.Fatalf("%d: unexpected nil conn", i) - } - defer conn.Close() - - h := http.Header{} - h.Set(api.StreamType, api.StreamTypeError) - if _, err := conn.CreateStream(h); err != nil { - t.Fatalf("%d: error creating error stream: %v", i, err) - } - - if test.stdin { - h.Set(api.StreamType, api.StreamTypeStdin) - stream, err := conn.CreateStream(h) - if err != nil { - t.Fatalf("%d: error creating stdin stream: %v", i, err) - } - _, err = stream.Write([]byte(expectedStdin)) - if err != nil { - t.Fatalf("%d: error writing to stdin stream: %v", i, err) - } - } - - var stdoutStream httpstream.Stream - if test.stdout { - h.Set(api.StreamType, api.StreamTypeStdout) - stdoutStream, err = conn.CreateStream(h) - if err != nil { - t.Fatalf("%d: error creating stdout stream: %v", i, err) - } - } - - var stderrStream httpstream.Stream - if test.stderr && !test.tty { - h.Set(api.StreamType, api.StreamTypeStderr) - stderrStream, err = conn.CreateStream(h) - if err != nil { - t.Fatalf("%d: error creating stderr stream: %v", i, err) - } - } - - if test.stdout { - output := make([]byte, 10) - n, err := stdoutStream.Read(output) - close(clientStdoutReadDone) - if err != nil { - t.Fatalf("%d: error reading from stdout stream: %v", i, err) - } - if e, a := expectedStdout, string(output[0:n]); e != a { - t.Fatalf("%d: stdout: expected '%v', got '%v'", i, e, a) - } - } - - if test.stderr && !test.tty { - output := make([]byte, 10) - n, err := stderrStream.Read(output) - close(clientStderrReadDone) - if err != nil { - t.Fatalf("%d: error reading from stderr stream: %v", i, err) - } - if e, a := expectedStderr, string(output[0:n]); e != a { - t.Fatalf("%d: stderr: expected '%v', got '%v'", i, e, a) - } - } - - // wait for the server to finish before checking if the attach/exec funcs were invoked - <-done - - if verb == "exec" { - if !execInvoked { - t.Errorf("%d: exec was not invoked", i) - } - if attachInvoked { - t.Errorf("%d: attach should not have been invoked", i) - } - } else { - if !attachInvoked { - t.Errorf("%d: attach was not invoked", i) - } - if execInvoked { - t.Errorf("%d: exec should not have been invoked", i) - } - } + }) } } @@ -1394,13 +1417,12 @@ func TestServeAttachContainer(t *testing.T) { } func TestServePortForwardIdleTimeout(t *testing.T) { - fw := newServerTest() + ss, err := newTestStreamingServer(100 * time.Millisecond) + require.NoError(t, err) + defer ss.testHTTPServer.Close() + fw := newServerTestWithDebug(true, false, ss) defer fw.testHTTPServer.Close() - fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration { - return 100 * time.Millisecond - } - podNamespace := "other" podName := "foo" @@ -1428,174 +1450,139 @@ func TestServePortForwardIdleTimeout(t *testing.T) { } func TestServePortForward(t *testing.T) { - tests := []struct { - port string - uid bool - clientData string - containerData string - shouldError bool - responseLocation string + tests := map[string]struct { + port string + uid bool + clientData string + containerData string + redirect bool + shouldError bool }{ - {port: "", shouldError: true}, - {port: "abc", shouldError: true}, - {port: "-1", shouldError: true}, - {port: "65536", shouldError: true}, - {port: "0", shouldError: true}, - {port: "1", shouldError: false}, - {port: "8000", shouldError: false}, - {port: "8000", clientData: "client data", containerData: "container data", shouldError: false}, - {port: "65535", shouldError: false}, - {port: "65535", uid: true, shouldError: false}, - {port: "65535", responseLocation: "http://localhost:12345/portforward", shouldError: false}, + "no port": {port: "", shouldError: true}, + "none number port": {port: "abc", shouldError: true}, + "negative port": {port: "-1", shouldError: true}, + "too large port": {port: "65536", shouldError: true}, + "0 port": {port: "0", shouldError: true}, + "min port": {port: "1", shouldError: false}, + "normal port": {port: "8000", shouldError: false}, + "normal port with data forward": {port: "8000", clientData: "client data", containerData: "container data", shouldError: false}, + "max port": {port: "65535", shouldError: false}, + "normal port with uid": {port: "8000", uid: true, shouldError: false}, + "normal port with redirect": {port: "8000", redirect: true, shouldError: false}, } podNamespace := "other" podName := "foo" - expectedPodName := getPodName(podName, podNamespace) - for i, test := range tests { - fw := newServerTest() - defer fw.testHTTPServer.Close() - - fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration { - return 0 - } - - if test.responseLocation != "" { - var err error - fw.fakeKubelet.redirectURL, err = url.Parse(test.responseLocation) + for desc, test := range tests { + test := test + t.Run(desc, func(t *testing.T) { + ss, err := newTestStreamingServer(0) require.NoError(t, err) - } + defer ss.testHTTPServer.Close() + fw := newServerTestWithDebug(true, test.redirect, ss) + defer fw.testHTTPServer.Close() - portForwardFuncDone := make(chan struct{}) + portForwardFuncDone := make(chan struct{}) - fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error { - defer close(portForwardFuncDone) - - if e, a := expectedPodName, name; e != a { - t.Fatalf("%d: pod name: expected '%v', got '%v'", i, e, a) + fw.fakeKubelet.getPortForwardCheck = func(name, namespace string, uid types.UID, opts portforward.V4Options) { + assert.Equal(t, podName, name, "pod name") + assert.Equal(t, podNamespace, namespace, "pod namespace") + if test.uid { + assert.Equal(t, testUID, string(uid), "uid") + } } - if e, a := testUID, uid; test.uid && e != string(a) { - t.Fatalf("%d: uid: expected '%v', got '%v'", i, e, a) + ss.fakeRuntime.portForwardFunc = func(podSandboxID string, port int32, stream io.ReadWriteCloser) error { + defer close(portForwardFuncDone) + assert.Equal(t, testPodSandboxID, podSandboxID, "pod sandbox id") + // The port should be valid if it reaches here. + testPort, err := strconv.ParseInt(test.port, 10, 32) + require.NoError(t, err, "parse port") + assert.Equal(t, int32(testPort), port, "port") + + if test.clientData != "" { + fromClient := make([]byte, 32) + n, err := stream.Read(fromClient) + assert.NoError(t, err, "reading client data") + assert.Equal(t, test.clientData, string(fromClient[0:n]), "client data") + } + + if test.containerData != "" { + _, err := stream.Write([]byte(test.containerData)) + assert.NoError(t, err, "writing container data") + } + + return nil } - p, err := strconv.ParseInt(test.port, 10, 32) - if err != nil { - t.Fatalf("%d: error parsing port string '%s': %v", i, test.port, err) + var url string + if test.uid { + url = fmt.Sprintf("%s/portForward/%s/%s/%s", fw.testHTTPServer.URL, podNamespace, podName, testUID) + } else { + url = fmt.Sprintf("%s/portForward/%s/%s", fw.testHTTPServer.URL, podNamespace, podName) } - if e, a := int32(p), port; e != a { - t.Fatalf("%d: port: expected '%v', got '%v'", i, e, a) + + var ( + upgradeRoundTripper httpstream.UpgradeRoundTripper + c *http.Client + ) + + if test.redirect { + c = &http.Client{} + // Don't follow redirects, since we want to inspect the redirect response. + c.CheckRedirect = func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + } + } else { + upgradeRoundTripper = spdy.NewRoundTripper(nil, true) + c = &http.Client{Transport: upgradeRoundTripper} } + resp, err := c.Post(url, "", nil) + require.NoError(t, err, "POSTing") + defer resp.Body.Close() + + if test.redirect { + assert.Equal(t, http.StatusFound, resp.StatusCode, "status code") + return + } else { + assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode, "status code") + } + + conn, err := upgradeRoundTripper.NewConnection(resp) + require.NoError(t, err, "creating streaming connection") + defer conn.Close() + + headers := http.Header{} + headers.Set("streamType", "error") + headers.Set("port", test.port) + _, err = conn.CreateStream(headers) + assert.Equal(t, test.shouldError, err != nil, "expect error") + + if test.shouldError { + return + } + + headers.Set("streamType", "data") + headers.Set("port", test.port) + dataStream, err := conn.CreateStream(headers) + require.NoError(t, err, "create stream") + if test.clientData != "" { - fromClient := make([]byte, 32) - n, err := stream.Read(fromClient) - if err != nil { - t.Fatalf("%d: error reading client data: %v", i, err) - } - if e, a := test.clientData, string(fromClient[0:n]); e != a { - t.Fatalf("%d: client data: expected to receive '%v', got '%v'", i, e, a) - } + _, err := dataStream.Write([]byte(test.clientData)) + assert.NoError(t, err, "writing client data") } if test.containerData != "" { - _, err := stream.Write([]byte(test.containerData)) - if err != nil { - t.Fatalf("%d: error writing container data: %v", i, err) - } + fromContainer := make([]byte, 32) + n, err := dataStream.Read(fromContainer) + assert.NoError(t, err, "reading container data") + assert.Equal(t, test.containerData, string(fromContainer[0:n]), "container data") } - return nil - } - - var url string - if test.uid { - url = fmt.Sprintf("%s/portForward/%s/%s/%s", fw.testHTTPServer.URL, podNamespace, podName, testUID) - } else { - url = fmt.Sprintf("%s/portForward/%s/%s", fw.testHTTPServer.URL, podNamespace, podName) - } - - var ( - upgradeRoundTripper httpstream.UpgradeRoundTripper - c *http.Client - ) - - if len(test.responseLocation) > 0 { - c = &http.Client{} - // Don't follow redirects, since we want to inspect the redirect response. - c.CheckRedirect = func(*http.Request, []*http.Request) error { - return http.ErrUseLastResponse - } - } else { - upgradeRoundTripper = spdy.NewRoundTripper(nil, true) - c = &http.Client{Transport: upgradeRoundTripper} - } - - resp, err := c.Post(url, "", nil) - if err != nil { - t.Fatalf("%d: Got error POSTing: %v", i, err) - } - defer resp.Body.Close() - - if test.responseLocation != "" { - assert.Equal(t, http.StatusFound, resp.StatusCode, "%d: status code", i) - assert.Equal(t, test.responseLocation, resp.Header.Get("Location"), "%d: location", i) - continue - } else { - assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode, "%d: status code", i) - } - - conn, err := upgradeRoundTripper.NewConnection(resp) - if err != nil { - t.Fatalf("Unexpected error creating streaming connection: %s", err) - } - if conn == nil { - t.Fatalf("%d: Unexpected nil connection", i) - } - defer conn.Close() - - headers := http.Header{} - headers.Set("streamType", "error") - headers.Set("port", test.port) - errorStream, err := conn.CreateStream(headers) - _ = errorStream - haveErr := err != nil - if e, a := test.shouldError, haveErr; e != a { - t.Fatalf("%d: create stream: expected err=%t, got %t: %v", i, e, a, err) - } - - if test.shouldError { - continue - } - - headers.Set("streamType", "data") - headers.Set("port", test.port) - dataStream, err := conn.CreateStream(headers) - haveErr = err != nil - if e, a := test.shouldError, haveErr; e != a { - t.Fatalf("%d: create stream: expected err=%t, got %t: %v", i, e, a, err) - } - - if test.clientData != "" { - _, err := dataStream.Write([]byte(test.clientData)) - if err != nil { - t.Fatalf("%d: unexpected error writing client data: %v", i, err) - } - } - - if test.containerData != "" { - fromContainer := make([]byte, 32) - n, err := dataStream.Read(fromContainer) - if err != nil { - t.Fatalf("%d: unexpected error reading container data: %v", i, err) - } - if e, a := test.containerData, string(fromContainer[0:n]); e != a { - t.Fatalf("%d: expected to receive '%v' from container, got '%v'", i, e, a) - } - } - - <-portForwardFuncDone + <-portForwardFuncDone + }) } } @@ -1616,7 +1603,7 @@ func TestCRIHandler(t *testing.T) { } func TestDebuggingDisabledHandlers(t *testing.T) { - fw := newServerTestWithDebug(false) + fw := newServerTestWithDebug(false, false, nil) defer fw.testHTTPServer.Close() paths := []string{ diff --git a/pkg/kubelet/server/server_websocket_test.go b/pkg/kubelet/server/server_websocket_test.go index 058b67d978a..daf6d356b63 100644 --- a/pkg/kubelet/server/server_websocket_test.go +++ b/pkg/kubelet/server/server_websocket_test.go @@ -23,11 +23,13 @@ import ( "strconv" "sync" "testing" - "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/net/websocket" "k8s.io/apimachinery/pkg/types" + "k8s.io/kubernetes/pkg/kubelet/server/portforward" ) const ( @@ -36,152 +38,114 @@ const ( ) func TestServeWSPortForward(t *testing.T) { - tests := []struct { + tests := map[string]struct { port string uid bool clientData string containerData string shouldError bool }{ - {port: "", shouldError: true}, - {port: "abc", shouldError: true}, - {port: "-1", shouldError: true}, - {port: "65536", shouldError: true}, - {port: "0", shouldError: true}, - {port: "1", shouldError: false}, - {port: "8000", shouldError: false}, - {port: "8000", clientData: "client data", containerData: "container data", shouldError: false}, - {port: "65535", shouldError: false}, - {port: "65535", uid: true, shouldError: false}, + "no port": {port: "", shouldError: true}, + "none number port": {port: "abc", shouldError: true}, + "negative port": {port: "-1", shouldError: true}, + "too large port": {port: "65536", shouldError: true}, + "0 port": {port: "0", shouldError: true}, + "min port": {port: "1", shouldError: false}, + "normal port": {port: "8000", shouldError: false}, + "normal port with data forward": {port: "8000", clientData: "client data", containerData: "container data", shouldError: false}, + "max port": {port: "65535", shouldError: false}, + "normal port with uid": {port: "8000", uid: true, shouldError: false}, } podNamespace := "other" podName := "foo" - expectedPodName := getPodName(podName, podNamespace) - expectedUid := "9b01b80f-8fb4-11e4-95ab-4200af06647" - for i, test := range tests { - fw := newServerTest() - defer fw.testHTTPServer.Close() + for desc, test := range tests { + test := test + t.Run(desc, func(t *testing.T) { + ss, err := newTestStreamingServer(0) + require.NoError(t, err) + defer ss.testHTTPServer.Close() + fw := newServerTestWithDebug(true, false, ss) + defer fw.testHTTPServer.Close() - fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration { - return 0 - } + portForwardFuncDone := make(chan struct{}) - portForwardFuncDone := make(chan struct{}) - - fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error { - defer close(portForwardFuncDone) - - if e, a := expectedPodName, name; e != a { - t.Fatalf("%d: pod name: expected '%v', got '%v'", i, e, a) + fw.fakeKubelet.getPortForwardCheck = func(name, namespace string, uid types.UID, opts portforward.V4Options) { + assert.Equal(t, podName, name, "pod name") + assert.Equal(t, podNamespace, namespace, "pod namespace") + if test.uid { + assert.Equal(t, testUID, string(uid), "uid") + } } - if e, a := expectedUid, uid; test.uid && e != string(a) { - t.Fatalf("%d: uid: expected '%v', got '%v'", i, e, a) + ss.fakeRuntime.portForwardFunc = func(podSandboxID string, port int32, stream io.ReadWriteCloser) error { + defer close(portForwardFuncDone) + assert.Equal(t, testPodSandboxID, podSandboxID, "pod sandbox id") + // The port should be valid if it reaches here. + testPort, err := strconv.ParseInt(test.port, 10, 32) + require.NoError(t, err, "parse port") + assert.Equal(t, int32(testPort), port, "port") + + if test.clientData != "" { + fromClient := make([]byte, 32) + n, err := stream.Read(fromClient) + assert.NoError(t, err, "reading client data") + assert.Equal(t, test.clientData, string(fromClient[0:n]), "client data") + } + + if test.containerData != "" { + _, err := stream.Write([]byte(test.containerData)) + assert.NoError(t, err, "writing container data") + } + + return nil } - p, err := strconv.ParseInt(test.port, 10, 32) - if err != nil { - t.Fatalf("%d: error parsing port string '%s': %v", i, test.port, err) + var url string + if test.uid { + url = fmt.Sprintf("ws://%s/portForward/%s/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, testUID, test.port) + } else { + url = fmt.Sprintf("ws://%s/portForward/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, test.port) } - if e, a := int32(p), port; e != a { - t.Fatalf("%d: port: expected '%v', got '%v'", i, e, a) + + ws, err := websocket.Dial(url, "", "http://127.0.0.1/") + assert.Equal(t, test.shouldError, err != nil, "websocket dial") + if test.shouldError { + return } + defer ws.Close() + + p, err := strconv.ParseUint(test.port, 10, 16) + require.NoError(t, err, "parse port") + p16 := uint16(p) + + channel, data, err := wsRead(ws) + require.NoError(t, err, "read") + assert.Equal(t, dataChannel, int(channel), "channel") + assert.Len(t, data, binary.Size(p16), "data size") + assert.Equal(t, p16, binary.LittleEndian.Uint16(data), "data") + + channel, data, err = wsRead(ws) + assert.NoError(t, err, "read") + assert.Equal(t, errorChannel, int(channel), "channel") + assert.Len(t, data, binary.Size(p16), "data size") + assert.Equal(t, p16, binary.LittleEndian.Uint16(data), "data") if test.clientData != "" { - fromClient := make([]byte, 32) - n, err := stream.Read(fromClient) - if err != nil { - t.Fatalf("%d: error reading client data: %v", i, err) - } - if e, a := test.clientData, string(fromClient[0:n]); e != a { - t.Fatalf("%d: client data: expected to receive '%v', got '%v'", i, e, a) - } + println("writing the client data") + err := wsWrite(ws, dataChannel, []byte(test.clientData)) + assert.NoError(t, err, "writing client data") } if test.containerData != "" { - _, err := stream.Write([]byte(test.containerData)) - if err != nil { - t.Fatalf("%d: error writing container data: %v", i, err) - } + _, data, err = wsRead(ws) + assert.NoError(t, err, "reading container data") + assert.Equal(t, test.containerData, string(data), "container data") } - return nil - } - - var url string - if test.uid { - url = fmt.Sprintf("ws://%s/portForward/%s/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, expectedUid, test.port) - } else { - url = fmt.Sprintf("ws://%s/portForward/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, test.port) - } - - ws, err := websocket.Dial(url, "", "http://127.0.0.1/") - if test.shouldError { - if err == nil { - t.Fatalf("%d: websocket dial expected err", i) - } - continue - } else if err != nil { - t.Fatalf("%d: websocket dial unexpected err: %v", i, err) - } - - defer ws.Close() - - p, err := strconv.ParseUint(test.port, 10, 16) - if err != nil { - t.Fatalf("%d: error parsing port string '%s': %v", i, test.port, err) - } - p16 := uint16(p) - - channel, data, err := wsRead(ws) - if err != nil { - t.Fatalf("%d: read failed: expected no error: got %v", i, err) - } - if channel != dataChannel { - t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, dataChannel) - } - if len(data) != binary.Size(p16) { - t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(p16)) - } - if e, a := p16, binary.LittleEndian.Uint16(data); e != a { - t.Fatalf("%d: wrong data: got %q: expected %s", i, data, test.port) - } - - channel, data, err = wsRead(ws) - if err != nil { - t.Fatalf("%d: read succeeded: expected no error: got %v", i, err) - } - if channel != errorChannel { - t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, errorChannel) - } - if len(data) != binary.Size(p16) { - t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(p16)) - } - if e, a := p16, binary.LittleEndian.Uint16(data); e != a { - t.Fatalf("%d: wrong data: got %q: expected %s", i, data, test.port) - } - - if test.clientData != "" { - println("writing the client data") - err := wsWrite(ws, dataChannel, []byte(test.clientData)) - if err != nil { - t.Fatalf("%d: unexpected error writing client data: %v", i, err) - } - } - - if test.containerData != "" { - _, data, err = wsRead(ws) - if err != nil { - t.Fatalf("%d: unexpected error reading container data: %v", i, err) - } - if e, a := test.containerData, string(data); e != a { - t.Fatalf("%d: expected to receive '%v' from container, got '%v'", i, e, a) - } - } - - <-portForwardFuncDone + <-portForwardFuncDone + }) } } @@ -190,27 +154,27 @@ func TestServeWSMultiplePortForward(t *testing.T) { ports := []uint16{7000, 8000, 9000} podNamespace := "other" podName := "foo" - expectedPodName := getPodName(podName, podNamespace) - fw := newServerTest() + ss, err := newTestStreamingServer(0) + require.NoError(t, err) + defer ss.testHTTPServer.Close() + fw := newServerTestWithDebug(true, false, ss) defer fw.testHTTPServer.Close() - fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration { - return 0 - } - portForwardWG := sync.WaitGroup{} portForwardWG.Add(len(ports)) portsMutex := sync.Mutex{} portsForwarded := map[int32]struct{}{} - fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error { - defer portForwardWG.Done() + fw.fakeKubelet.getPortForwardCheck = func(name, namespace string, uid types.UID, opts portforward.V4Options) { + assert.Equal(t, podName, name, "pod name") + assert.Equal(t, podNamespace, namespace, "pod namespace") + } - if e, a := expectedPodName, name; e != a { - t.Fatalf("%d: pod name: expected '%v', got '%v'", port, e, a) - } + ss.fakeRuntime.portForwardFunc = func(podSandboxID string, port int32, stream io.ReadWriteCloser) error { + defer portForwardWG.Done() + assert.Equal(t, testPodSandboxID, podSandboxID, "pod sandbox id") portsMutex.Lock() portsForwarded[port] = struct{}{} @@ -218,17 +182,11 @@ func TestServeWSMultiplePortForward(t *testing.T) { fromClient := make([]byte, 32) n, err := stream.Read(fromClient) - if err != nil { - t.Fatalf("%d: error reading client data: %v", port, err) - } - if e, a := fmt.Sprintf("client data on port %d", port), string(fromClient[0:n]); e != a { - t.Fatalf("%d: client data: expected to receive '%v', got '%v'", port, e, a) - } + assert.NoError(t, err, "reading client data") + assert.Equal(t, fmt.Sprintf("client data on port %d", port), string(fromClient[0:n]), "client data") _, err = stream.Write([]byte(fmt.Sprintf("container data on port %d", port))) - if err != nil { - t.Fatalf("%d: error writing container data: %v", port, err) - } + assert.NoError(t, err, "writing container data") return nil } @@ -239,70 +197,42 @@ func TestServeWSMultiplePortForward(t *testing.T) { } ws, err := websocket.Dial(url, "", "http://127.0.0.1/") - if err != nil { - t.Fatalf("websocket dial unexpected err: %v", err) - } + require.NoError(t, err, "websocket dial") defer ws.Close() for i, port := range ports { channel, data, err := wsRead(ws) - if err != nil { - t.Fatalf("%d: read failed: expected no error: got %v", i, err) - } - if int(channel) != i*2+dataChannel { - t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, i*2+dataChannel) - } - if len(data) != binary.Size(port) { - t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(port)) - } - if e, a := port, binary.LittleEndian.Uint16(data); e != a { - t.Fatalf("%d: wrong data: got %q: expected %d", i, data, port) - } + assert.NoError(t, err, "port %d read", port) + assert.Equal(t, i*2+dataChannel, int(channel), "port %d channel", port) + assert.Len(t, data, binary.Size(port), "port %d data size", port) + assert.Equal(t, binary.LittleEndian.Uint16(data), port, "port %d data", port) channel, data, err = wsRead(ws) - if err != nil { - t.Fatalf("%d: read succeeded: expected no error: got %v", i, err) - } - if int(channel) != i*2+errorChannel { - t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, i*2+errorChannel) - } - if len(data) != binary.Size(port) { - t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(port)) - } - if e, a := port, binary.LittleEndian.Uint16(data); e != a { - t.Fatalf("%d: wrong data: got %q: expected %d", i, data, port) - } + assert.NoError(t, err, "port %d read", port) + assert.Equal(t, i*2+errorChannel, int(channel), "port %d channel", port) + assert.Len(t, data, binary.Size(port), "port %d data size", port) + assert.Equal(t, binary.LittleEndian.Uint16(data), port, "port %d data", port) } for i, port := range ports { - println("writing the client data", port) + t.Logf("port %d writing the client data", port) err := wsWrite(ws, byte(i*2+dataChannel), []byte(fmt.Sprintf("client data on port %d", port))) - if err != nil { - t.Fatalf("%d: unexpected error writing client data: %v", i, err) - } + assert.NoError(t, err, "port %d write client data", port) channel, data, err := wsRead(ws) - if err != nil { - t.Fatalf("%d: unexpected error reading container data: %v", i, err) - } - - if int(channel) != i*2+dataChannel { - t.Fatalf("%d: wrong channel: got %q: expected %q", port, channel, i*2+dataChannel) - } - if e, a := fmt.Sprintf("container data on port %d", port), string(data); e != a { - t.Fatalf("%d: expected to receive '%v' from container, got '%v'", i, e, a) - } + assert.NoError(t, err, "port %d read container data", port) + assert.Equal(t, i*2+dataChannel, int(channel), "port %d channel", port) + assert.Equal(t, fmt.Sprintf("container data on port %d", port), string(data), "port %d container data", port) } portForwardWG.Wait() portsMutex.Lock() defer portsMutex.Unlock() - if len(ports) != len(portsForwarded) { - t.Fatalf("expected to forward %d ports; got %v", len(ports), portsForwarded) - } + assert.Len(t, portsForwarded, len(ports), "all ports forwarded") } + func wsWrite(conn *websocket.Conn, channel byte, data []byte) error { frame := make([]byte, len(data)+1) frame[0] = channel diff --git a/pkg/kubelet/server/streaming/server.go b/pkg/kubelet/server/streaming/server.go index ae1c046b025..7cbc424c41e 100644 --- a/pkg/kubelet/server/streaming/server.go +++ b/pkg/kubelet/server/streaming/server.go @@ -20,6 +20,7 @@ import ( "crypto/tls" "errors" "io" + "net" "net/http" "net/url" "path" @@ -71,6 +72,7 @@ type Config struct { Addr string // The optional base URL for constructing streaming URLs. If empty, the baseURL will be // constructed from the serve address. + // Note that for port "0", the URL port will be set to actual port in use. BaseURL *url.URL // How long to leave idle connections open for. @@ -233,10 +235,16 @@ func (s *server) Start(stayUp bool) error { return errors.New("stayUp=false is not yet implemented") } + listener, err := net.Listen("tcp", s.config.Addr) + if err != nil { + return err + } + // Use the actual address as baseURL host. This handles the "0" port case. + s.config.BaseURL.Host = listener.Addr().String() if s.config.TLSConfig != nil { - return s.server.ListenAndServeTLS("", "") // Use certs from TLSConfig. + return s.server.ServeTLS(listener, "", "") // Use certs from TLSConfig. } else { - return s.server.ListenAndServe() + return s.server.Serve(listener) } }