Merge pull request #33684 from fraenkel/port_forward_ws

Automatic merge from submit-queue

Add websocket support for port forwarding

#32880

**Release note**:
```release-note
Port forwarding can forward over websockets or SPDY.
```
This commit is contained in:
Kubernetes Submit Queue 2017-02-01 23:19:02 -08:00 committed by GitHub
commit 0477100f98
35 changed files with 1409 additions and 413 deletions

View File

@ -3798,6 +3798,13 @@
"name": "namespace",
"in": "path",
"required": true
},
{
"uniqueItems": true,
"type": "integer",
"description": "List of ports to forward Required when using WebSockets",
"name": "ports",
"in": "query"
}
]
},

View File

@ -9095,6 +9095,14 @@
"summary": "connect GET requests to portforward of Pod",
"nickname": "connectGetNamespacedPodPortforward",
"parameters": [
{
"type": "integer",
"paramType": "query",
"name": "ports",
"description": "List of ports to forward Required when using WebSockets",
"required": false,
"allowMultiple": false
},
{
"type": "string",
"paramType": "path",
@ -9125,6 +9133,14 @@
"summary": "connect POST requests to portforward of Pod",
"nickname": "connectPostNamespacedPodPortforward",
"parameters": [
{
"type": "integer",
"paramType": "query",
"name": "ports",
"description": "List of ports to forward Required when using WebSockets",
"required": false,
"allowMultiple": false
},
{
"type": "string",
"paramType": "path",

View File

@ -27,7 +27,7 @@ spec:
command:
- /bin/sh
- -c
- "for i in gcr.io/google_containers/busybox gcr.io/google_containers/busybox:1.24 gcr.io/google_containers/dnsutils:e2e gcr.io/google_containers/eptest:0.1 gcr.io/google_containers/fakegitserver:0.1 gcr.io/google_containers/hostexec:1.2 gcr.io/google_containers/iperf:e2e gcr.io/google_containers/jessie-dnsutils:e2e gcr.io/google_containers/liveness:e2e gcr.io/google_containers/mounttest:0.7 gcr.io/google_containers/mounttest-user:0.3 gcr.io/google_containers/netexec:1.4 gcr.io/google_containers/netexec:1.7 gcr.io/google_containers/nettest:1.7 gcr.io/google_containers/nettest:1.8 gcr.io/google_containers/nginx-slim:0.7 gcr.io/google_containers/nginx-slim:0.8 gcr.io/google_containers/n-way-http:1.0 gcr.io/google_containers/pause:2.0 gcr.io/google_containers/pause-amd64:3.0 gcr.io/google_containers/porter:cd5cb5791ebaa8641955f0e8c2a9bed669b1eaab gcr.io/google_containers/portforwardtester:1.0 gcr.io/google_containers/redis:e2e gcr.io/google_containers/resource_consumer:beta4 gcr.io/google_containers/resource_consumer/controller:beta4 gcr.io/google_containers/serve_hostname:v1.4 gcr.io/google_containers/test-webserver:e2e gcr.io/google_containers/ubuntu:14.04 gcr.io/google_containers/update-demo:kitten gcr.io/google_containers/update-demo:nautilus gcr.io/google_containers/volume-ceph:0.1 gcr.io/google_containers/volume-gluster:0.2 gcr.io/google_containers/volume-iscsi:0.1 gcr.io/google_containers/volume-nfs:0.6 gcr.io/google_containers/volume-rbd:0.1 gcr.io/google_samples/gb-redisslave:v1 gcr.io/google_containers/redis:v1; do echo $(date '+%X') pulling $i; docker pull $i 1>/dev/null; done; exit 0;"
- "for i in gcr.io/google_containers/busybox gcr.io/google_containers/busybox:1.24 gcr.io/google_containers/dnsutils:e2e gcr.io/google_containers/eptest:0.1 gcr.io/google_containers/fakegitserver:0.1 gcr.io/google_containers/hostexec:1.2 gcr.io/google_containers/iperf:e2e gcr.io/google_containers/jessie-dnsutils:e2e gcr.io/google_containers/liveness:e2e gcr.io/google_containers/mounttest:0.7 gcr.io/google_containers/mounttest-user:0.3 gcr.io/google_containers/netexec:1.4 gcr.io/google_containers/netexec:1.7 gcr.io/google_containers/nettest:1.7 gcr.io/google_containers/nettest:1.8 gcr.io/google_containers/nginx-slim:0.7 gcr.io/google_containers/nginx-slim:0.8 gcr.io/google_containers/n-way-http:1.0 gcr.io/google_containers/pause:2.0 gcr.io/google_containers/pause-amd64:3.0 gcr.io/google_containers/porter:cd5cb5791ebaa8641955f0e8c2a9bed669b1eaab gcr.io/google_containers/portforwardtester:1.2 gcr.io/google_containers/redis:e2e gcr.io/google_containers/resource_consumer:beta4 gcr.io/google_containers/resource_consumer/controller:beta4 gcr.io/google_containers/serve_hostname:v1.4 gcr.io/google_containers/test-webserver:e2e gcr.io/google_containers/ubuntu:14.04 gcr.io/google_containers/update-demo:kitten gcr.io/google_containers/update-demo:nautilus gcr.io/google_containers/volume-ceph:0.1 gcr.io/google_containers/volume-gluster:0.2 gcr.io/google_containers/volume-iscsi:0.1 gcr.io/google_containers/volume-nfs:0.6 gcr.io/google_containers/volume-rbd:0.1 gcr.io/google_samples/gb-redisslave:v1 gcr.io/google_containers/redis:v1; do echo $(date '+%X') pulling $i; docker pull $i 1>/dev/null; done; exit 0;"
securityContext:
privileged: true
volumeMounts:

View File

@ -9047,6 +9047,14 @@ span.icon > [class^="icon-"], span.icon > [class*=" icon-"] { cursor: default; }
</thead>
<tbody>
<tr>
<td class="tableblock halign-left valign-top"><p class="tableblock">QueryParameter</p></td>
<td class="tableblock halign-left valign-top"><p class="tableblock">ports</p></td>
<td class="tableblock halign-left valign-top"><p class="tableblock">List of ports to forward Required when using WebSockets</p></td>
<td class="tableblock halign-left valign-top"><p class="tableblock">false</p></td>
<td class="tableblock halign-left valign-top"><p class="tableblock">integer (int32)</p></td>
<td class="tableblock halign-left valign-top"></td>
</tr>
<tr>
<td class="tableblock halign-left valign-top"><p class="tableblock">PathParameter</p></td>
<td class="tableblock halign-left valign-top"><p class="tableblock">namespace</p></td>
<td class="tableblock halign-left valign-top"><p class="tableblock">object name and auth scope, such as for teams and projects</p></td>
@ -9152,6 +9160,14 @@ span.icon > [class^="icon-"], span.icon > [class*=" icon-"] { cursor: default; }
</thead>
<tbody>
<tr>
<td class="tableblock halign-left valign-top"><p class="tableblock">QueryParameter</p></td>
<td class="tableblock halign-left valign-top"><p class="tableblock">ports</p></td>
<td class="tableblock halign-left valign-top"><p class="tableblock">List of ports to forward Required when using WebSockets</p></td>
<td class="tableblock halign-left valign-top"><p class="tableblock">false</p></td>
<td class="tableblock halign-left valign-top"><p class="tableblock">integer (int32)</p></td>
<td class="tableblock halign-left valign-top"></td>
</tr>
<tr>
<td class="tableblock halign-left valign-top"><p class="tableblock">PathParameter</p></td>
<td class="tableblock halign-left valign-top"><p class="tableblock">namespace</p></td>
<td class="tableblock halign-left valign-top"><p class="tableblock">object name and auth scope, such as for teams and projects</p></td>
@ -33308,7 +33324,7 @@ span.icon > [class^="icon-"], span.icon > [class*=" icon-"] { cursor: default; }
</div>
<div id="footer">
<div id="footer-text">
Last updated 2017-01-06 18:13:51 UTC
Last updated 2017-02-01 12:44:12 UTC
</div>
</div>
</body>

View File

@ -42,16 +42,16 @@ import (
type fakePortForwarder struct {
lock sync.Mutex
// stores data expected from the stream per port
expected map[uint16]string
expected map[int32]string
// stores data received from the stream per port
received map[uint16]string
received map[int32]string
// data to be sent to the stream per port
send map[uint16]string
send map[int32]string
}
var _ portforward.PortForwarder = &fakePortForwarder{}
func (pf *fakePortForwarder) PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error {
func (pf *fakePortForwarder) PortForward(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error {
defer stream.Close()
// read from the client
@ -77,14 +77,14 @@ func (pf *fakePortForwarder) PortForward(name string, uid types.UID, port uint16
// fakePortForwardServer creates an HTTP server that can handle port forwarding
// requests.
func fakePortForwardServer(t *testing.T, testName string, serverSends, expectedFromClient map[uint16]string) http.HandlerFunc {
func fakePortForwardServer(t *testing.T, testName string, serverSends, expectedFromClient map[int32]string) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
pf := &fakePortForwarder{
expected: expectedFromClient,
received: make(map[uint16]string),
received: make(map[int32]string),
send: serverSends,
}
portforward.ServePortForward(w, req, pf, "pod", "uid", 0, 10*time.Second)
portforward.ServePortForward(w, req, pf, "pod", "uid", nil, 0, 10*time.Second, portforward.SupportedProtocols)
for port, expected := range expectedFromClient {
actual, ok := pf.received[port]
@ -109,19 +109,19 @@ func fakePortForwardServer(t *testing.T, testName string, serverSends, expectedF
func TestForwardPorts(t *testing.T) {
tests := map[string]struct {
ports []string
clientSends map[uint16]string
serverSends map[uint16]string
clientSends map[int32]string
serverSends map[int32]string
}{
"forward 1 port with no data either direction": {
ports: []string{"5000"},
},
"forward 2 ports with bidirectional data": {
ports: []string{"5001", "6000"},
clientSends: map[uint16]string{
clientSends: map[int32]string{
5001: "abcd",
6000: "ghij",
},
serverSends: map[uint16]string{
serverSends: map[int32]string{
5001: "1234",
6000: "5678",
},

View File

@ -1047,10 +1047,14 @@ func typeToJSON(typeName string) string {
return "string"
case "byte", "*byte":
return "string"
// TODO: Fix these when go-restful supports a way to specify an array query param:
// https://github.com/emicklei/go-restful/issues/225
case "[]string", "[]*string":
// TODO: Fix this when go-restful supports a way to specify an array query param:
// https://github.com/emicklei/go-restful/issues/225
return "string"
case "[]int32", "[]*int32":
return "integer"
default:
return typeName
}

View File

@ -72,6 +72,7 @@ go_library(
"//pkg/kubelet/rkt:go_default_library",
"//pkg/kubelet/secret:go_default_library",
"//pkg/kubelet/server:go_default_library",
"//pkg/kubelet/server/portforward:go_default_library",
"//pkg/kubelet/server/remotecommand:go_default_library",
"//pkg/kubelet/server/stats:go_default_library",
"//pkg/kubelet/server/streaming:go_default_library",
@ -177,6 +178,7 @@ go_test(
"//pkg/kubelet/prober/results:go_default_library",
"//pkg/kubelet/prober/testing:go_default_library",
"//pkg/kubelet/secret:go_default_library",
"//pkg/kubelet/server/portforward:go_default_library",
"//pkg/kubelet/server/remotecommand:go_default_library",
"//pkg/kubelet/server/stats:go_default_library",
"//pkg/kubelet/status:go_default_library",

View File

@ -130,7 +130,7 @@ type DirectStreamingRuntime interface {
// tty.
ExecInContainer(containerID ContainerID, cmd []string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan term.Size, timeout time.Duration) error
// Forward the specified port from the specified pod to the stream.
PortForward(pod *Pod, port uint16, stream io.ReadWriteCloser) error
PortForward(pod *Pod, port int32, stream io.ReadWriteCloser) error
// ContainerAttach encapsulates the attaching to containers for testability
ContainerAttacher
}
@ -141,7 +141,7 @@ type DirectStreamingRuntime interface {
type IndirectStreamingRuntime interface {
GetExec(id ContainerID, cmd []string, stdin, stdout, stderr, tty bool) (*url.URL, error)
GetAttach(id ContainerID, stdin, stdout, stderr, tty bool) (*url.URL, error)
GetPortForward(podName, podNamespace string, podUID types.UID) (*url.URL, error)
GetPortForward(podName, podNamespace string, podUID types.UID, ports []int32) (*url.URL, error)
}
type ImageService interface {

View File

@ -73,7 +73,7 @@ type FakeDirectStreamingRuntime struct {
TTY bool
// Port-forward args
Pod *Pod
Port uint16
Port int32
Stream io.ReadWriteCloser
}
}
@ -394,7 +394,7 @@ func (f *FakeRuntime) RemoveImage(image ImageSpec) error {
return f.Err
}
func (f *FakeDirectStreamingRuntime) PortForward(pod *Pod, port uint16, stream io.ReadWriteCloser) error {
func (f *FakeDirectStreamingRuntime) PortForward(pod *Pod, port int32, stream io.ReadWriteCloser) error {
f.Lock()
defer f.Unlock()
@ -471,7 +471,7 @@ func (f *FakeIndirectStreamingRuntime) GetAttach(id ContainerID, stdin, stdout,
return &url.URL{Host: FakeHost}, f.Err
}
func (f *FakeIndirectStreamingRuntime) GetPortForward(podName, podNamespace string, podUID types.UID) (*url.URL, error) {
func (f *FakeIndirectStreamingRuntime) GetPortForward(podName, podNamespace string, podUID types.UID, ports []int32) (*url.URL, error) {
f.Lock()
defer f.Unlock()

View File

@ -64,7 +64,7 @@ func (r *streamingRuntime) PortForward(podSandboxID string, port int32, stream i
if port < 0 || port > math.MaxUint16 {
return fmt.Errorf("invalid port %d", port)
}
return dockertools.PortForward(r.client, podSandboxID, uint16(port), stream)
return dockertools.PortForward(r.client, podSandboxID, port, stream)
}
// ExecSync executes a command in the container, and returns the stdout output.

View File

@ -1354,7 +1354,7 @@ func noPodInfraContainerError(podName, podNamespace string) error {
// - match cgroups of container
// - should we support nsenter + socat on the host? (current impl)
// - should we support nsenter + socat in a container, running with elevated privs and --pid=host?
func (dm *DockerManager) PortForward(pod *kubecontainer.Pod, port uint16, stream io.ReadWriteCloser) error {
func (dm *DockerManager) PortForward(pod *kubecontainer.Pod, port int32, stream io.ReadWriteCloser) error {
podInfraContainer := pod.FindContainerByName(PodInfraContainerName)
if podInfraContainer == nil {
return noPodInfraContainerError(pod.Name, pod.Namespace)
@ -1370,7 +1370,7 @@ func (dm *DockerManager) UpdatePodCIDR(podCIDR string) error {
}
// Temporarily export this function to share with dockershim.
func PortForward(client DockerInterface, podInfraContainerID string, port uint16, stream io.ReadWriteCloser) error {
func PortForward(client DockerInterface, podInfraContainerID string, port int32, stream io.ReadWriteCloser) error {
container, err := client.InspectContainer(podInfraContainerID)
if err != nil {
return err

View File

@ -2171,9 +2171,10 @@ func getStreamingConfig(kubeCfg *componentconfig.KubeletConfiguration, kubeDeps
BaseURL: &url.URL{
Path: "/cri/",
},
StreamIdleTimeout: kubeCfg.StreamingConnectionIdleTimeout.Duration,
StreamCreationTimeout: streaming.DefaultConfig.StreamCreationTimeout,
SupportedProtocols: streaming.DefaultConfig.SupportedProtocols,
StreamIdleTimeout: kubeCfg.StreamingConnectionIdleTimeout.Duration,
StreamCreationTimeout: streaming.DefaultConfig.StreamCreationTimeout,
SupportedRemoteCommandProtocols: streaming.DefaultConfig.SupportedRemoteCommandProtocols,
SupportedPortForwardProtocols: streaming.DefaultConfig.SupportedPortForwardProtocols,
}
if kubeDeps.TLSOptions != nil {
config.TLSConfig = kubeDeps.TLSOptions.Config

View File

@ -50,6 +50,7 @@ import (
"k8s.io/kubernetes/pkg/kubelet/envvars"
"k8s.io/kubernetes/pkg/kubelet/images"
"k8s.io/kubernetes/pkg/kubelet/qos"
"k8s.io/kubernetes/pkg/kubelet/server/portforward"
"k8s.io/kubernetes/pkg/kubelet/server/remotecommand"
"k8s.io/kubernetes/pkg/kubelet/status"
kubetypes "k8s.io/kubernetes/pkg/kubelet/types"
@ -1394,7 +1395,7 @@ func (kl *Kubelet) AttachContainer(podFullName string, podUID types.UID, contain
// PortForward connects to the pod's port and copies data between the port
// and the stream.
func (kl *Kubelet) PortForward(podFullName string, podUID types.UID, port uint16, stream io.ReadWriteCloser) error {
func (kl *Kubelet) PortForward(podFullName string, podUID types.UID, port int32, stream io.ReadWriteCloser) error {
streamingRuntime, ok := kl.containerRuntime.(kubecontainer.DirectStreamingRuntime)
if !ok {
return fmt.Errorf("streaming methods not supported by runtime")
@ -1467,7 +1468,7 @@ func (kl *Kubelet) GetAttach(podFullName string, podUID types.UID, containerName
}
// GetPortForward gets the URL the port-forward will be served from, or nil if the Kubelet will serve it.
func (kl *Kubelet) GetPortForward(podName, podNamespace string, podUID types.UID) (*url.URL, error) {
func (kl *Kubelet) GetPortForward(podName, podNamespace string, podUID types.UID, portForwardOpts portforward.V4Options) (*url.URL, error) {
switch streamingRuntime := kl.containerRuntime.(type) {
case kubecontainer.DirectStreamingRuntime:
// Kubelet will serve the attach directly.
@ -1484,7 +1485,7 @@ func (kl *Kubelet) GetPortForward(podName, podNamespace string, podUID types.UID
return nil, fmt.Errorf("pod not found (%q)", podFullName)
}
return streamingRuntime.GetPortForward(podName, podNamespace, podUID)
return streamingRuntime.GetPortForward(podName, podNamespace, podUID, portForwardOpts.Ports)
default:
return nil, fmt.Errorf("container runtime does not support port-forward")
}

View File

@ -37,6 +37,7 @@ import (
"k8s.io/kubernetes/pkg/api/v1"
kubecontainer "k8s.io/kubernetes/pkg/kubelet/container"
containertest "k8s.io/kubernetes/pkg/kubelet/container/testing"
"k8s.io/kubernetes/pkg/kubelet/server/portforward"
"k8s.io/kubernetes/pkg/kubelet/server/remotecommand"
)
@ -1607,7 +1608,7 @@ func TestPortForward(t *testing.T) {
podName = "podFoo"
podNamespace = "nsFoo"
podUID types.UID = "12345678"
port uint16 = 5000
port int32 = 5000
)
var (
stream = &fakeReadWriteCloser{}
@ -1646,7 +1647,7 @@ func TestPortForward(t *testing.T) {
podFullName := kubecontainer.GetPodFullName(podWithUidNameNs(podUID, tc.podName, podNamespace))
{ // No streaming case
description := "no streaming - " + tc.description
redirect, err := kubelet.GetPortForward(tc.podName, podNamespace, podUID)
redirect, err := kubelet.GetPortForward(tc.podName, podNamespace, podUID, portforward.V4Options{})
assert.Error(t, err, description)
assert.Nil(t, redirect, description)
@ -1658,7 +1659,7 @@ func TestPortForward(t *testing.T) {
fakeRuntime := &containertest.FakeDirectStreamingRuntime{FakeRuntime: testKubelet.fakeRuntime}
kubelet.containerRuntime = fakeRuntime
redirect, err := kubelet.GetPortForward(tc.podName, podNamespace, podUID)
redirect, err := kubelet.GetPortForward(tc.podName, podNamespace, podUID, portforward.V4Options{})
assert.NoError(t, err, description)
assert.Nil(t, redirect, description)
@ -1677,7 +1678,7 @@ func TestPortForward(t *testing.T) {
fakeRuntime := &containertest.FakeIndirectStreamingRuntime{FakeRuntime: testKubelet.fakeRuntime}
kubelet.containerRuntime = fakeRuntime
redirect, err := kubelet.GetPortForward(tc.podName, podNamespace, podUID)
redirect, err := kubelet.GetPortForward(tc.podName, podNamespace, podUID, portforward.V4Options{})
if tc.expectError {
assert.Error(t, err, description)
} else {

View File

@ -237,7 +237,7 @@ func (m *kubeGenericRuntimeManager) getSandboxIDByPodUID(podUID kubetypes.UID, s
}
// GetPortForward gets the endpoint the runtime will serve the port-forward request from.
func (m *kubeGenericRuntimeManager) GetPortForward(podName, podNamespace string, podUID kubetypes.UID) (*url.URL, error) {
func (m *kubeGenericRuntimeManager) GetPortForward(podName, podNamespace string, podUID kubetypes.UID, ports []int32) (*url.URL, error) {
sandboxIDs, err := m.getSandboxIDByPodUID(podUID, nil)
if err != nil {
return nil, fmt.Errorf("failed to find sandboxID for pod %s: %v", format.PodDesc(podName, podNamespace, podUID), err)
@ -245,9 +245,9 @@ func (m *kubeGenericRuntimeManager) GetPortForward(podName, podNamespace string,
if len(sandboxIDs) == 0 {
return nil, fmt.Errorf("failed to find sandboxID for pod %s", format.PodDesc(podName, podNamespace, podUID))
}
// TODO: Port is unused for now, but we may need it in the future.
req := &runtimeapi.PortForwardRequest{
PodSandboxId: sandboxIDs[0],
Port: ports,
}
resp, err := m.runtimeService.PortForward(req)
if err != nil {

View File

@ -2107,7 +2107,7 @@ func (r *Runtime) ExecInContainer(containerID kubecontainer.ContainerID, cmd []s
// - should we support nsenter + socat in a container, running with elevated privs and --pid=host?
//
// TODO(yifan): Merge with the same function in dockertools.
func (r *Runtime) PortForward(pod *kubecontainer.Pod, port uint16, stream io.ReadWriteCloser) error {
func (r *Runtime) PortForward(pod *kubecontainer.Pod, port int32, stream io.ReadWriteCloser) error {
glog.V(4).Infof("Rkt port forwarding in container.")
ctx, cancel := context.WithTimeout(context.Background(), r.requestTimeout)

View File

@ -55,6 +55,7 @@ go_test(
srcs = [
"auth_test.go",
"server_test.go",
"server_websocket_test.go",
],
library = ":go_default_library",
tags = ["automanaged"],
@ -64,6 +65,7 @@ go_test(
"//pkg/kubelet/cm:go_default_library",
"//pkg/kubelet/container:go_default_library",
"//pkg/kubelet/container/testing:go_default_library",
"//pkg/kubelet/server/portforward:go_default_library",
"//pkg/kubelet/server/remotecommand:go_default_library",
"//pkg/kubelet/server/stats:go_default_library",
"//pkg/util/term:go_default_library",
@ -72,6 +74,7 @@ go_test(
"//vendor:github.com/google/cadvisor/info/v2",
"//vendor:github.com/stretchr/testify/assert",
"//vendor:github.com/stretchr/testify/require",
"//vendor:golang.org/x/net/websocket",
"//vendor:k8s.io/apimachinery/pkg/api/errors",
"//vendor:k8s.io/apimachinery/pkg/apis/meta/v1",
"//vendor:k8s.io/apimachinery/pkg/types",

View File

@ -12,7 +12,9 @@ go_library(
name = "go_default_library",
srcs = [
"constants.go",
"httpstream.go",
"portforward.go",
"websocket.go",
],
tags = ["automanaged"],
deps = [
@ -22,12 +24,17 @@ go_library(
"//vendor:k8s.io/apimachinery/pkg/util/httpstream",
"//vendor:k8s.io/apimachinery/pkg/util/httpstream/spdy",
"//vendor:k8s.io/apimachinery/pkg/util/runtime",
"//vendor:k8s.io/apiserver/pkg/server/httplog",
"//vendor:k8s.io/apiserver/pkg/util/wsstream",
],
)
go_test(
name = "go_default_test",
srcs = ["portforward_test.go"],
srcs = [
"httpstream_test.go",
"websocket_test.go",
],
library = ":go_default_library",
tags = ["automanaged"],
deps = [

View File

@ -18,4 +18,6 @@ limitations under the License.
package portforward
// The subprotocol "portforward.k8s.io" is used for port forwarding.
const PortForwardProtocolV1Name = "portforward.k8s.io"
const ProtocolV1Name = "portforward.k8s.io"
var SupportedProtocols = []string{ProtocolV1Name}

View File

@ -0,0 +1,309 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package portforward
import (
"errors"
"fmt"
"net/http"
"strconv"
"sync"
"time"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/kubernetes/pkg/api"
"github.com/golang/glog"
)
func handleHttpStreams(req *http.Request, w http.ResponseWriter, portForwarder PortForwarder, podName string, uid types.UID, supportedPortForwardProtocols []string, idleTimeout, streamCreationTimeout time.Duration) error {
_, err := httpstream.Handshake(req, w, supportedPortForwardProtocols)
// negotiated protocol isn't currently used server side, but could be in the future
if err != nil {
// Handshake writes the error to the client
return err
}
streamChan := make(chan httpstream.Stream, 1)
glog.V(5).Infof("Upgrading port forward response")
upgrader := spdy.NewResponseUpgrader()
conn := upgrader.UpgradeResponse(w, req, httpStreamReceived(streamChan))
if conn == nil {
return errors.New("Unable to upgrade websocket connection")
}
defer conn.Close()
glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout)
conn.SetIdleTimeout(idleTimeout)
h := &httpStreamHandler{
conn: conn,
streamChan: streamChan,
streamPairs: make(map[string]*httpStreamPair),
streamCreationTimeout: streamCreationTimeout,
pod: podName,
uid: uid,
forwarder: portForwarder,
}
h.run()
return nil
}
// httpStreamReceived is the httpstream.NewStreamHandler for port
// forward streams. It checks each stream's port and stream type headers,
// rejecting any streams that with missing or invalid values. Each valid
// stream is sent to the streams channel.
func httpStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error {
return func(stream httpstream.Stream, replySent <-chan struct{}) error {
// make sure it has a valid port header
portString := stream.Headers().Get(api.PortHeader)
if len(portString) == 0 {
return fmt.Errorf("%q header is required", api.PortHeader)
}
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return fmt.Errorf("unable to parse %q as a port: %v", portString, err)
}
if port < 1 {
return fmt.Errorf("port %q must be > 0", portString)
}
// make sure it has a valid stream type header
streamType := stream.Headers().Get(api.StreamType)
if len(streamType) == 0 {
return fmt.Errorf("%q header is required", api.StreamType)
}
if streamType != api.StreamTypeError && streamType != api.StreamTypeData {
return fmt.Errorf("invalid stream type %q", streamType)
}
streams <- stream
return nil
}
}
// httpStreamHandler is capable of processing multiple port forward
// requests over a single httpstream.Connection.
type httpStreamHandler struct {
conn httpstream.Connection
streamChan chan httpstream.Stream
streamPairsLock sync.RWMutex
streamPairs map[string]*httpStreamPair
streamCreationTimeout time.Duration
pod string
uid types.UID
forwarder PortForwarder
}
// getStreamPair returns a httpStreamPair for requestID. This creates a
// new pair if one does not yet exist for the requestID. The returned bool is
// true if the pair was created.
func (h *httpStreamHandler) getStreamPair(requestID string) (*httpStreamPair, bool) {
h.streamPairsLock.Lock()
defer h.streamPairsLock.Unlock()
if p, ok := h.streamPairs[requestID]; ok {
glog.V(5).Infof("(conn=%p, request=%s) found existing stream pair", h.conn, requestID)
return p, false
}
glog.V(5).Infof("(conn=%p, request=%s) creating new stream pair", h.conn, requestID)
p := newPortForwardPair(requestID)
h.streamPairs[requestID] = p
return p, true
}
// monitorStreamPair waits for the pair to receive both its error and data
// streams, or for the timeout to expire (whichever happens first), and then
// removes the pair.
func (h *httpStreamHandler) monitorStreamPair(p *httpStreamPair, timeout <-chan time.Time) {
select {
case <-timeout:
err := fmt.Errorf("(conn=%v, request=%s) timed out waiting for streams", h.conn, p.requestID)
utilruntime.HandleError(err)
p.printError(err.Error())
case <-p.complete:
glog.V(5).Infof("(conn=%v, request=%s) successfully received error and data streams", h.conn, p.requestID)
}
h.removeStreamPair(p.requestID)
}
// hasStreamPair returns a bool indicating if a stream pair for requestID
// exists.
func (h *httpStreamHandler) hasStreamPair(requestID string) bool {
h.streamPairsLock.RLock()
defer h.streamPairsLock.RUnlock()
_, ok := h.streamPairs[requestID]
return ok
}
// removeStreamPair removes the stream pair identified by requestID from streamPairs.
func (h *httpStreamHandler) removeStreamPair(requestID string) {
h.streamPairsLock.Lock()
defer h.streamPairsLock.Unlock()
delete(h.streamPairs, requestID)
}
// requestID returns the request id for stream.
func (h *httpStreamHandler) requestID(stream httpstream.Stream) string {
requestID := stream.Headers().Get(api.PortForwardRequestIDHeader)
if len(requestID) == 0 {
glog.V(5).Infof("(conn=%p) stream received without %s header", h.conn, api.PortForwardRequestIDHeader)
// If we get here, it's because the connection came from an older client
// that isn't generating the request id header
// (https://github.com/kubernetes/kubernetes/blob/843134885e7e0b360eb5441e85b1410a8b1a7a0c/pkg/client/unversioned/portforward/portforward.go#L258-L287)
//
// This is a best-effort attempt at supporting older clients.
//
// When there aren't concurrent new forwarded connections, each connection
// will have a pair of streams (data, error), and the stream IDs will be
// consecutive odd numbers, e.g. 1 and 3 for the first connection. Convert
// the stream ID into a pseudo-request id by taking the stream type and
// using id = stream.Identifier() when the stream type is error,
// and id = stream.Identifier() - 2 when it's data.
//
// NOTE: this only works when there are not concurrent new streams from
// multiple forwarded connections; it's a best-effort attempt at supporting
// old clients that don't generate request ids. If there are concurrent
// new connections, it's possible that 1 connection gets streams whose IDs
// are not consecutive (e.g. 5 and 9 instead of 5 and 7).
streamType := stream.Headers().Get(api.StreamType)
switch streamType {
case api.StreamTypeError:
requestID = strconv.Itoa(int(stream.Identifier()))
case api.StreamTypeData:
requestID = strconv.Itoa(int(stream.Identifier()) - 2)
}
glog.V(5).Infof("(conn=%p) automatically assigning request ID=%q from stream type=%s, stream ID=%d", h.conn, requestID, streamType, stream.Identifier())
}
return requestID
}
// run is the main loop for the httpStreamHandler. It processes new
// streams, invoking portForward for each complete stream pair. The loop exits
// when the httpstream.Connection is closed.
func (h *httpStreamHandler) run() {
glog.V(5).Infof("(conn=%p) waiting for port forward streams", h.conn)
Loop:
for {
select {
case <-h.conn.CloseChan():
glog.V(5).Infof("(conn=%p) upgraded connection closed", h.conn)
break Loop
case stream := <-h.streamChan:
requestID := h.requestID(stream)
streamType := stream.Headers().Get(api.StreamType)
glog.V(5).Infof("(conn=%p, request=%s) received new stream of type %s", h.conn, requestID, streamType)
p, created := h.getStreamPair(requestID)
if created {
go h.monitorStreamPair(p, time.After(h.streamCreationTimeout))
}
if complete, err := p.add(stream); err != nil {
msg := fmt.Sprintf("error processing stream for request %s: %v", requestID, err)
utilruntime.HandleError(errors.New(msg))
p.printError(msg)
} else if complete {
go h.portForward(p)
}
}
}
}
// portForward invokes the httpStreamHandler's forwarder.PortForward
// function for the given stream pair.
func (h *httpStreamHandler) portForward(p *httpStreamPair) {
defer p.dataStream.Close()
defer p.errorStream.Close()
portString := p.dataStream.Headers().Get(api.PortHeader)
port, _ := strconv.ParseInt(portString, 10, 32)
glog.V(5).Infof("(conn=%p, request=%s) invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
err := h.forwarder.PortForward(h.pod, h.uid, int32(port), p.dataStream)
glog.V(5).Infof("(conn=%p, request=%s) done invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
if err != nil {
msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", port, h.pod, h.uid, err)
utilruntime.HandleError(msg)
fmt.Fprint(p.errorStream, msg.Error())
}
}
// httpStreamPair represents the error and data streams for a port
// forwarding request.
type httpStreamPair struct {
lock sync.RWMutex
requestID string
dataStream httpstream.Stream
errorStream httpstream.Stream
complete chan struct{}
}
// newPortForwardPair creates a new httpStreamPair.
func newPortForwardPair(requestID string) *httpStreamPair {
return &httpStreamPair{
requestID: requestID,
complete: make(chan struct{}),
}
}
// add adds the stream to the httpStreamPair. If the pair already
// contains a stream for the new stream's type, an error is returned. add
// returns true if both the data and error streams for this pair have been
// received.
func (p *httpStreamPair) add(stream httpstream.Stream) (bool, error) {
p.lock.Lock()
defer p.lock.Unlock()
switch stream.Headers().Get(api.StreamType) {
case api.StreamTypeError:
if p.errorStream != nil {
return false, errors.New("error stream already assigned")
}
p.errorStream = stream
case api.StreamTypeData:
if p.dataStream != nil {
return false, errors.New("data stream already assigned")
}
p.dataStream = stream
}
complete := p.errorStream != nil && p.dataStream != nil
if complete {
close(p.complete)
}
return complete, nil
}
// printError writes s to p.errorStream if p.errorStream has been set.
func (p *httpStreamPair) printError(s string) {
p.lock.RLock()
defer p.lock.RUnlock()
if p.errorStream != nil {
fmt.Fprint(p.errorStream, s)
}
}

View File

@ -25,7 +25,7 @@ import (
"k8s.io/kubernetes/pkg/api"
)
func TestPortForwardStreamReceived(t *testing.T) {
func TestHTTPStreamReceived(t *testing.T) {
tests := map[string]struct {
port string
streamType string
@ -62,7 +62,7 @@ func TestPortForwardStreamReceived(t *testing.T) {
}
for name, test := range tests {
streams := make(chan httpstream.Stream, 1)
f := portForwardStreamReceived(streams)
f := httpStreamReceived(streams)
stream := newFakeHttpStream()
if len(test.port) > 0 {
stream.headers.Set("port", test.port)
@ -92,48 +92,11 @@ func TestPortForwardStreamReceived(t *testing.T) {
}
}
type fakeHttpStream struct {
headers http.Header
id uint32
}
func newFakeHttpStream() *fakeHttpStream {
return &fakeHttpStream{
headers: make(http.Header),
}
}
var _ httpstream.Stream = &fakeHttpStream{}
func (s *fakeHttpStream) Read(data []byte) (int, error) {
return 0, nil
}
func (s *fakeHttpStream) Write(data []byte) (int, error) {
return 0, nil
}
func (s *fakeHttpStream) Close() error {
return nil
}
func (s *fakeHttpStream) Reset() error {
return nil
}
func (s *fakeHttpStream) Headers() http.Header {
return s.headers
}
func (s *fakeHttpStream) Identifier() uint32 {
return s.id
}
func TestGetStreamPair(t *testing.T) {
timeout := make(chan time.Time)
h := &portForwardStreamHandler{
streamPairs: make(map[string]*portForwardStreamPair),
h := &httpStreamHandler{
streamPairs: make(map[string]*httpStreamPair),
}
// test adding a new entry
@ -223,7 +186,7 @@ func TestGetStreamPair(t *testing.T) {
}
func TestRequestID(t *testing.T) {
h := &portForwardStreamHandler{}
h := &httpStreamHandler{}
s := newFakeHttpStream()
s.headers.Set(api.StreamType, api.StreamTypeError)
@ -244,3 +207,40 @@ func TestRequestID(t *testing.T) {
t.Errorf("expected %q, got %q", e, a)
}
}
type fakeHttpStream struct {
headers http.Header
id uint32
}
func newFakeHttpStream() *fakeHttpStream {
return &fakeHttpStream{
headers: make(http.Header),
}
}
var _ httpstream.Stream = &fakeHttpStream{}
func (s *fakeHttpStream) Read(data []byte) (int, error) {
return 0, nil
}
func (s *fakeHttpStream) Write(data []byte) (int, error) {
return 0, nil
}
func (s *fakeHttpStream) Close() error {
return nil
}
func (s *fakeHttpStream) Reset() error {
return nil
}
func (s *fakeHttpStream) Headers() http.Header {
return s.headers
}
func (s *fakeHttpStream) Identifier() uint32 {
return s.id
}

View File

@ -17,28 +17,20 @@ limitations under the License.
package portforward
import (
"errors"
"fmt"
"io"
"net/http"
"strconv"
"sync"
"time"
"github.com/golang/glog"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/kubernetes/pkg/api"
"k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/apiserver/pkg/util/wsstream"
)
// PortForwarder knows how to forward content from a data stream to/from a port
// in a pod.
type PortForwarder interface {
// PortForwarder copies data between a data stream and a port in a pod.
PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error
PortForward(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error
}
// ServePortForward handles a port forwarding request. A single request is
@ -46,278 +38,16 @@ type PortForwarder interface {
// been timed out due to idleness. This function handles multiple forwarded
// connections; i.e., multiple `curl http://localhost:8888/` requests will be
// handled by a single invocation of ServePortForward.
func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) {
supportedPortForwardProtocols := []string{PortForwardProtocolV1Name}
_, err := httpstream.Handshake(req, w, supportedPortForwardProtocols)
// negotiated protocol isn't currently used server side, but could be in the future
if err != nil {
// Handshake writes the error to the client
utilruntime.HandleError(err)
return
func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, portForwardOptions *V4Options, idleTimeout time.Duration, streamCreationTimeout time.Duration, supportedProtocols []string) {
var err error
if wsstream.IsWebSocketRequest(req) {
err = handleWebSocketStreams(req, w, portForwarder, podName, uid, portForwardOptions, supportedProtocols, idleTimeout, streamCreationTimeout)
} else {
err = handleHttpStreams(req, w, portForwarder, podName, uid, supportedProtocols, idleTimeout, streamCreationTimeout)
}
streamChan := make(chan httpstream.Stream, 1)
glog.V(5).Infof("Upgrading port forward response")
upgrader := spdy.NewResponseUpgrader()
conn := upgrader.UpgradeResponse(w, req, portForwardStreamReceived(streamChan))
if conn == nil {
return
}
defer conn.Close()
glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout)
conn.SetIdleTimeout(idleTimeout)
h := &portForwardStreamHandler{
conn: conn,
streamChan: streamChan,
streamPairs: make(map[string]*portForwardStreamPair),
streamCreationTimeout: streamCreationTimeout,
pod: podName,
uid: uid,
forwarder: portForwarder,
}
h.run()
}
// portForwardStreamReceived is the httpstream.NewStreamHandler for port
// forward streams. It checks each stream's port and stream type headers,
// rejecting any streams that with missing or invalid values. Each valid
// stream is sent to the streams channel.
func portForwardStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error {
return func(stream httpstream.Stream, replySent <-chan struct{}) error {
// make sure it has a valid port header
portString := stream.Headers().Get(api.PortHeader)
if len(portString) == 0 {
return fmt.Errorf("%q header is required", api.PortHeader)
}
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return fmt.Errorf("unable to parse %q as a port: %v", portString, err)
}
if port < 1 {
return fmt.Errorf("port %q must be > 0", portString)
}
// make sure it has a valid stream type header
streamType := stream.Headers().Get(api.StreamType)
if len(streamType) == 0 {
return fmt.Errorf("%q header is required", api.StreamType)
}
if streamType != api.StreamTypeError && streamType != api.StreamTypeData {
return fmt.Errorf("invalid stream type %q", streamType)
}
streams <- stream
return nil
}
}
// portForwardStreamHandler is capable of processing multiple port forward
// requests over a single httpstream.Connection.
type portForwardStreamHandler struct {
conn httpstream.Connection
streamChan chan httpstream.Stream
streamPairsLock sync.RWMutex
streamPairs map[string]*portForwardStreamPair
streamCreationTimeout time.Duration
pod string
uid types.UID
forwarder PortForwarder
}
// getStreamPair returns a portForwardStreamPair for requestID. This creates a
// new pair if one does not yet exist for the requestID. The returned bool is
// true if the pair was created.
func (h *portForwardStreamHandler) getStreamPair(requestID string) (*portForwardStreamPair, bool) {
h.streamPairsLock.Lock()
defer h.streamPairsLock.Unlock()
if p, ok := h.streamPairs[requestID]; ok {
glog.V(5).Infof("(conn=%p, request=%s) found existing stream pair", h.conn, requestID)
return p, false
}
glog.V(5).Infof("(conn=%p, request=%s) creating new stream pair", h.conn, requestID)
p := newPortForwardPair(requestID)
h.streamPairs[requestID] = p
return p, true
}
// monitorStreamPair waits for the pair to receive both its error and data
// streams, or for the timeout to expire (whichever happens first), and then
// removes the pair.
func (h *portForwardStreamHandler) monitorStreamPair(p *portForwardStreamPair, timeout <-chan time.Time) {
select {
case <-timeout:
err := fmt.Errorf("(conn=%v, request=%s) timed out waiting for streams", h.conn, p.requestID)
utilruntime.HandleError(err)
p.printError(err.Error())
case <-p.complete:
glog.V(5).Infof("(conn=%v, request=%s) successfully received error and data streams", h.conn, p.requestID)
}
h.removeStreamPair(p.requestID)
}
// hasStreamPair returns a bool indicating if a stream pair for requestID
// exists.
func (h *portForwardStreamHandler) hasStreamPair(requestID string) bool {
h.streamPairsLock.RLock()
defer h.streamPairsLock.RUnlock()
_, ok := h.streamPairs[requestID]
return ok
}
// removeStreamPair removes the stream pair identified by requestID from streamPairs.
func (h *portForwardStreamHandler) removeStreamPair(requestID string) {
h.streamPairsLock.Lock()
defer h.streamPairsLock.Unlock()
delete(h.streamPairs, requestID)
}
// requestID returns the request id for stream.
func (h *portForwardStreamHandler) requestID(stream httpstream.Stream) string {
requestID := stream.Headers().Get(api.PortForwardRequestIDHeader)
if len(requestID) == 0 {
glog.V(5).Infof("(conn=%p) stream received without %s header", h.conn, api.PortForwardRequestIDHeader)
// If we get here, it's because the connection came from an older client
// that isn't generating the request id header
// (https://github.com/kubernetes/kubernetes/blob/843134885e7e0b360eb5441e85b1410a8b1a7a0c/pkg/client/unversioned/portforward/portforward.go#L258-L287)
//
// This is a best-effort attempt at supporting older clients.
//
// When there aren't concurrent new forwarded connections, each connection
// will have a pair of streams (data, error), and the stream IDs will be
// consecutive odd numbers, e.g. 1 and 3 for the first connection. Convert
// the stream ID into a pseudo-request id by taking the stream type and
// using id = stream.Identifier() when the stream type is error,
// and id = stream.Identifier() - 2 when it's data.
//
// NOTE: this only works when there are not concurrent new streams from
// multiple forwarded connections; it's a best-effort attempt at supporting
// old clients that don't generate request ids. If there are concurrent
// new connections, it's possible that 1 connection gets streams whose IDs
// are not consecutive (e.g. 5 and 9 instead of 5 and 7).
streamType := stream.Headers().Get(api.StreamType)
switch streamType {
case api.StreamTypeError:
requestID = strconv.Itoa(int(stream.Identifier()))
case api.StreamTypeData:
requestID = strconv.Itoa(int(stream.Identifier()) - 2)
}
glog.V(5).Infof("(conn=%p) automatically assigning request ID=%q from stream type=%s, stream ID=%d", h.conn, requestID, streamType, stream.Identifier())
}
return requestID
}
// run is the main loop for the portForwardStreamHandler. It processes new
// streams, invoking portForward for each complete stream pair. The loop exits
// when the httpstream.Connection is closed.
func (h *portForwardStreamHandler) run() {
glog.V(5).Infof("(conn=%p) waiting for port forward streams", h.conn)
Loop:
for {
select {
case <-h.conn.CloseChan():
glog.V(5).Infof("(conn=%p) upgraded connection closed", h.conn)
break Loop
case stream := <-h.streamChan:
requestID := h.requestID(stream)
streamType := stream.Headers().Get(api.StreamType)
glog.V(5).Infof("(conn=%p, request=%s) received new stream of type %s", h.conn, requestID, streamType)
p, created := h.getStreamPair(requestID)
if created {
go h.monitorStreamPair(p, time.After(h.streamCreationTimeout))
}
if complete, err := p.add(stream); err != nil {
msg := fmt.Sprintf("error processing stream for request %s: %v", requestID, err)
utilruntime.HandleError(errors.New(msg))
p.printError(msg)
} else if complete {
go h.portForward(p)
}
}
}
}
// portForward invokes the portForwardStreamHandler's forwarder.PortForward
// function for the given stream pair.
func (h *portForwardStreamHandler) portForward(p *portForwardStreamPair) {
defer p.dataStream.Close()
defer p.errorStream.Close()
portString := p.dataStream.Headers().Get(api.PortHeader)
port, _ := strconv.ParseUint(portString, 10, 16)
glog.V(5).Infof("(conn=%p, request=%s) invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
err := h.forwarder.PortForward(h.pod, h.uid, uint16(port), p.dataStream)
glog.V(5).Infof("(conn=%p, request=%s) done invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
if err != nil {
msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", port, h.pod, h.uid, err)
utilruntime.HandleError(msg)
fmt.Fprint(p.errorStream, msg.Error())
}
}
// portForwardStreamPair represents the error and data streams for a port
// forwarding request.
type portForwardStreamPair struct {
lock sync.RWMutex
requestID string
dataStream httpstream.Stream
errorStream httpstream.Stream
complete chan struct{}
}
// newPortForwardPair creates a new portForwardStreamPair.
func newPortForwardPair(requestID string) *portForwardStreamPair {
return &portForwardStreamPair{
requestID: requestID,
complete: make(chan struct{}),
}
}
// add adds the stream to the portForwardStreamPair. If the pair already
// contains a stream for the new stream's type, an error is returned. add
// returns true if both the data and error streams for this pair have been
// received.
func (p *portForwardStreamPair) add(stream httpstream.Stream) (bool, error) {
p.lock.Lock()
defer p.lock.Unlock()
switch stream.Headers().Get(api.StreamType) {
case api.StreamTypeError:
if p.errorStream != nil {
return false, errors.New("error stream already assigned")
}
p.errorStream = stream
case api.StreamTypeData:
if p.dataStream != nil {
return false, errors.New("data stream already assigned")
}
p.dataStream = stream
}
complete := p.errorStream != nil && p.dataStream != nil
if complete {
close(p.complete)
}
return complete, nil
}
// printError writes s to p.errorStream if p.errorStream has been set.
func (p *portForwardStreamPair) printError(s string) {
p.lock.RLock()
defer p.lock.RUnlock()
if p.errorStream != nil {
fmt.Fprint(p.errorStream, s)
runtime.HandleError(err)
return
}
}

View File

@ -0,0 +1,191 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package portforward
import (
"encoding/binary"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/golang/glog"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/apiserver/pkg/server/httplog"
"k8s.io/apiserver/pkg/util/wsstream"
"k8s.io/kubernetes/pkg/api"
)
const (
dataChannel = iota
errorChannel
v4BinaryWebsocketProtocol = "v4." + wsstream.ChannelWebSocketProtocol
v4Base64WebsocketProtocol = "v4." + wsstream.Base64ChannelWebSocketProtocol
)
// options contains details about which streams are required for
// port forwarding.
type V4Options struct {
Ports []int32
}
// newOptions creates a new options from the Request.
func NewV4Options(req *http.Request) (*V4Options, error) {
if !wsstream.IsWebSocketRequest(req) {
return &V4Options{}, nil
}
portStrings := req.URL.Query()[api.PortHeader]
if len(portStrings) == 0 {
return nil, fmt.Errorf("query parameter %q is required", api.PortHeader)
}
ports := make([]int32, 0, len(portStrings))
for _, portString := range portStrings {
if len(portString) == 0 {
return nil, fmt.Errorf("query parameter %q cannot be empty", api.PortHeader)
}
for _, p := range strings.Split(portString, ",") {
port, err := strconv.ParseUint(p, 10, 16)
if err != nil {
return nil, fmt.Errorf("unable to parse %q as a port: %v", portString, err)
}
if port < 1 {
return nil, fmt.Errorf("port %q must be > 0", portString)
}
ports = append(ports, int32(port))
}
}
return &V4Options{
Ports: ports,
}, nil
}
// handleWebSocketStreams handles requests to forward ports to a pod via
// a PortForwarder. A pair of streams are created per port (DATA n,
// ERROR n+1). The associated port is written to each stream as a unsigned 16
// bit integer in little endian format.
func handleWebSocketStreams(req *http.Request, w http.ResponseWriter, portForwarder PortForwarder, podName string, uid types.UID, opts *V4Options, supportedPortForwardProtocols []string, idleTimeout, streamCreationTimeout time.Duration) error {
channels := make([]wsstream.ChannelType, 0, len(opts.Ports)*2)
for i := 0; i < len(opts.Ports); i++ {
channels = append(channels, wsstream.ReadWriteChannel, wsstream.WriteChannel)
}
conn := wsstream.NewConn(map[string]wsstream.ChannelProtocolConfig{
"": {
Binary: true,
Channels: channels,
},
v4BinaryWebsocketProtocol: {
Binary: true,
Channels: channels,
},
v4Base64WebsocketProtocol: {
Binary: false,
Channels: channels,
},
})
conn.SetIdleTimeout(idleTimeout)
_, streams, err := conn.Open(httplog.Unlogged(w), req)
if err != nil {
err = fmt.Errorf("Unable to upgrade websocket connection: %v", err)
return err
}
defer conn.Close()
streamPairs := make([]*websocketStreamPair, len(opts.Ports))
for i := range streamPairs {
streamPair := websocketStreamPair{
port: opts.Ports[i],
dataStream: streams[i*2+dataChannel],
errorStream: streams[i*2+errorChannel],
}
streamPairs[i] = &streamPair
portBytes := make([]byte, 2)
// port is always positive so conversion is allowable
binary.LittleEndian.PutUint16(portBytes, uint16(streamPair.port))
streamPair.dataStream.Write(portBytes)
streamPair.errorStream.Write(portBytes)
}
h := &websocketStreamHandler{
conn: conn,
streamPairs: streamPairs,
pod: podName,
uid: uid,
forwarder: portForwarder,
}
h.run()
return nil
}
// websocketStreamPair represents the error and data streams for a port
// forwarding request.
type websocketStreamPair struct {
port int32
dataStream io.ReadWriteCloser
errorStream io.WriteCloser
}
// websocketStreamHandler is capable of processing a single port forward
// request over a websocket connection
type websocketStreamHandler struct {
conn *wsstream.Conn
ports []int32
streamPairs []*websocketStreamPair
pod string
uid types.UID
forwarder PortForwarder
}
// run invokes the websocketStreamHandler's forwarder.PortForward
// function for the given stream pair.
func (h *websocketStreamHandler) run() {
wg := sync.WaitGroup{}
wg.Add(len(h.streamPairs))
for _, pair := range h.streamPairs {
p := pair
go func() {
defer wg.Done()
h.portForward(p)
}()
}
wg.Wait()
}
func (h *websocketStreamHandler) portForward(p *websocketStreamPair) {
defer p.dataStream.Close()
defer p.errorStream.Close()
glog.V(5).Infof("(conn=%p) invoking forwarder.PortForward for port %d", h.conn, p.port)
err := h.forwarder.PortForward(h.pod, h.uid, p.port, p.dataStream)
glog.V(5).Infof("(conn=%p) done invoking forwarder.PortForward for port %d", h.conn, p.port)
if err != nil {
msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", p.port, h.pod, h.uid, err)
runtime.HandleError(msg)
fmt.Fprint(p.errorStream, msg.Error())
}
}

View File

@ -0,0 +1,101 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package portforward
import (
"net/http"
"reflect"
"testing"
)
func TestV4Options(t *testing.T) {
tests := map[string]struct {
url string
websocket bool
expectedOpts *V4Options
expectedError string
}{
"non-ws request": {
url: "http://example.com",
expectedOpts: &V4Options{},
},
"missing port": {
url: "http://example.com",
websocket: true,
expectedError: `query parameter "port" is required`,
},
"unable to parse port": {
url: "http://example.com?port=abc",
websocket: true,
expectedError: `unable to parse "abc" as a port: strconv.ParseUint: parsing "abc": invalid syntax`,
},
"negative port": {
url: "http://example.com?port=-1",
websocket: true,
expectedError: `unable to parse "-1" as a port: strconv.ParseUint: parsing "-1": invalid syntax`,
},
"one port": {
url: "http://example.com?port=80",
websocket: true,
expectedOpts: &V4Options{
Ports: []int32{80},
},
},
"multiple ports": {
url: "http://example.com?port=80,90,100",
websocket: true,
expectedOpts: &V4Options{
Ports: []int32{80, 90, 100},
},
},
"multiple port": {
url: "http://example.com?port=80&port=90",
websocket: true,
expectedOpts: &V4Options{
Ports: []int32{80, 90},
},
},
}
for name, test := range tests {
req, err := http.NewRequest(http.MethodGet, test.url, nil)
if err != nil {
t.Errorf("%s: invalid url %q err=%q", name, test.url, err)
continue
}
if test.websocket {
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
}
opts, err := NewV4Options(req)
if len(test.expectedError) > 0 {
if err == nil {
t.Errorf("%s: expected err=%q, but it was nil", name, test.expectedError)
}
if e, a := test.expectedError, err.Error(); e != a {
t.Errorf("%s: expected err=%q, got %q", name, e, a)
}
continue
}
if err != nil {
t.Errorf("%s: unexpected error %v", name, err)
continue
}
if !reflect.DeepEqual(test.expectedOpts, opts) {
t.Errorf("%s: expected options %#v, got %#v", name, test.expectedOpts, err)
}
}
}

View File

@ -172,7 +172,7 @@ type HostInterface interface {
AttachContainer(name string, uid types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan term.Size) error
GetKubeletContainerLogs(podFullName, containerName string, logOptions *v1.PodLogOptions, stdout, stderr io.Writer) error
ServeLogs(w http.ResponseWriter, req *http.Request)
PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error
PortForward(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error
StreamingConnectionIdleTimeout() time.Duration
ResyncInterval() time.Duration
GetHostname() string
@ -184,7 +184,7 @@ type HostInterface interface {
ListVolumesForPod(podUID types.UID) (map[string]volume.Volume, bool)
GetExec(podFullName string, podUID types.UID, containerName string, cmd []string, streamOpts remotecommand.Options) (*url.URL, error)
GetAttach(podFullName string, podUID types.UID, containerName string, streamOpts remotecommand.Options) (*url.URL, error)
GetPortForward(podName, podNamespace string, podUID types.UID) (*url.URL, error)
GetPortForward(podName, podNamespace string, podUID types.UID, portForwardOpts portforward.V4Options) (*url.URL, error)
}
// NewServer initializes and configures a kubelet.Server object to handle HTTP requests.
@ -335,9 +335,15 @@ func (s *Server) InstallDebuggingHandlers(criHandler http.Handler) {
ws = new(restful.WebService)
ws.
Path("/portForward")
ws.Route(ws.GET("/{podNamespace}/{podID}").
To(s.getPortForward).
Operation("getPortForward"))
ws.Route(ws.POST("/{podNamespace}/{podID}").
To(s.getPortForward).
Operation("getPortForward"))
ws.Route(ws.GET("/{podNamespace}/{podID}/{uid}").
To(s.getPortForward).
Operation("getPortForward"))
ws.Route(ws.POST("/{podNamespace}/{podID}/{uid}").
To(s.getPortForward).
Operation("getPortForward"))
@ -562,7 +568,7 @@ func (s *Server) getSpec(request *restful.Request, response *restful.Response) {
response.WriteEntity(info)
}
type requestParams struct {
type execRequestParams struct {
podNamespace string
podName string
podUID types.UID
@ -570,8 +576,8 @@ type requestParams struct {
cmd []string
}
func getRequestParams(req *restful.Request) requestParams {
return requestParams{
func getExecRequestParams(req *restful.Request) execRequestParams {
return execRequestParams{
podNamespace: req.PathParameter("podNamespace"),
podName: req.PathParameter("podID"),
podUID: types.UID(req.PathParameter("uid")),
@ -580,9 +586,23 @@ func getRequestParams(req *restful.Request) requestParams {
}
}
type portForwardRequestParams struct {
podNamespace string
podName string
podUID types.UID
}
func getPortForwardRequestParams(req *restful.Request) portForwardRequestParams {
return portForwardRequestParams{
podNamespace: req.PathParameter("podNamespace"),
podName: req.PathParameter("podID"),
podUID: types.UID(req.PathParameter("uid")),
}
}
// getAttach handles requests to attach to a container.
func (s *Server) getAttach(request *restful.Request, response *restful.Response) {
params := getRequestParams(request)
params := getExecRequestParams(request)
streamOpts, err := remotecommand.NewOptions(request.Request)
if err != nil {
utilruntime.HandleError(err)
@ -620,7 +640,7 @@ func (s *Server) getAttach(request *restful.Request, response *restful.Response)
// getExec handles requests to run a command inside a container.
func (s *Server) getExec(request *restful.Request, response *restful.Response) {
params := getRequestParams(request)
params := getExecRequestParams(request)
streamOpts, err := remotecommand.NewOptions(request.Request)
if err != nil {
utilruntime.HandleError(err)
@ -659,7 +679,7 @@ func (s *Server) getExec(request *restful.Request, response *restful.Response) {
// getRun handles requests to run a command inside a container.
func (s *Server) getRun(request *restful.Request, response *restful.Response) {
params := getRequestParams(request)
params := getExecRequestParams(request)
pod, ok := s.host.GetPodByName(params.podNamespace, params.podName)
if !ok {
response.WriteError(http.StatusNotFound, fmt.Errorf("pod does not exist"))
@ -693,7 +713,14 @@ func writeJsonResponse(response *restful.Response, data []byte) {
// getPortForward handles a new restful port forward request. It determines the
// pod name and uid and then calls ServePortForward.
func (s *Server) getPortForward(request *restful.Request, response *restful.Response) {
params := getRequestParams(request)
params := getPortForwardRequestParams(request)
portForwardOptions, err := portforward.NewV4Options(request.Request)
if err != nil {
utilruntime.HandleError(err)
response.WriteError(http.StatusBadRequest, err)
return
}
pod, ok := s.host.GetPodByName(params.podNamespace, params.podName)
if !ok {
response.WriteError(http.StatusNotFound, fmt.Errorf("pod does not exist"))
@ -704,7 +731,7 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp
return
}
redirect, err := s.host.GetPortForward(pod.Name, pod.Namespace, pod.UID)
redirect, err := s.host.GetPortForward(pod.Name, pod.Namespace, pod.UID, *portForwardOptions)
if err != nil {
streaming.WriteError(err, response.ResponseWriter)
return
@ -719,8 +746,10 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp
s.host,
kubecontainer.GetPodFullName(pod),
params.podUID,
portForwardOptions,
s.host.StreamingConnectionIdleTimeout(),
remotecommand.DefaultStreamCreationTimeout)
remotecommand.DefaultStreamCreationTimeout,
portforward.SupportedProtocols)
}
// ServeHTTP responds to HTTP requests on the Kubelet.

View File

@ -52,6 +52,7 @@ import (
"k8s.io/kubernetes/pkg/kubelet/cm"
kubecontainer "k8s.io/kubernetes/pkg/kubelet/container"
kubecontainertesting "k8s.io/kubernetes/pkg/kubelet/container/testing"
"k8s.io/kubernetes/pkg/kubelet/server/portforward"
"k8s.io/kubernetes/pkg/kubelet/server/remotecommand"
"k8s.io/kubernetes/pkg/kubelet/server/stats"
"k8s.io/kubernetes/pkg/util/term"
@ -73,7 +74,7 @@ type fakeKubelet struct {
runFunc func(podFullName string, uid types.UID, containerName string, cmd []string) ([]byte, error)
execFunc func(pod string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error
attachFunc func(pod string, uid types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool) error
portForwardFunc func(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error
portForwardFunc func(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error
containerLogsFunc func(podFullName, containerName string, logOptions *v1.PodLogOptions, stdout, stderr io.Writer) error
streamingConnectionIdleTimeoutFunc func() time.Duration
hostnameFunc func() string
@ -139,7 +140,7 @@ func (fk *fakeKubelet) AttachContainer(name string, uid types.UID, container str
return fk.attachFunc(name, uid, container, in, out, err, tty)
}
func (fk *fakeKubelet) PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error {
func (fk *fakeKubelet) PortForward(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error {
return fk.portForwardFunc(name, uid, port, stream)
}
@ -151,7 +152,7 @@ func (fk *fakeKubelet) GetAttach(podFullName string, podUID types.UID, container
return fk.redirectURL, nil
}
func (fk *fakeKubelet) GetPortForward(podName, podNamespace string, podUID types.UID) (*url.URL, error) {
func (fk *fakeKubelet) GetPortForward(podName, podNamespace string, podUID types.UID, portForwardOpts portforward.V4Options) (*url.URL, error) {
return fk.redirectURL, nil
}
@ -1503,7 +1504,7 @@ func TestServePortForward(t *testing.T) {
portForwardFuncDone := make(chan struct{})
fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error {
fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error {
defer close(portForwardFuncDone)
if e, a := expectedPodName, name; e != a {
@ -1514,11 +1515,11 @@ func TestServePortForward(t *testing.T) {
t.Fatalf("%d: uid: expected '%v', got '%v'", i, e, a)
}
p, err := strconv.ParseUint(test.port, 10, 16)
p, err := strconv.ParseInt(test.port, 10, 32)
if err != nil {
t.Fatalf("%d: error parsing port string '%s': %v", i, test.port, err)
}
if e, a := uint16(p), port; e != a {
if e, a := int32(p), port; e != a {
t.Fatalf("%d: port: expected '%v', got '%v'", i, e, a)
}

View File

@ -0,0 +1,331 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package server
import (
"encoding/binary"
"fmt"
"io"
"strconv"
"sync"
"testing"
"time"
"golang.org/x/net/websocket"
"k8s.io/apimachinery/pkg/types"
)
const (
dataChannel = iota
errorChannel
)
func TestServeWSPortForward(t *testing.T) {
tests := []struct {
port string
uid bool
clientData string
containerData string
shouldError bool
}{
{port: "", shouldError: true},
{port: "abc", shouldError: true},
{port: "-1", shouldError: true},
{port: "65536", shouldError: true},
{port: "0", shouldError: true},
{port: "1", shouldError: false},
{port: "8000", shouldError: false},
{port: "8000", clientData: "client data", containerData: "container data", shouldError: false},
{port: "65535", shouldError: false},
{port: "65535", uid: true, shouldError: false},
}
podNamespace := "other"
podName := "foo"
expectedPodName := getPodName(podName, podNamespace)
expectedUid := "9b01b80f-8fb4-11e4-95ab-4200af06647"
for i, test := range tests {
fw := newServerTest()
defer fw.testHTTPServer.Close()
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
return 0
}
portForwardFuncDone := make(chan struct{})
fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error {
defer close(portForwardFuncDone)
if e, a := expectedPodName, name; e != a {
t.Fatalf("%d: pod name: expected '%v', got '%v'", i, e, a)
}
if e, a := expectedUid, uid; test.uid && e != string(a) {
t.Fatalf("%d: uid: expected '%v', got '%v'", i, e, a)
}
p, err := strconv.ParseInt(test.port, 10, 32)
if err != nil {
t.Fatalf("%d: error parsing port string '%s': %v", i, test.port, err)
}
if e, a := int32(p), port; e != a {
t.Fatalf("%d: port: expected '%v', got '%v'", i, e, a)
}
if test.clientData != "" {
fromClient := make([]byte, 32)
n, err := stream.Read(fromClient)
if err != nil {
t.Fatalf("%d: error reading client data: %v", i, err)
}
if e, a := test.clientData, string(fromClient[0:n]); e != a {
t.Fatalf("%d: client data: expected to receive '%v', got '%v'", i, e, a)
}
}
if test.containerData != "" {
_, err := stream.Write([]byte(test.containerData))
if err != nil {
t.Fatalf("%d: error writing container data: %v", i, err)
}
}
return nil
}
var url string
if test.uid {
url = fmt.Sprintf("ws://%s/portForward/%s/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, expectedUid, test.port)
} else {
url = fmt.Sprintf("ws://%s/portForward/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, test.port)
}
ws, err := websocket.Dial(url, "", "http://127.0.0.1/")
if test.shouldError {
if err == nil {
t.Fatalf("%d: websocket dial expected err", i)
}
continue
} else if err != nil {
t.Fatalf("%d: websocket dial unexpected err: %v", i, err)
}
defer ws.Close()
p, err := strconv.ParseUint(test.port, 10, 16)
if err != nil {
t.Fatalf("%d: error parsing port string '%s': %v", i, test.port, err)
}
p16 := uint16(p)
channel, data, err := wsRead(ws)
if err != nil {
t.Fatalf("%d: read failed: expected no error: got %v", i, err)
}
if channel != dataChannel {
t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, dataChannel)
}
if len(data) != binary.Size(p16) {
t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(p16))
}
if e, a := p16, binary.LittleEndian.Uint16(data); e != a {
t.Fatalf("%d: wrong data: got %q: expected %s", i, data, test.port)
}
channel, data, err = wsRead(ws)
if err != nil {
t.Fatalf("%d: read succeeded: expected no error: got %v", i, err)
}
if channel != errorChannel {
t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, errorChannel)
}
if len(data) != binary.Size(p16) {
t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(p16))
}
if e, a := p16, binary.LittleEndian.Uint16(data); e != a {
t.Fatalf("%d: wrong data: got %q: expected %s", i, data, test.port)
}
if test.clientData != "" {
println("writing the client data")
err := wsWrite(ws, dataChannel, []byte(test.clientData))
if err != nil {
t.Fatalf("%d: unexpected error writing client data: %v", i, err)
}
}
if test.containerData != "" {
channel, data, err = wsRead(ws)
if err != nil {
t.Fatalf("%d: unexpected error reading container data: %v", i, err)
}
if e, a := test.containerData, string(data); e != a {
t.Fatalf("%d: expected to receive '%v' from container, got '%v'", i, e, a)
}
}
<-portForwardFuncDone
}
}
func TestServeWSMultiplePortForward(t *testing.T) {
portsText := []string{"7000,8000", "9000"}
ports := []uint16{7000, 8000, 9000}
podNamespace := "other"
podName := "foo"
expectedPodName := getPodName(podName, podNamespace)
fw := newServerTest()
defer fw.testHTTPServer.Close()
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
return 0
}
portForwardWG := sync.WaitGroup{}
portForwardWG.Add(len(ports))
portsMutex := sync.Mutex{}
portsForwarded := map[int32]struct{}{}
fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error {
defer portForwardWG.Done()
if e, a := expectedPodName, name; e != a {
t.Fatalf("%d: pod name: expected '%v', got '%v'", port, e, a)
}
portsMutex.Lock()
portsForwarded[port] = struct{}{}
portsMutex.Unlock()
fromClient := make([]byte, 32)
n, err := stream.Read(fromClient)
if err != nil {
t.Fatalf("%d: error reading client data: %v", port, err)
}
if e, a := fmt.Sprintf("client data on port %d", port), string(fromClient[0:n]); e != a {
t.Fatalf("%d: client data: expected to receive '%v', got '%v'", port, e, a)
}
_, err = stream.Write([]byte(fmt.Sprintf("container data on port %d", port)))
if err != nil {
t.Fatalf("%d: error writing container data: %v", port, err)
}
return nil
}
url := fmt.Sprintf("ws://%s/portForward/%s/%s?", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName)
for _, port := range portsText {
url = url + fmt.Sprintf("port=%s&", port)
}
ws, err := websocket.Dial(url, "", "http://127.0.0.1/")
if err != nil {
t.Fatalf("websocket dial unexpected err: %v", err)
}
defer ws.Close()
for i, port := range ports {
channel, data, err := wsRead(ws)
if err != nil {
t.Fatalf("%d: read failed: expected no error: got %v", i, err)
}
if int(channel) != i*2+dataChannel {
t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, i*2+dataChannel)
}
if len(data) != binary.Size(port) {
t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(port))
}
if e, a := port, binary.LittleEndian.Uint16(data); e != a {
t.Fatalf("%d: wrong data: got %q: expected %d", i, data, port)
}
channel, data, err = wsRead(ws)
if err != nil {
t.Fatalf("%d: read succeeded: expected no error: got %v", i, err)
}
if int(channel) != i*2+errorChannel {
t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, i*2+errorChannel)
}
if len(data) != binary.Size(port) {
t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(port))
}
if e, a := port, binary.LittleEndian.Uint16(data); e != a {
t.Fatalf("%d: wrong data: got %q: expected %d", i, data, port)
}
}
for i, port := range ports {
println("writing the client data", port)
err := wsWrite(ws, byte(i*2+dataChannel), []byte(fmt.Sprintf("client data on port %d", port)))
if err != nil {
t.Fatalf("%d: unexpected error writing client data: %v", i, err)
}
channel, data, err := wsRead(ws)
if err != nil {
t.Fatalf("%d: unexpected error reading container data: %v", i, err)
}
if int(channel) != i*2+dataChannel {
t.Fatalf("%d: wrong channel: got %q: expected %q", port, channel, i*2+dataChannel)
}
if e, a := fmt.Sprintf("container data on port %d", port), string(data); e != a {
t.Fatalf("%d: expected to receive '%v' from container, got '%v'", i, e, a)
}
}
portForwardWG.Wait()
portsMutex.Lock()
defer portsMutex.Unlock()
if len(ports) != len(portsForwarded) {
t.Fatalf("expected to forward %d ports; got %v", len(ports), portsForwarded)
}
}
func wsWrite(conn *websocket.Conn, channel byte, data []byte) error {
frame := make([]byte, len(data)+1)
frame[0] = channel
copy(frame[1:], data)
err := websocket.Message.Send(conn, frame)
return err
}
func wsRead(conn *websocket.Conn) (byte, []byte, error) {
for {
var data []byte
err := websocket.Message.Receive(conn, &data)
if err != nil {
return 0, nil, err
}
if len(data) == 0 {
continue
}
channel := data[0]
data = data[1:]
return channel, data, err
}
}

View File

@ -80,7 +80,12 @@ type Config struct {
// The streaming protocols the server supports (understands and permits). See
// k8s.io/kubernetes/pkg/kubelet/server/remotecommand/constants.go for available protocols.
// Only used for SPDY streaming.
SupportedProtocols []string
SupportedRemoteCommandProtocols []string
// The streaming protocols the server supports (understands and permits). See
// k8s.io/kubernetes/pkg/kubelet/server/portforward/constants.go for available protocols.
// Only used for SPDY streaming.
SupportedPortForwardProtocols []string
// The config for serving over TLS. If nil, TLS will not be used.
TLSConfig *tls.Config
@ -89,9 +94,10 @@ type Config struct {
// DefaultConfig provides default values for server Config. The DefaultConfig is partial, so
// some fields like Addr must still be provided.
var DefaultConfig = Config{
StreamIdleTimeout: 4 * time.Hour,
StreamCreationTimeout: remotecommand.DefaultStreamCreationTimeout,
SupportedProtocols: remotecommand.SupportedStreamingProtocols,
StreamIdleTimeout: 4 * time.Hour,
StreamCreationTimeout: remotecommand.DefaultStreamCreationTimeout,
SupportedRemoteCommandProtocols: remotecommand.SupportedStreamingProtocols,
SupportedPortForwardProtocols: portforward.SupportedProtocols,
}
// TODO(timstclair): Add auth(n/z) interface & handling.
@ -248,7 +254,7 @@ func (s *server) serveExec(req *restful.Request, resp *restful.Response) {
streamOpts,
s.config.StreamIdleTimeout,
s.config.StreamCreationTimeout,
s.config.SupportedProtocols)
s.config.SupportedRemoteCommandProtocols)
}
func (s *server) serveAttach(req *restful.Request, resp *restful.Response) {
@ -280,7 +286,7 @@ func (s *server) serveAttach(req *restful.Request, resp *restful.Response) {
streamOpts,
s.config.StreamIdleTimeout,
s.config.StreamCreationTimeout,
s.config.SupportedProtocols)
s.config.SupportedRemoteCommandProtocols)
}
func (s *server) servePortForward(req *restful.Request, resp *restful.Response) {
@ -296,14 +302,22 @@ func (s *server) servePortForward(req *restful.Request, resp *restful.Response)
return
}
portForwardOptions, err := portforward.NewV4Options(req.Request)
if err != nil {
resp.WriteError(http.StatusBadRequest, err)
return
}
portforward.ServePortForward(
resp.ResponseWriter,
req.Request,
s.runtime,
pf.PodSandboxId,
"", // unused: podUID
portForwardOptions,
s.config.StreamIdleTimeout,
s.config.StreamCreationTimeout)
s.config.StreamCreationTimeout,
s.config.SupportedPortForwardProtocols)
}
// criAdapter wraps the Runtime functions to conform to the remotecommand interfaces.
@ -324,6 +338,6 @@ func (a *criAdapter) AttachContainer(podName string, podUID types.UID, container
return a.Attach(container, in, out, err, tty, resize)
}
func (a *criAdapter) PortForward(podName string, podUID types.UID, port uint16, stream io.ReadWriteCloser) error {
return a.Runtime.PortForward(podName, int32(port), stream)
func (a *criAdapter) PortForward(podName string, podUID types.UID, port int32, stream io.ReadWriteCloser) error {
return a.Runtime.PortForward(podName, port, stream)
}

View File

@ -240,7 +240,7 @@ func TestServePortForward(t *testing.T) {
exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", reqURL)
require.NoError(t, err)
streamConn, _, err := exec.Dial(kubeletportforward.PortForwardProtocolV1Name)
streamConn, _, err := exec.Dial(kubeletportforward.ProtocolV1Name)
require.NoError(t, err)
defer streamConn.Close()

View File

@ -43,12 +43,14 @@ go_test(
deps = [
"//pkg/api:go_default_library",
"//pkg/api/testing:go_default_library",
"//pkg/kubelet/client:go_default_library",
"//vendor:k8s.io/apimachinery/pkg/api/errors",
"//vendor:k8s.io/apimachinery/pkg/api/resource",
"//vendor:k8s.io/apimachinery/pkg/apis/meta/v1",
"//vendor:k8s.io/apimachinery/pkg/fields",
"//vendor:k8s.io/apimachinery/pkg/labels",
"//vendor:k8s.io/apimachinery/pkg/runtime",
"//vendor:k8s.io/apimachinery/pkg/types",
"//vendor:k8s.io/apiserver/pkg/endpoints/request",
],
)

View File

@ -165,9 +165,10 @@ func (r *PortForwardREST) New() runtime.Object {
return &api.Pod{}
}
// NewConnectOptions returns nil since portforward doesn't take additional parameters
// NewConnectOptions returns the versioned object that represents the
// portforward parameters
func (r *PortForwardREST) NewConnectOptions() (runtime.Object, bool, string) {
return nil, false, ""
return &api.PodPortForwardOptions{}, false, ""
}
// ConnectMethods returns the methods supported by portforward
@ -177,7 +178,11 @@ func (r *PortForwardREST) ConnectMethods() []string {
// Connect returns a handler for the pod portforward proxy
func (r *PortForwardREST) Connect(ctx genericapirequest.Context, name string, opts runtime.Object, responder rest.Responder) (http.Handler, error) {
location, transport, err := pod.PortForwardLocation(r.Store, r.KubeletConn, ctx, name)
portForwardOpts, ok := opts.(*api.PodPortForwardOptions)
if !ok {
return nil, fmt.Errorf("invalid options object: %#v", opts)
}
location, transport, err := pod.PortForwardLocation(r.Store, r.KubeletConn, ctx, name, portForwardOpts)
if err != nil {
return nil, err
}

View File

@ -383,6 +383,14 @@ func streamParams(params url.Values, opts runtime.Object) error {
if opts.TTY {
params.Add(api.ExecTTYParam, "1")
}
case *api.PodPortForwardOptions:
if len(opts.Ports) > 0 {
ports := make([]string, len(opts.Ports))
for i, p := range opts.Ports {
ports[i] = strconv.FormatInt(int64(p), 10)
}
params.Add(api.PortHeader, strings.Join(ports, ","))
}
default:
return fmt.Errorf("Unknown object for streaming: %v", opts)
}
@ -477,6 +485,7 @@ func PortForwardLocation(
connInfo client.ConnectionInfoGetter,
ctx genericapirequest.Context,
name string,
opts *api.PodPortForwardOptions,
) (*url.URL, http.RoundTripper, error) {
pod, err := getPod(getter, ctx, name)
if err != nil {
@ -492,10 +501,15 @@ func PortForwardLocation(
if err != nil {
return nil, nil, err
}
params := url.Values{}
if err := streamParams(params, opts); err != nil {
return nil, nil, err
}
loc := &url.URL{
Scheme: nodeInfo.Scheme,
Host: net.JoinHostPort(nodeInfo.Hostname, nodeInfo.Port),
Path: fmt.Sprintf("/portForward/%s/%s", pod.Namespace, pod.Name),
Scheme: nodeInfo.Scheme,
Host: net.JoinHostPort(nodeInfo.Hostname, nodeInfo.Port),
Path: fmt.Sprintf("/portForward/%s/%s", pod.Namespace, pod.Name),
RawQuery: params.Encode(),
}
return loc, nodeInfo.Transport, nil
}

View File

@ -17,6 +17,7 @@ limitations under the License.
package pod
import (
"net/url"
"reflect"
"testing"
@ -26,9 +27,11 @@ import (
"k8s.io/apimachinery/pkg/fields"
"k8s.io/apimachinery/pkg/labels"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
genericapirequest "k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/kubernetes/pkg/api"
apitesting "k8s.io/kubernetes/pkg/api/testing"
"k8s.io/kubernetes/pkg/kubelet/client"
)
func TestMatchPod(t *testing.T) {
@ -333,3 +336,69 @@ func TestSelectableFieldLabelConversions(t *testing.T) {
nil,
)
}
type mockConnectionInfoGetter struct {
info *client.ConnectionInfo
}
func (g mockConnectionInfoGetter) GetConnectionInfo(nodeName types.NodeName) (*client.ConnectionInfo, error) {
return g.info, nil
}
func TestPortForwardLocation(t *testing.T) {
ctx := genericapirequest.NewDefaultContext()
tcs := []struct {
in *api.Pod
info *client.ConnectionInfo
opts *api.PodPortForwardOptions
expectedErr error
expectedURL *url.URL
}{
{
in: &api.Pod{
Spec: api.PodSpec{},
},
opts: &api.PodPortForwardOptions{},
expectedErr: errors.NewBadRequest("pod test does not have a host assigned"),
},
{
in: &api.Pod{
ObjectMeta: metav1.ObjectMeta{
Namespace: "ns",
Name: "pod1",
},
Spec: api.PodSpec{
NodeName: "node1",
},
},
info: &client.ConnectionInfo{},
opts: &api.PodPortForwardOptions{},
expectedURL: &url.URL{Host: ":", Path: "/portForward/ns/pod1"},
},
{
in: &api.Pod{
ObjectMeta: metav1.ObjectMeta{
Namespace: "ns",
Name: "pod1",
},
Spec: api.PodSpec{
NodeName: "node1",
},
},
info: &client.ConnectionInfo{},
opts: &api.PodPortForwardOptions{Ports: []int32{80}},
expectedURL: &url.URL{Host: ":", Path: "/portForward/ns/pod1", RawQuery: "port=80"},
},
}
for _, tc := range tcs {
getter := &mockPodGetter{tc.in}
connectionGetter := &mockConnectionInfoGetter{tc.info}
loc, _, err := PortForwardLocation(getter, connectionGetter, ctx, "test", tc.opts)
if !reflect.DeepEqual(err, tc.expectedErr) {
t.Errorf("expected %v, got %v", tc.expectedErr, err)
}
if !reflect.DeepEqual(loc, tc.expectedURL) {
t.Errorf("expected %v, got %v", tc.expectedURL, loc)
}
}
}

View File

@ -169,6 +169,7 @@ go_library(
"//vendor:github.com/onsi/gomega",
"//vendor:github.com/stretchr/testify/assert",
"//vendor:golang.org/x/crypto/ssh",
"//vendor:golang.org/x/net/websocket",
"//vendor:google.golang.org/api/compute/v1",
"//vendor:google.golang.org/api/googleapi",
"//vendor:gopkg.in/inf.v0",

View File

@ -17,6 +17,8 @@ limitations under the License.
package e2e
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
@ -28,6 +30,7 @@ import (
"syscall"
"time"
"golang.org/x/net/websocket"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/kubernetes/pkg/api/v1"
@ -36,6 +39,7 @@ import (
testutils "k8s.io/kubernetes/test/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
const (
@ -368,36 +372,144 @@ func doTestMustConnectSendDisconnect(bindAddress string, f *framework.Framework)
verifyLogMessage(logOutput, "^Done$")
}
func doTestOverWebSockets(bindAddress string, f *framework.Framework) {
config, err := framework.LoadConfig()
Expect(err).NotTo(HaveOccurred(), "unable to get base config")
By("creating the pod")
pod := pfPod("def", "10", "10", "100", fmt.Sprintf("%s", bindAddress))
if _, err := f.ClientSet.Core().Pods(f.Namespace.Name).Create(pod); err != nil {
framework.Failf("Couldn't create pod: %v", err)
}
if err := f.WaitForPodReady(pod.Name); err != nil {
framework.Failf("Pod did not start running: %v", err)
}
defer func() {
logs, err := framework.GetPodLogs(f.ClientSet, f.Namespace.Name, pod.Name, "portforwardtester")
if err != nil {
framework.Logf("Error getting pod log: %v", err)
} else {
framework.Logf("Pod log:\n%s", logs)
}
}()
req := f.ClientSet.Core().RESTClient().Get().
Namespace(f.Namespace.Name).
Resource("pods").
Name(pod.Name).
Suffix("portforward").
Param("ports", "80")
url := req.URL()
ws, err := framework.OpenWebSocketForURL(url, config, []string{"v4.channel.k8s.io"})
if err != nil {
framework.Failf("Failed to open websocket to %s: %v", url.String(), err)
}
defer ws.Close()
Eventually(func() error {
channel, msg, err := wsRead(ws)
if err != nil {
return fmt.Errorf("Failed to read completely from websocket %s: %v", url.String(), err)
}
if channel != 0 {
return fmt.Errorf("Got message from server that didn't start with channel 0 (data): %v", msg)
}
if p := binary.LittleEndian.Uint16(msg); p != 80 {
return fmt.Errorf("Received the wrong port: %d", p)
}
return nil
}, time.Minute, 10*time.Second).Should(BeNil())
Eventually(func() error {
channel, msg, err := wsRead(ws)
if err != nil {
return fmt.Errorf("Failed to read completely from websocket %s: %v", url.String(), err)
}
if channel != 1 {
return fmt.Errorf("Got message from server that didn't start with channel 1 (error): %v", msg)
}
if p := binary.LittleEndian.Uint16(msg); p != 80 {
return fmt.Errorf("Received the wrong port: %d", p)
}
return nil
}, time.Minute, 10*time.Second).Should(BeNil())
By("sending the expected data to the local port")
err = wsWrite(ws, 0, []byte("def"))
if err != nil {
framework.Failf("Failed to write to websocket %s: %v", url.String(), err)
}
By("reading data from the local port")
buf := bytes.Buffer{}
expectedData := bytes.Repeat([]byte("x"), 100)
Eventually(func() error {
channel, msg, err := wsRead(ws)
if err != nil {
return fmt.Errorf("Failed to read completely from websocket %s: %v", url.String(), err)
}
if channel != 0 {
return fmt.Errorf("Got message from server that didn't start with channel 0 (data): %v", msg)
}
buf.Write(msg)
if bytes.Equal(expectedData, buf.Bytes()) {
return fmt.Errorf("Expected %q from server, got %q", expectedData, buf.Bytes())
}
return nil
}, time.Minute, 10*time.Second).Should(BeNil())
By("verifying logs")
logOutput, err := framework.GetPodLogs(f.ClientSet, f.Namespace.Name, pod.Name, "portforwardtester")
if err != nil {
framework.Failf("Error retrieving pod logs: %v", err)
}
verifyLogMessage(logOutput, "^Accepted client connection$")
verifyLogMessage(logOutput, "^Received expected client data$")
}
var _ = framework.KubeDescribe("Port forwarding", func() {
f := framework.NewDefaultFramework("port-forwarding")
framework.KubeDescribe("With a server listening on 0.0.0.0 that expects a client request", func() {
It("should support a client that connects, sends no data, and disconnects", func() {
doTestMustConnectSendNothing("0.0.0.0", f)
framework.KubeDescribe("With a server listening on 0.0.0.0", func() {
framework.KubeDescribe("that expects a client request", func() {
It("should support a client that connects, sends no data, and disconnects", func() {
doTestMustConnectSendNothing("0.0.0.0", f)
})
It("should support a client that connects, sends data, and disconnects", func() {
doTestMustConnectSendDisconnect("0.0.0.0", f)
})
})
It("should support a client that connects, sends data, and disconnects", func() {
doTestMustConnectSendDisconnect("0.0.0.0", f)
framework.KubeDescribe("that expects no client request", func() {
It("should support a client that connects, sends data, and disconnects", func() {
doTestConnectSendDisconnect("0.0.0.0", f)
})
})
It("should support forwarding over websockets", func() {
doTestOverWebSockets("0.0.0.0", f)
})
})
framework.KubeDescribe("With a server listening on 0.0.0.0 that expects no client request", func() {
It("should support a client that connects, sends data, and disconnects", func() {
doTestConnectSendDisconnect("0.0.0.0", f)
framework.KubeDescribe("With a server listening on localhost", func() {
framework.KubeDescribe("that expects a client request", func() {
It("should support a client that connects, sends no data, and disconnects [Conformance]", func() {
doTestMustConnectSendNothing("localhost", f)
})
It("should support a client that connects, sends data, and disconnects [Conformance]", func() {
doTestMustConnectSendDisconnect("localhost", f)
})
})
})
framework.KubeDescribe("With a server listening on localhost that expects a client request", func() {
It("should support a client that connects, sends no data, and disconnects [Conformance]", func() {
doTestMustConnectSendNothing("localhost", f)
framework.KubeDescribe("that expects no client request", func() {
It("should support a client that connects, sends data, and disconnects [Conformance]", func() {
doTestConnectSendDisconnect("localhost", f)
})
})
It("should support a client that connects, sends data, and disconnects [Conformance]", func() {
doTestMustConnectSendDisconnect("localhost", f)
})
})
framework.KubeDescribe("With a server listening on localhost that expects no client request", func() {
It("should support a client that connects, sends data, and disconnects [Conformance]", func() {
doTestConnectSendDisconnect("localhost", f)
It("should support forwarding over websockets", func() {
doTestOverWebSockets("localhost", f)
})
})
})
@ -412,3 +524,30 @@ func verifyLogMessage(log, expected string) {
}
framework.Failf("Missing %q from log: %s", expected, log)
}
func wsRead(conn *websocket.Conn) (byte, []byte, error) {
for {
var data []byte
err := websocket.Message.Receive(conn, &data)
if err != nil {
return 0, nil, err
}
if len(data) == 0 {
continue
}
channel := data[0]
data = data[1:]
return channel, data, err
}
}
func wsWrite(conn *websocket.Conn, channel byte, data []byte) error {
frame := make([]byte, len(data)+1)
frame[0] = channel
copy(frame[1:], data)
err := websocket.Message.Send(conn, frame)
return err
}