Add streaming command execution & port forwarding

Add streaming command execution & port forwarding via HTTP connection
upgrades (currently using SPDY).
This commit is contained in:
Andy Goldstein
2015-01-08 15:41:38 -05:00
parent 25d38c175b
commit 5bd0e9ab05
45 changed files with 4439 additions and 157 deletions

View File

@@ -1349,3 +1349,32 @@ type SecretList struct {
Items []Secret `json:"items"`
}
// These constants are for remote command execution and port forwarding and are
// used by both the client side and server side components.
//
// This is probably not the ideal place for them, but it didn't seem worth it
// to create pkg/exec and pkg/portforward just to contain a single file with
// constants in it. Suggestions for more appropriate alternatives are
// definitely welcome!
const (
// Enable stdin for remote command execution
ExecStdinParam = "input"
// Enable stdout for remote command execution
ExecStdoutParam = "output"
// Enable stderr for remote command execution
ExecStderrParam = "error"
// Enable TTY for remote command execution
ExecTTYParam = "tty"
// Command to run for remote command execution
ExecCommandParamm = "command"
StreamType = "streamType"
StreamTypeStdin = "stdin"
StreamTypeStdout = "stdout"
StreamTypeStderr = "stderr"
StreamTypeData = "data"
StreamTypeError = "error"
PortHeader = "port"
)

View File

@@ -99,6 +99,7 @@ func RecoverPanics(handler http.Handler) http.Handler {
http.StatusConflict,
http.StatusNotFound,
errors.StatusUnprocessableEntity,
http.StatusSwitchingProtocols,
),
).Log()

View File

@@ -22,6 +22,7 @@ import (
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httputil"
"net/url"
@@ -34,6 +35,7 @@ import (
"github.com/GoogleCloudPlatform/kubernetes/pkg/httplog"
"github.com/GoogleCloudPlatform/kubernetes/pkg/runtime"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
"github.com/golang/glog"
"golang.org/x/net/html"
@@ -176,14 +178,67 @@ func (r *ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
httpCode = http.StatusOK
newReq.Header = req.Header
proxy := httputil.NewSingleHostReverseProxy(&url.URL{Scheme: "http", Host: destURL.Host})
proxy.Transport = &proxyTransport{
proxyScheme: req.URL.Scheme,
proxyHost: req.URL.Host,
proxyPathPrepend: path.Join(r.prefix, "ns", namespace, resource, id),
// TODO convert this entire proxy to an UpgradeAwareProxy similar to
// https://github.com/openshift/origin/blob/master/pkg/util/httpproxy/upgradeawareproxy.go.
// That proxy needs to be modified to support multiple backends, not just 1.
connectionHeader := strings.ToLower(req.Header.Get(httpstream.HeaderConnection))
if strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) && len(req.Header.Get(httpstream.HeaderUpgrade)) > 0 {
//TODO support TLS? Doesn't look like proxyTransport does anything special ...
dialAddr := util.CanonicalAddr(destURL)
backendConn, err := net.Dial("tcp", dialAddr)
if err != nil {
status := errToAPIStatus(err)
writeJSON(status.Code, r.codec, status, w)
return
}
defer backendConn.Close()
// TODO should we use _ (a bufio.ReadWriter) instead of requestHijackedConn
// when copying between the client and the backend? Docker doesn't when they
// hijack, just for reference...
requestHijackedConn, _, err := w.(http.Hijacker).Hijack()
if err != nil {
status := errToAPIStatus(err)
writeJSON(status.Code, r.codec, status, w)
return
}
defer requestHijackedConn.Close()
if err = newReq.Write(backendConn); err != nil {
status := errToAPIStatus(err)
writeJSON(status.Code, r.codec, status, w)
return
}
done := make(chan struct{}, 2)
go func() {
_, err := io.Copy(backendConn, requestHijackedConn)
if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
glog.Errorf("Error proxying data from client to backend: %v", err)
}
done <- struct{}{}
}()
go func() {
_, err := io.Copy(requestHijackedConn, backendConn)
if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
glog.Errorf("Error proxying data from backend to client: %v", err)
}
done <- struct{}{}
}()
<-done
} else {
proxy := httputil.NewSingleHostReverseProxy(&url.URL{Scheme: "http", Host: destURL.Host})
proxy.Transport = &proxyTransport{
proxyScheme: req.URL.Scheme,
proxyHost: req.URL.Host,
proxyPathPrepend: path.Join(r.prefix, "ns", namespace, resource, id),
}
proxy.FlushInterval = 200 * time.Millisecond
proxy.ServeHTTP(w, newReq)
}
proxy.FlushInterval = 200 * time.Millisecond
proxy.ServeHTTP(w, newReq)
}
type proxyTransport struct {

View File

@@ -29,6 +29,7 @@ import (
"testing"
"golang.org/x/net/html"
"golang.org/x/net/websocket"
)
func parseURLOrDie(inURL string) *url.URL {
@@ -327,3 +328,45 @@ func TestProxy(t *testing.T) {
}
}
}
func TestProxyUpgrade(t *testing.T) {
backendServer := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {
defer ws.Close()
body := make([]byte, 5)
ws.Read(body)
ws.Write([]byte("hello " + string(body)))
}))
defer backendServer.Close()
simpleStorage := &SimpleRESTStorage{
errors: map[string]error{},
resourceLocation: backendServer.URL,
expectedResourceNamespace: "myns",
}
namespaceHandler := Handle(map[string]RESTStorage{
"foo": simpleStorage,
}, codec, "/prefix", "version", selfLinker, admissionControl, requestContextMapper, namespaceMapper)
server := httptest.NewServer(namespaceHandler)
defer server.Close()
ws, err := websocket.Dial("ws://"+server.Listener.Addr().String()+"/prefix/version/proxy/namespaces/myns/foo/123", "", "http://127.0.0.1/")
if err != nil {
t.Fatalf("websocket dial err: %s", err)
}
defer ws.Close()
if _, err := ws.Write([]byte("world")); err != nil {
t.Fatalf("write err: %s", err)
}
response := make([]byte, 20)
n, err := ws.Read(response)
if err != nil {
t.Fatalf("read err: %s", err)
}
if e, a := "hello world", string(response[0:n]); e != a {
t.Fatalf("expected '%#v', got '%#v'", e, a)
}
}

View File

@@ -0,0 +1,19 @@
/*
Copyright 2015 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package portforward adds support for SSH-like port forwarding from the client's
// local host to remote containers.
package portforward

View File

@@ -0,0 +1,300 @@
/*
Copyright 2015 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package portforward
import (
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"strconv"
"strings"
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
"github.com/GoogleCloudPlatform/kubernetes/pkg/client"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream/spdy"
"github.com/golang/glog"
)
type upgrader interface {
upgrade(*client.Request, *client.Config) (httpstream.Connection, error)
}
type defaultUpgrader struct{}
func (u *defaultUpgrader) upgrade(req *client.Request, config *client.Config) (httpstream.Connection, error) {
return req.Upgrade(config, spdy.NewRoundTripper)
}
// PortForwarder knows how to listen for local connections and forward them to
// a remote pod via an upgraded HTTP request.
type PortForwarder struct {
req *client.Request
config *client.Config
ports []ForwardedPort
stopChan <-chan struct{}
streamConn httpstream.Connection
listeners []io.Closer
upgrader upgrader
Ready chan struct{}
}
// ForwardedPort contains a Local:Remote port pairing.
type ForwardedPort struct {
Local uint16
Remote uint16
}
/*
valid port specifications:
5000
- forwards from localhost:5000 to pod:5000
8888:5000
- forwards from localhost:8888 to pod:5000
0:5000
:5000
- selects a random available local port,
forwards from localhost:<random port> to pod:5000
*/
func parsePorts(ports []string) ([]ForwardedPort, error) {
var forwards []ForwardedPort
for _, portString := range ports {
parts := strings.Split(portString, ":")
var localString, remoteString string
if len(parts) == 1 {
localString = parts[0]
remoteString = parts[0]
} else if len(parts) == 2 {
localString = parts[0]
if localString == "" {
// support :5000
localString = "0"
}
remoteString = parts[1]
} else {
return nil, fmt.Errorf("Invalid port format '%s'", portString)
}
localPort, err := strconv.ParseUint(localString, 10, 16)
if err != nil {
return nil, fmt.Errorf("Error parsing local port '%s': %s", localString, err)
}
remotePort, err := strconv.ParseUint(remoteString, 10, 16)
if err != nil {
return nil, fmt.Errorf("Error parsing remote port '%s': %s", remoteString, err)
}
if remotePort == 0 {
return nil, fmt.Errorf("Remote port must be > 0")
}
forwards = append(forwards, ForwardedPort{uint16(localPort), uint16(remotePort)})
}
return forwards, nil
}
// New creates a new PortForwarder.
func New(req *client.Request, config *client.Config, ports []string, stopChan <-chan struct{}) (*PortForwarder, error) {
if len(ports) == 0 {
return nil, errors.New("You must specify at least 1 port")
}
parsedPorts, err := parsePorts(ports)
if err != nil {
return nil, err
}
return &PortForwarder{
req: req,
config: config,
ports: parsedPorts,
stopChan: stopChan,
Ready: make(chan struct{}),
}, nil
}
// ForwardPorts formats and executes a port forwarding request. The connection will remain
// open until stopChan is closed.
func (pf *PortForwarder) ForwardPorts() error {
defer pf.Close()
if pf.upgrader == nil {
pf.upgrader = &defaultUpgrader{}
}
var err error
pf.streamConn, err = pf.upgrader.upgrade(pf.req, pf.config)
if err != nil {
return fmt.Errorf("Error upgrading connection: %s", err)
}
defer pf.streamConn.Close()
return pf.forward()
}
// forward dials the remote host specific in req, upgrades the request, starts
// listeners for each port specified in ports, and forwards local connections
// to the remote host via streams.
func (pf *PortForwarder) forward() error {
var err error
listenSuccess := false
for _, port := range pf.ports {
err = pf.listenOnPort(&port)
if err != nil {
glog.Warningf("Unable to listen on port %d: %v", port, err)
}
listenSuccess = true
}
if !listenSuccess {
return fmt.Errorf("Unable to listen on any of the requested ports: %v", pf.ports)
}
close(pf.Ready)
// wait for interrupt or conn closure
select {
case <-pf.stopChan:
case <-pf.streamConn.CloseChan():
glog.Errorf("Lost connection to pod")
}
return nil
}
// listenOnPort creates a new listener on port and waits for new connections
// in the background.
func (pf *PortForwarder) listenOnPort(port *ForwardedPort) error {
listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port.Local))
if err != nil {
return err
}
parts := strings.Split(listener.Addr().String(), ":")
localPort, err := strconv.ParseUint(parts[1], 10, 16)
if err != nil {
return fmt.Errorf("Error parsing local part: %s", err)
}
port.Local = uint16(localPort)
glog.Infof("Forwarding from %d -> %d", localPort, port.Remote)
pf.listeners = append(pf.listeners, listener)
go pf.waitForConnection(listener, *port)
return nil
}
// waitForConnection waits for new connections to listener and handles them in
// the background.
func (pf *PortForwarder) waitForConnection(listener net.Listener, port ForwardedPort) {
for {
conn, err := listener.Accept()
if err != nil {
// TODO consider using something like https://github.com/hydrogen18/stoppableListener?
if !strings.Contains(strings.ToLower(err.Error()), "use of closed network connection") {
glog.Errorf("Error accepting connection on port %d: %v", port.Local, err)
}
return
}
go pf.handleConnection(conn, port)
}
}
// handleConnection copies data between the local connection and the stream to
// the remote server.
func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
defer conn.Close()
glog.Infof("Handling connection for %d", port.Local)
errorChan := make(chan error)
doneChan := make(chan struct{}, 2)
// create error stream
headers := http.Header{}
headers.Set(api.StreamType, api.StreamTypeError)
headers.Set(api.PortHeader, fmt.Sprintf("%d", port.Remote))
errorStream, err := pf.streamConn.CreateStream(headers)
if err != nil {
glog.Errorf("Error creating error stream for port %d -> %d: %v", port.Local, port.Remote, err)
return
}
defer errorStream.Reset()
go func() {
message, err := ioutil.ReadAll(errorStream)
if err != nil && err != io.EOF {
errorChan <- fmt.Errorf("Error reading from error stream for port %d -> %d: %v", port.Local, port.Remote, err)
}
if len(message) > 0 {
errorChan <- fmt.Errorf("An error occurred forwarding %d -> %d: %v", port.Local, port.Remote, string(message))
}
}()
// create data stream
headers.Set(api.StreamType, api.StreamTypeData)
dataStream, err := pf.streamConn.CreateStream(headers)
if err != nil {
glog.Errorf("Error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err)
return
}
// Send a Reset when this function exits to completely tear down the stream here
// and in the remote server.
defer dataStream.Reset()
go func() {
// Copy from the remote side to the local port. We won't get an EOF from
// the server as it has no way of knowing when to close the stream. We'll
// take care of closing both ends of the stream with the call to
// stream.Reset() when this function exits.
if _, err := io.Copy(conn, dataStream); err != nil && err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") {
glog.Errorf("Error copying from remote stream to local connection: %v", err)
}
doneChan <- struct{}{}
}()
go func() {
// Copy from the local port to the remote side. Here we will be able to know
// when the Copy gets an EOF from conn, as that will happen as soon as conn is
// closed (i.e. client disconnected).
if _, err := io.Copy(dataStream, conn); err != nil && err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") {
glog.Errorf("Error copying from local connection to remote stream: %v", err)
}
doneChan <- struct{}{}
}()
select {
case err := <-errorChan:
glog.Error(err)
case <-doneChan:
}
}
func (pf *PortForwarder) Close() {
// stop all listeners
for _, l := range pf.listeners {
if err := l.Close(); err != nil {
glog.Errorf("Error closing listener: %v", err)
}
}
}

View File

@@ -0,0 +1,321 @@
/*
Copyright 2015 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package portforward
import (
"bytes"
"errors"
"fmt"
"io"
"net"
"net/http"
"reflect"
"sync"
"testing"
"time"
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
"github.com/GoogleCloudPlatform/kubernetes/pkg/client"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
)
func TestParsePortsAndNew(t *testing.T) {
tests := []struct {
input []string
expected []ForwardedPort
expectParseError bool
expectNewError bool
}{
{input: []string{}, expectNewError: true},
{input: []string{"a"}, expectParseError: true, expectNewError: true},
{input: []string{":a"}, expectParseError: true, expectNewError: true},
{input: []string{"-1"}, expectParseError: true, expectNewError: true},
{input: []string{"65536"}, expectParseError: true, expectNewError: true},
{input: []string{"0"}, expectParseError: true, expectNewError: true},
{input: []string{"0:0"}, expectParseError: true, expectNewError: true},
{input: []string{"a:5000"}, expectParseError: true, expectNewError: true},
{input: []string{"5000:a"}, expectParseError: true, expectNewError: true},
{
input: []string{"5000", "5000:5000", "8888:5000", "5000:8888", ":5000", "0:5000"},
expected: []ForwardedPort{
{5000, 5000},
{5000, 5000},
{8888, 5000},
{5000, 8888},
{0, 5000},
{0, 5000},
},
},
}
for i, test := range tests {
parsed, err := parsePorts(test.input)
haveError := err != nil
if e, a := test.expectParseError, haveError; e != a {
t.Fatalf("%d: parsePorts: error expected=%t, got %t: %s", i, e, a, err)
}
expectedRequest := &client.Request{}
expectedConfig := &client.Config{}
expectedStopChan := make(chan struct{})
pf, err := New(expectedRequest, expectedConfig, test.input, expectedStopChan)
haveError = err != nil
if e, a := test.expectNewError, haveError; e != a {
t.Fatalf("%d: New: error expected=%t, got %t: %s", i, e, a, err)
}
if test.expectParseError || test.expectNewError {
continue
}
for pi, expectedPort := range test.expected {
if e, a := expectedPort.Local, parsed[pi].Local; e != a {
t.Fatalf("%d: local expected: %d, got: %d", i, e, a)
}
if e, a := expectedPort.Remote, parsed[pi].Remote; e != a {
t.Fatalf("%d: remote expected: %d, got: %d", i, e, a)
}
}
if e, a := expectedRequest, pf.req; e != a {
t.Fatalf("%d: req: expected %#v, got %#v", i, e, a)
}
if e, a := expectedConfig, pf.config; e != a {
t.Fatalf("%d: config: expected %#v, got %#v", i, e, a)
}
if e, a := test.expected, pf.ports; !reflect.DeepEqual(e, a) {
t.Fatalf("%d: ports: expected %#v, got %#v", i, e, a)
}
if e, a := expectedStopChan, pf.stopChan; e != a {
t.Fatalf("%d: stopChan: expected %#v, got %#v", i, e, a)
}
if pf.Ready == nil {
t.Fatalf("%d: Ready should be non-nil", i)
}
}
}
type fakeUpgrader struct {
conn *fakeUpgradeConnection
err error
}
func (u *fakeUpgrader) upgrade(req *client.Request, config *client.Config) (httpstream.Connection, error) {
return u.conn, u.err
}
type fakeUpgradeConnection struct {
closeCalled bool
lock sync.Mutex
streams map[string]*fakeUpgradeStream
portData map[string]string
}
func newFakeUpgradeConnection() *fakeUpgradeConnection {
return &fakeUpgradeConnection{
streams: make(map[string]*fakeUpgradeStream),
portData: make(map[string]string),
}
}
func (c *fakeUpgradeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
c.lock.Lock()
defer c.lock.Unlock()
stream := &fakeUpgradeStream{}
c.streams[headers.Get(api.PortHeader)] = stream
stream.data = c.portData[headers.Get(api.PortHeader)]
return stream, nil
}
func (c *fakeUpgradeConnection) Close() error {
c.lock.Lock()
defer c.lock.Unlock()
c.closeCalled = true
return nil
}
func (c *fakeUpgradeConnection) CloseChan() <-chan bool {
return make(chan bool)
}
func (c *fakeUpgradeConnection) SetIdleTimeout(timeout time.Duration) {
}
type fakeUpgradeStream struct {
readCalled bool
writeCalled bool
dataWritten []byte
closeCalled bool
resetCalled bool
data string
lock sync.Mutex
}
func (s *fakeUpgradeStream) Read(p []byte) (int, error) {
s.lock.Lock()
defer s.lock.Unlock()
s.readCalled = true
b := []byte(s.data)
n := copy(p, b)
return n, io.EOF
}
func (s *fakeUpgradeStream) Write(p []byte) (int, error) {
s.lock.Lock()
defer s.lock.Unlock()
s.writeCalled = true
s.dataWritten = make([]byte, len(p))
copy(s.dataWritten, p)
return len(p), io.EOF
}
func (s *fakeUpgradeStream) Close() error {
s.lock.Lock()
defer s.lock.Unlock()
s.closeCalled = true
return nil
}
func (s *fakeUpgradeStream) Reset() error {
s.lock.Lock()
defer s.lock.Unlock()
s.resetCalled = true
return nil
}
func (s *fakeUpgradeStream) Headers() http.Header {
s.lock.Lock()
defer s.lock.Unlock()
return http.Header{}
}
func TestForwardPorts(t *testing.T) {
testCases := []struct {
Upgrader *fakeUpgrader
Ports []string
Send map[uint16]string
Receive map[uint16]string
Err bool
}{
{
Upgrader: &fakeUpgrader{err: errors.New("bail")},
Err: true,
},
{
Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()},
Ports: []string{"5000"},
},
{
Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()},
Ports: []string{"5000", "6000"},
Send: map[uint16]string{
5000: "abcd",
6000: "ghij",
},
Receive: map[uint16]string{
5000: "1234",
6000: "5678",
},
},
}
for i, testCase := range testCases {
stopChan := make(chan struct{}, 1)
pf, err := New(&client.Request{}, &client.Config{}, testCase.Ports, stopChan)
hasErr := err != nil
if hasErr != testCase.Err {
t.Fatalf("%d: New: expected %t, got %t: %v", i, testCase.Err, hasErr, err)
}
if pf == nil {
continue
}
pf.upgrader = testCase.Upgrader
if testCase.Upgrader.err != nil {
err := pf.ForwardPorts()
hasErr := err != nil
if hasErr != testCase.Err {
t.Fatalf("%d: ForwardPorts: expected %t, got %t: %v", i, testCase.Err, hasErr, err)
}
continue
}
doneChan := make(chan error)
go func() {
doneChan <- pf.ForwardPorts()
}()
select {
case <-pf.Ready:
case <-time.After(500 * time.Millisecond):
t.Fatalf("%d: timed out waiting for listeners", i)
}
conn := testCase.Upgrader.conn
for port, data := range testCase.Send {
conn.lock.Lock()
conn.portData[fmt.Sprintf("%d", port)] = testCase.Receive[port]
conn.lock.Unlock()
clientConn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
if err != nil {
t.Fatalf("%d: error dialing %d: %s", i, port, err)
}
defer clientConn.Close()
n, err := clientConn.Write([]byte(data))
if err != nil && err != io.EOF {
t.Fatalf("%d: Error sending data '%s': %s", i, data, err)
}
if n == 0 {
t.Fatalf("%d: unexpected write of 0 bytes", i)
}
b := make([]byte, 4)
n, err = clientConn.Read(b)
if err != nil && err != io.EOF {
t.Fatalf("%d: Error reading data: %s", i, err)
}
if !bytes.Equal([]byte(testCase.Receive[port]), b) {
t.Fatalf("%d: expected to read '%s', got '%s'", i, testCase.Receive[port], b)
}
}
// tell r.ForwardPorts to stop
close(stopChan)
// wait for r.ForwardPorts to actually return
select {
case err := <-doneChan:
if err != nil {
t.Fatalf("%d: unexpected error: %s", err)
}
case <-time.After(200 * time.Millisecond):
t.Fatalf("%d: timeout waiting for ForwardPorts to finish")
}
if e, a := len(testCase.Send), len(conn.streams); e != a {
t.Fatalf("%d: expected %d streams to be created, got %d", e, a)
}
if !conn.closeCalled {
t.Fatalf("%d: expected conn closure", i)
}
}
}

View File

@@ -0,0 +1,20 @@
/*
Copyright 2015 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package remotecommand adds support for executing commands in containers,
// with support for separate stdin, stdout, and stderr streams, as well as
// TTY.
package remotecommand

View File

@@ -0,0 +1,186 @@
/*
Copyright 2015 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package remotecommand
import (
"fmt"
"io"
"io/ioutil"
"net/http"
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
"github.com/GoogleCloudPlatform/kubernetes/pkg/client"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream/spdy"
"github.com/golang/glog"
)
type upgrader interface {
upgrade(*client.Request, *client.Config) (httpstream.Connection, error)
}
type defaultUpgrader struct{}
func (u *defaultUpgrader) upgrade(req *client.Request, config *client.Config) (httpstream.Connection, error) {
return req.Upgrade(config, spdy.NewRoundTripper)
}
type RemoteCommandExecutor struct {
req *client.Request
config *client.Config
command []string
stdin io.Reader
stdout io.Writer
stderr io.Writer
tty bool
upgrader upgrader
}
func New(req *client.Request, config *client.Config, command []string, stdin io.Reader, stdout, stderr io.Writer, tty bool) *RemoteCommandExecutor {
return &RemoteCommandExecutor{
req: req,
config: config,
command: command,
stdin: stdin,
stdout: stdout,
stderr: stderr,
tty: tty,
}
}
// Execute sends a remote command execution request, upgrading the
// connection and creating streams to represent stdin/stdout/stderr. Data is
// copied between these streams and the supplied stdin/stdout/stderr parameters.
func (e *RemoteCommandExecutor) Execute() error {
doStdin := (e.stdin != nil)
doStdout := (e.stdout != nil)
doStderr := (!e.tty && e.stderr != nil)
if doStdin {
e.req.Param(api.ExecStdinParam, "1")
}
if doStdout {
e.req.Param(api.ExecStdoutParam, "1")
}
if doStderr {
e.req.Param(api.ExecStderrParam, "1")
}
if e.tty {
e.req.Param(api.ExecTTYParam, "1")
}
for _, s := range e.command {
e.req.Param(api.ExecCommandParamm, s)
}
if e.upgrader == nil {
e.upgrader = &defaultUpgrader{}
}
conn, err := e.upgrader.upgrade(e.req, e.config)
if err != nil {
return err
}
defer conn.Close()
doneChan := make(chan struct{}, 2)
errorChan := make(chan error)
cp := func(s string, dst io.Writer, src io.Reader) {
glog.V(4).Infof("Copying %s", s)
defer glog.V(4).Infof("Done copying %s", s)
if _, err := io.Copy(dst, src); err != nil && err != io.EOF {
glog.Errorf("Error copying %s: %v", s, err)
}
if s == api.StreamTypeStdout || s == api.StreamTypeStderr {
doneChan <- struct{}{}
}
}
headers := http.Header{}
headers.Set(api.StreamType, api.StreamTypeError)
errorStream, err := conn.CreateStream(headers)
if err != nil {
return err
}
go func() {
message, err := ioutil.ReadAll(errorStream)
if err != nil && err != io.EOF {
errorChan <- fmt.Errorf("Error reading from error stream: %s", err)
return
}
if len(message) > 0 {
errorChan <- fmt.Errorf("Error executing remote command: %s", message)
return
}
}()
defer errorStream.Reset()
if doStdin {
headers.Set(api.StreamType, api.StreamTypeStdin)
remoteStdin, err := conn.CreateStream(headers)
if err != nil {
return err
}
defer remoteStdin.Reset()
// TODO this goroutine will never exit cleanly (the io.Copy never unblocks)
// because stdin is not closed until the process exits. If we try to call
// stdin.Close(), it returns no error but doesn't unblock the copy. It will
// exit when the process exits, instead.
go cp(api.StreamTypeStdin, remoteStdin, e.stdin)
}
waitCount := 0
completedStreams := 0
if doStdout {
waitCount++
headers.Set(api.StreamType, api.StreamTypeStdout)
remoteStdout, err := conn.CreateStream(headers)
if err != nil {
return err
}
defer remoteStdout.Reset()
go cp(api.StreamTypeStdout, e.stdout, remoteStdout)
}
if doStderr && !e.tty {
waitCount++
headers.Set(api.StreamType, api.StreamTypeStderr)
remoteStderr, err := conn.CreateStream(headers)
if err != nil {
return err
}
defer remoteStderr.Reset()
go cp(api.StreamTypeStderr, e.stderr, remoteStderr)
}
Loop:
for {
select {
case <-doneChan:
completedStreams++
if completedStreams == waitCount {
break Loop
}
case err := <-errorChan:
return err
}
}
return nil
}

View File

@@ -0,0 +1,288 @@
/*
Copyright 2015 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package remotecommand
import (
"bytes"
"errors"
"io"
"net/http"
"strings"
"sync"
"testing"
"time"
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
"github.com/GoogleCloudPlatform/kubernetes/pkg/client"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
)
type fakeUpgrader struct {
conn *fakeUpgradeConnection
err error
}
func (u *fakeUpgrader) upgrade(req *client.Request, config *client.Config) (httpstream.Connection, error) {
return u.conn, u.err
}
type fakeUpgradeConnection struct {
closeCalled bool
lock sync.Mutex
stdin *fakeUpgradeStream
stdout *fakeUpgradeStream
stdoutData string
stderr *fakeUpgradeStream
stderrData string
errorStream *fakeUpgradeStream
errorData string
unexpectedStreamCreated bool
}
func newFakeUpgradeConnection() *fakeUpgradeConnection {
return &fakeUpgradeConnection{}
}
func (c *fakeUpgradeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
c.lock.Lock()
defer c.lock.Unlock()
stream := &fakeUpgradeStream{}
switch headers.Get(api.StreamType) {
case api.StreamTypeStdin:
c.stdin = stream
case api.StreamTypeStdout:
c.stdout = stream
stream.data = c.stdoutData
case api.StreamTypeStderr:
c.stderr = stream
stream.data = c.stderrData
case api.StreamTypeError:
c.errorStream = stream
stream.data = c.errorData
default:
c.unexpectedStreamCreated = true
}
return stream, nil
}
func (c *fakeUpgradeConnection) Close() error {
c.lock.Lock()
defer c.lock.Unlock()
c.closeCalled = true
return nil
}
func (c *fakeUpgradeConnection) CloseChan() <-chan bool {
return make(chan bool)
}
func (c *fakeUpgradeConnection) SetIdleTimeout(timeout time.Duration) {
}
type fakeUpgradeStream struct {
readCalled bool
writeCalled bool
dataWritten []byte
closeCalled bool
resetCalled bool
data string
lock sync.Mutex
}
func (s *fakeUpgradeStream) Read(p []byte) (int, error) {
s.lock.Lock()
defer s.lock.Unlock()
s.readCalled = true
b := []byte(s.data)
n := copy(p, b)
return n, io.EOF
}
func (s *fakeUpgradeStream) Write(p []byte) (int, error) {
s.lock.Lock()
defer s.lock.Unlock()
s.writeCalled = true
s.dataWritten = make([]byte, len(p))
copy(s.dataWritten, p)
return len(p), io.EOF
}
func (s *fakeUpgradeStream) Close() error {
s.lock.Lock()
defer s.lock.Unlock()
s.closeCalled = true
return nil
}
func (s *fakeUpgradeStream) Reset() error {
s.lock.Lock()
defer s.lock.Unlock()
s.resetCalled = true
return nil
}
func (s *fakeUpgradeStream) Headers() http.Header {
s.lock.Lock()
defer s.lock.Unlock()
return http.Header{}
}
func TestRequestExecuteRemoteCommand(t *testing.T) {
testCases := []struct {
Upgrader *fakeUpgrader
Stdin string
Stdout string
Stderr string
Error string
Tty bool
ShouldError bool
}{
{
Upgrader: &fakeUpgrader{err: errors.New("bail")},
ShouldError: true,
},
{
Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()},
Stdin: "a",
Stdout: "b",
Stderr: "c",
Error: "bail",
ShouldError: true,
},
{
Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()},
Stdin: "a",
Stdout: "b",
Stderr: "c",
},
{
Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()},
Stdin: "a",
Stdout: "b",
Stderr: "c",
Tty: true,
},
}
for i, testCase := range testCases {
if testCase.Error != "" {
testCase.Upgrader.conn.errorData = testCase.Error
}
if testCase.Stdout != "" {
testCase.Upgrader.conn.stdoutData = testCase.Stdout
}
if testCase.Stderr != "" {
testCase.Upgrader.conn.stderrData = testCase.Stderr
}
var localOut, localErr *bytes.Buffer
if testCase.Stdout != "" {
localOut = &bytes.Buffer{}
}
if testCase.Stderr != "" {
localErr = &bytes.Buffer{}
}
e := New(&client.Request{}, &client.Config{}, []string{"ls", "/"}, strings.NewReader(testCase.Stdin), localOut, localErr, testCase.Tty)
e.upgrader = testCase.Upgrader
err := e.Execute()
hasErr := err != nil
if hasErr != testCase.ShouldError {
t.Fatalf("%d: expected %t, got %t: %v", i, testCase.ShouldError, hasErr, err)
}
conn := testCase.Upgrader.conn
if testCase.Error != "" {
if conn.errorStream == nil {
t.Fatalf("%d: expected error stream creation", i)
}
if !conn.errorStream.readCalled {
t.Fatalf("%d: expected error stream read", i)
}
if e, a := testCase.Error, err.Error(); !strings.Contains(a, e) {
t.Fatalf("%d: expected error stream read '%v', got '%v'", i, e, a)
}
if !conn.errorStream.resetCalled {
t.Fatalf("%d: expected error reset", i)
}
}
if testCase.ShouldError {
continue
}
if testCase.Stdin != "" {
if conn.stdin == nil {
t.Fatalf("%d: expected stdin stream creation", i)
}
if !conn.stdin.writeCalled {
t.Fatalf("%d: expected stdin stream write", i)
}
if e, a := testCase.Stdin, string(conn.stdin.dataWritten); e != a {
t.Fatalf("%d: expected stdin write %v, got %v", i, e, a)
}
if !conn.stdin.resetCalled {
t.Fatalf("%d: expected stdin reset", i)
}
}
if testCase.Stdout != "" {
if conn.stdout == nil {
t.Fatalf("%d: expected stdout stream creation", i)
}
if !conn.stdout.readCalled {
t.Fatalf("%d: expected stdout stream read", i)
}
if e, a := testCase.Stdout, localOut; e != a.String() {
t.Fatalf("%d: expected stdout data '%s', got '%s'", i, e, a)
}
if !conn.stdout.resetCalled {
t.Fatalf("%d: expected stdout reset", i)
}
}
if testCase.Stderr != "" {
if testCase.Tty {
if conn.stderr != nil {
t.Fatalf("%d: unexpected stderr stream creation", i)
}
if localErr.String() != "" {
t.Fatalf("%d: unexpected stderr data '%s'", i, localErr)
}
} else {
if conn.stderr == nil {
t.Fatalf("%d: expected stderr stream creation", i)
}
if !conn.stderr.readCalled {
t.Fatalf("%d: expected stderr stream read", i)
}
if e, a := testCase.Stderr, localErr; e != a.String() {
t.Fatalf("%d: expected stderr data '%s', got '%s'", i, e, a)
}
if !conn.stderr.resetCalled {
t.Fatalf("%d: expected stderr reset", i)
}
}
}
if !conn.closeCalled {
t.Fatalf("%d: expected upgraded connection to get closed")
}
}
}

View File

@@ -18,6 +18,7 @@ package client
import (
"bytes"
"crypto/tls"
"fmt"
"io"
"io/ioutil"
@@ -33,6 +34,7 @@ import (
"github.com/GoogleCloudPlatform/kubernetes/pkg/labels"
"github.com/GoogleCloudPlatform/kubernetes/pkg/runtime"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
"github.com/GoogleCloudPlatform/kubernetes/pkg/watch"
watchjson "github.com/GoogleCloudPlatform/kubernetes/pkg/watch/json"
"github.com/golang/glog"
@@ -277,7 +279,7 @@ func (r *Request) setParam(paramName, value string) *Request {
if r.params == nil {
r.params = make(url.Values)
}
r.params[paramName] = []string{value}
r.params[paramName] = append(r.params[paramName], value)
return r
}
@@ -347,8 +349,10 @@ func (r *Request) finalURL() string {
finalURL.Path = p
query := url.Values{}
for key, value := range r.params {
query[key] = value
for key, values := range r.params {
for _, value := range values {
query.Add(key, value)
}
}
if r.namespaceSet && r.namespaceInQuery {
@@ -434,6 +438,41 @@ func (r *Request) Stream() (io.ReadCloser, error) {
return resp.Body, nil
}
// Upgrade upgrades the request so that it supports multiplexed bidirectional
// streams. The current implementation uses SPDY, but this could be replaced
// with HTTP/2 once it's available, or something else.
func (r *Request) Upgrade(config *Config, newRoundTripperFunc func(*tls.Config) httpstream.UpgradeRoundTripper) (httpstream.Connection, error) {
if r.err != nil {
return nil, r.err
}
tlsConfig, err := TLSConfigFor(config)
if err != nil {
return nil, err
}
upgradeRoundTripper := newRoundTripperFunc(tlsConfig)
wrapper, err := HTTPWrappersForConfig(config, upgradeRoundTripper)
if err != nil {
return nil, err
}
r.client = &http.Client{Transport: wrapper}
req, err := http.NewRequest(r.verb, r.finalURL(), nil)
if err != nil {
return nil, fmt.Errorf("Error creating request: %s", err)
}
resp, err := r.client.Do(req)
if err != nil {
return nil, fmt.Errorf("Error sending request: %s", err)
}
defer resp.Body.Close()
return upgradeRoundTripper.NewConnection(resp)
}
// Do formats and executes the request. Returns a Result object for easy response
// processing.
//
@@ -513,6 +552,8 @@ func (r *Request) transformResponse(resp *http.Response, req *http.Request) ([]b
}
switch {
case resp.StatusCode == http.StatusSwitchingProtocols:
// no-op, we've been upgraded
case resp.StatusCode < http.StatusOK || resp.StatusCode > http.StatusPartialContent:
if !isStatusResponse {
var err error = &UnexpectedStatusError{

View File

@@ -18,6 +18,7 @@ package client
import (
"bytes"
"crypto/tls"
"encoding/base64"
"errors"
"io"
@@ -40,6 +41,7 @@ import (
"github.com/GoogleCloudPlatform/kubernetes/pkg/labels"
"github.com/GoogleCloudPlatform/kubernetes/pkg/runtime"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
"github.com/GoogleCloudPlatform/kubernetes/pkg/watch"
watchjson "github.com/GoogleCloudPlatform/kubernetes/pkg/watch/json"
)
@@ -151,16 +153,22 @@ func TestRequestParam(t *testing.T) {
if !api.Semantic.DeepDerivative(r.params, url.Values{"foo": []string{"a"}}) {
t.Errorf("should have set a param: %#v", r)
}
r.Param("bar", "1")
r.Param("bar", "2")
if !api.Semantic.DeepDerivative(r.params, url.Values{"foo": []string{"a"}, "bar": []string{"1", "2"}}) {
t.Errorf("should have set a param: %#v", r)
}
}
func TestRequestURI(t *testing.T) {
r := (&Request{}).Param("foo", "a")
r.Prefix("other")
r.RequestURI("/test?foo=b&a=b")
r.RequestURI("/test?foo=b&a=b&c=1&c=2")
if r.path != "/test" {
t.Errorf("path is wrong: %#v", r)
}
if !api.Semantic.DeepDerivative(r.params, url.Values{"a": []string{"b"}, "foo": []string{"b"}}) {
if !api.Semantic.DeepDerivative(r.params, url.Values{"a": []string{"b"}, "foo": []string{"b"}, "c": []string{"1", "2"}}) {
t.Errorf("should have set a param: %#v", r)
}
}
@@ -443,6 +451,122 @@ func TestRequestStream(t *testing.T) {
}
}
type fakeUpgradeConnection struct{}
func (c *fakeUpgradeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
return nil, nil
}
func (c *fakeUpgradeConnection) Close() error {
return nil
}
func (c *fakeUpgradeConnection) CloseChan() <-chan bool {
return make(chan bool)
}
func (c *fakeUpgradeConnection) SetIdleTimeout(timeout time.Duration) {
}
type fakeUpgradeRoundTripper struct {
req *http.Request
conn httpstream.Connection
}
func (f *fakeUpgradeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
f.req = req
b := []byte{}
body := ioutil.NopCloser(bytes.NewReader(b))
resp := &http.Response{
StatusCode: 101,
Body: body,
}
return resp, nil
}
func (f *fakeUpgradeRoundTripper) NewConnection(resp *http.Response) (httpstream.Connection, error) {
return f.conn, nil
}
func TestRequestUpgrade(t *testing.T) {
uri, _ := url.Parse("http://localhost/")
testCases := []struct {
Request *Request
Config *Config
RoundTripper *fakeUpgradeRoundTripper
Err bool
AuthBasicHeader bool
AuthBearerHeader bool
}{
{
Request: &Request{err: errors.New("bail")},
Err: true,
},
{
Request: &Request{},
Config: &Config{
TLSClientConfig: TLSClientConfig{
CAFile: "foo",
},
Insecure: true,
},
Err: true,
},
{
Request: &Request{},
Config: &Config{
Username: "u",
Password: "p",
BearerToken: "b",
},
Err: true,
},
{
Request: NewRequest(nil, "", uri, testapi.Codec(), true, true),
Config: &Config{
Username: "u",
Password: "p",
},
AuthBasicHeader: true,
Err: false,
},
{
Request: NewRequest(nil, "", uri, testapi.Codec(), true, true),
Config: &Config{
BearerToken: "b",
},
AuthBearerHeader: true,
Err: false,
},
}
for i, testCase := range testCases {
r := testCase.Request
rt := &fakeUpgradeRoundTripper{}
expectedConn := &fakeUpgradeConnection{}
conn, err := r.Upgrade(testCase.Config, func(config *tls.Config) httpstream.UpgradeRoundTripper {
rt.conn = expectedConn
return rt
})
_ = conn
hasErr := err != nil
if hasErr != testCase.Err {
t.Errorf("%d: expected %t, got %t: %v", i, testCase.Err, hasErr, r.err)
}
if testCase.Err {
continue
}
if testCase.AuthBasicHeader && !strings.Contains(rt.req.Header.Get("Authorization"), "Basic") {
t.Errorf("%d: expected basic auth header, got: %s", rt.req.Header.Get("Authorization"))
}
if testCase.AuthBearerHeader && !strings.Contains(rt.req.Header.Get("Authorization"), "Bearer") {
t.Errorf("%d: expected bearer auth header, got: %s", rt.req.Header.Get("Authorization"))
}
if e, a := expectedConn, conn; e != a {
t.Errorf("%d: conn: expected %#v, got %#v", i, e, a)
}
}
}
func TestRequestDo(t *testing.T) {
testCases := []struct {
Request *Request

View File

@@ -355,7 +355,7 @@ func (e Equalities) deepValueDerive(v1, v2 reflect.Value, visited map[visit]bool
}
// DeepDerivative is similar to DeepEqual except that unset fields in a1 are
// ignored (not compared). This allows we to focus on the fields that matter to
// ignored (not compared). This allows us to focus on the fields that matter to
// the semantic comparison.
//
// The unset fields include a nil pointer and an empty string.

View File

@@ -17,7 +17,9 @@ limitations under the License.
package httplog
import (
"bufio"
"fmt"
"net"
"net/http"
"runtime"
"time"
@@ -46,6 +48,10 @@ type logger interface {
// Add a layer on top of ResponseWriter, so we can track latency and error
// message sources.
//
// TODO now that we're using go-restful, we shouldn't need to be wrapping
// the http.ResponseWriter. We can recover panics from go-restful, and
// the logging value is questionable.
type respLogger struct {
status int
statusStack string
@@ -68,7 +74,7 @@ func (passthroughLogger) Addf(format string, data ...interface{}) {
// DefaultStacktracePred is the default implementation of StacktracePred.
func DefaultStacktracePred(status int) bool {
return status < http.StatusOK || status >= http.StatusBadRequest
return (status < http.StatusOK || status >= http.StatusBadRequest) && status != http.StatusSwitchingProtocols
}
// NewLogged turns a normal response writer into a logged response writer.
@@ -186,3 +192,8 @@ func (rl *respLogger) WriteHeader(status int) {
}
rl.w.WriteHeader(status)
}
// Hijack implements http.Hijacker.
func (rl *respLogger) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return rl.w.(http.Hijacker).Hijack()
}

View File

@@ -183,7 +183,7 @@ func (f *Factory) BindFlags(flags *pflag.FlagSet) {
}
// NewKubectlCommand creates the `kubectl` command and its nested children.
func (f *Factory) NewKubectlCommand(out io.Writer) *cobra.Command {
func (f *Factory) NewKubectlCommand(in io.Reader, out, err io.Writer) *cobra.Command {
// Parent command to which all subcommands are added.
cmds := &cobra.Command{
Use: "kubectl",
@@ -211,6 +211,9 @@ Find more information at https://github.com/GoogleCloudPlatform/kubernetes.`,
cmds.AddCommand(f.NewCmdRollingUpdate(out))
cmds.AddCommand(f.NewCmdResize(out))
cmds.AddCommand(f.NewCmdExec(in, out, err))
cmds.AddCommand(f.NewCmdPortForward())
cmds.AddCommand(f.NewCmdRunContainer(out))
cmds.AddCommand(f.NewCmdStop(out))
cmds.AddCommand(f.NewCmdExposeService(out))

133
pkg/kubectl/cmd/exec.go Normal file
View File

@@ -0,0 +1,133 @@
/*
Copyright 2014 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cmd
import (
"io"
"os"
"os/signal"
"syscall"
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
"github.com/GoogleCloudPlatform/kubernetes/pkg/client/remotecommand"
"github.com/GoogleCloudPlatform/kubernetes/pkg/kubectl/cmd/util"
"github.com/docker/docker/pkg/term"
"github.com/golang/glog"
"github.com/spf13/cobra"
)
func (f *Factory) NewCmdExec(cmdIn io.Reader, cmdOut, cmdErr io.Writer) *cobra.Command {
flags := &struct {
pod string
container string
stdin bool
tty bool
}{}
cmd := &cobra.Command{
Use: "exec -p <pod> -c <container> -- <command> [<args...>]",
Short: "Execute a command in a container.",
Long: `Execute a command in a container.
Examples:
$ kubectl exec -p 123456-7890 -c ruby-container date
<returns output from running 'date' in ruby-container from pod 123456-7890>
$ kubectl exec -p 123456-7890 -c ruby-container -i -t -- bash -il
<switches to raw terminal mode, sends stdin to 'bash' in ruby-container from
pod 123456-780 and sends stdout/stderr from 'bash' back to the client`,
Run: func(cmd *cobra.Command, args []string) {
if len(flags.pod) == 0 {
usageError(cmd, "<pod> is required for exec")
}
if len(args) < 1 {
usageError(cmd, "<command> is required for exec")
}
namespace, err := f.DefaultNamespace(cmd)
checkErr(err)
client, err := f.Client(cmd)
checkErr(err)
pod, err := client.Pods(namespace).Get(flags.pod)
checkErr(err)
if pod.Status.Phase != api.PodRunning {
glog.Fatalf("Unable to execute command because pod is not running. Current status=%v", pod.Status.Phase)
}
if len(flags.container) == 0 {
flags.container = pod.Spec.Containers[0].Name
}
var stdin io.Reader
if util.GetFlagBool(cmd, "stdin") {
stdin = cmdIn
if flags.tty {
if file, ok := cmdIn.(*os.File); ok {
inFd := file.Fd()
if term.IsTerminal(inFd) {
oldState, err := term.SetRawTerminal(inFd)
if err != nil {
glog.Fatal(err)
}
// this handles a clean exit, where the command finished
defer term.RestoreTerminal(inFd, oldState)
// SIGINT is handled by term.SetRawTerminal (it runs a goroutine that listens
// for SIGINT and restores the terminal before exiting)
// this handles SIGTERM
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGTERM)
go func() {
<-sigChan
term.RestoreTerminal(inFd, oldState)
os.Exit(0)
}()
} else {
glog.Warning("Stdin is not a terminal")
}
} else {
flags.tty = false
glog.Warning("Unable to use a TTY")
}
}
}
config, err := f.ClientConfig(cmd)
checkErr(err)
req := client.RESTClient.Get().
Prefix("proxy").
Resource("minions").
Name(pod.Status.Host).
Suffix("exec", namespace, flags.pod, flags.container)
e := remotecommand.New(req, config, args, stdin, cmdOut, cmdErr, flags.tty)
err = e.Execute()
checkErr(err)
},
}
cmd.Flags().StringVarP(&flags.pod, "pod", "p", "", "Pod name")
// TODO support UID
cmd.Flags().StringVarP(&flags.container, "container", "c", "", "Container name")
cmd.Flags().BoolVarP(&flags.stdin, "stdin", "i", false, "Pass stdin to the container")
cmd.Flags().BoolVarP(&flags.tty, "tty", "t", false, "Stdin is a TTY")
return cmd
}

View File

@@ -0,0 +1,104 @@
/*
Copyright 2014 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cmd
import (
"os"
"os/signal"
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
"github.com/GoogleCloudPlatform/kubernetes/pkg/client/portforward"
"github.com/golang/glog"
"github.com/spf13/cobra"
)
func (f *Factory) NewCmdPortForward() *cobra.Command {
flags := &struct {
pod string
container string
}{}
cmd := &cobra.Command{
Use: "port-forward -p <pod> [<local port>:]<remote port> [<port>...]",
Short: "Forward 1 or more local ports to a pod.",
Long: `Forward 1 or more local ports to a pod.
Examples:
$ kubectl port-forward -p mypod 5000 6000
<listens on ports 5000 and 6000 locally, forwarding data to/from ports 5000
and 6000 in the pod>
$ kubectl port-forward -p mypod 8888:5000
<listens on port 8888 locally, forwarding to 5000 in the pod>
$ kubectl port-forward -p mypod :5000
<listens on a random port locally, forwarding to 5000 in the pod>
$ kubectl port-forward -p mypod 0:5000
<listens on a random port locally, forwarding to 5000 in the pod>
`,
Run: func(cmd *cobra.Command, args []string) {
if len(flags.pod) == 0 {
usageError(cmd, "<pod> is required for exec")
}
if len(args) < 1 {
usageError(cmd, "at least 1 <port> is required for port-forward")
}
namespace, err := f.DefaultNamespace(cmd)
checkErr(err)
client, err := f.Client(cmd)
checkErr(err)
pod, err := client.Pods(namespace).Get(flags.pod)
checkErr(err)
if pod.Status.Phase != api.PodRunning {
glog.Fatalf("Unable to execute command because pod is not running. Current status=%v", pod.Status.Phase)
}
config, err := f.ClientConfig(cmd)
checkErr(err)
signals := make(chan os.Signal, 1)
signal.Notify(signals, os.Interrupt)
defer signal.Stop(signals)
stopCh := make(chan struct{}, 1)
go func() {
<-signals
close(stopCh)
}()
req := client.RESTClient.Get().
Prefix("proxy").
Resource("minions").
Name(pod.Status.Host).
Suffix("portForward", namespace, flags.pod)
pf, err := portforward.New(req, config, args, stopCh)
checkErr(err)
err = pf.ForwardPorts()
checkErr(err)
},
}
cmd.Flags().StringVarP(&flags.pod, "pod", "p", "", "Pod name")
// TODO support UID
return cmd
}

View File

@@ -198,6 +198,127 @@ func (d *dockerContainerCommandRunner) RunInContainer(containerID string, cmd []
return buf.Bytes(), <-errChan
}
// ExecInContainer uses nsenter to run the command inside the container identified by containerID.
//
// TODO:
// - match cgroups of container
// - should we support `docker exec`?
// - should we support nsenter in a container, running with elevated privs and --pid=host?
func (d *dockerContainerCommandRunner) ExecInContainer(containerId string, cmd []string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool) error {
container, err := d.client.InspectContainer(containerId)
if err != nil {
return err
}
if !container.State.Running {
return fmt.Errorf("container not running (%s)", container)
}
containerPid := container.State.Pid
// TODO what if the container doesn't have `env`???
args := []string{"-t", fmt.Sprintf("%d", containerPid), "-m", "-i", "-u", "-n", "-p", "--", "env", "-i"}
args = append(args, fmt.Sprintf("HOSTNAME=%s", container.Config.Hostname))
args = append(args, container.Config.Env...)
args = append(args, cmd...)
glog.Infof("ARGS %#v", args)
command := exec.Command("nsenter", args...)
// TODO use exec.LookPath
if tty {
p, err := StartPty(command)
if err != nil {
return err
}
defer p.Close()
// make sure to close the stdout stream
defer stdout.Close()
if stdin != nil {
go io.Copy(p, stdin)
}
if stdout != nil {
go io.Copy(stdout, p)
}
return command.Wait()
} else {
cp := func(dst io.WriteCloser, src io.Reader, closeDst bool) {
defer func() {
if closeDst {
dst.Close()
}
}()
io.Copy(dst, src)
}
if stdin != nil {
inPipe, err := command.StdinPipe()
if err != nil {
return err
}
go func() {
cp(inPipe, stdin, false)
inPipe.Close()
}()
}
if stdout != nil {
outPipe, err := command.StdoutPipe()
if err != nil {
return err
}
go cp(stdout, outPipe, true)
}
if stderr != nil {
errPipe, err := command.StderrPipe()
if err != nil {
return err
}
go cp(stderr, errPipe, true)
}
return command.Run()
}
}
// PortForward executes socat in the pod's network namespace and copies
// data between stream (representing the user's local connection on their
// computer) and the specified port in the container.
//
// TODO:
// - match cgroups of container
// - should we support nsenter + socat on the host? (current impl)
// - should we support nsenter + socat in a container, running with elevated privs and --pid=host?
func (d *dockerContainerCommandRunner) PortForward(podInfraContainerID string, port uint16, stream io.ReadWriteCloser) error {
container, err := d.client.InspectContainer(podInfraContainerID)
if err != nil {
return err
}
if !container.State.Running {
return fmt.Errorf("container not running (%s)", container)
}
containerPid := container.State.Pid
// TODO use exec.LookPath for socat / what if the host doesn't have it???
args := []string{"-t", fmt.Sprintf("%d", containerPid), "-n", "socat", "-", fmt.Sprintf("TCP4:localhost:%d", port)}
// TODO use exec.LookPath
command := exec.Command("nsenter", args...)
in, err := command.StdinPipe()
if err != nil {
return err
}
out, err := command.StdoutPipe()
if err != nil {
return err
}
go io.Copy(in, stream)
go io.Copy(stream, out)
return command.Run()
}
// NewDockerContainerCommandRunner creates a ContainerCommandRunner which uses nsinit to run a command
// inside a container.
func NewDockerContainerCommandRunner(client DockerInterface) ContainerCommandRunner {
@@ -690,4 +811,6 @@ func ConnectToDockerOrDie(dockerEndpoint string) DockerInterface {
type ContainerCommandRunner interface {
RunInContainer(containerID string, cmd []string) ([]byte, error)
GetDockerServerVersion() ([]uint, error)
ExecInContainer(containerID string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error
PortForward(podInfraContainerID string, port uint16, stream io.ReadWriteCloser) error
}

View File

@@ -0,0 +1,30 @@
// +build linux
/*
Copyright 2015 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package dockertools
import (
"os"
"os/exec"
"github.com/kr/pty"
)
func StartPty(c *exec.Cmd) (*os.File, error) {
return pty.Start(c)
}

View File

@@ -0,0 +1,28 @@
// +build !linux
/*
Copyright 2015 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package dockertools
import (
"os"
"os/exec"
)
func StartPty(c *exec.Cmd) (pty *os.File, err error) {
return nil, nil
}

View File

@@ -86,7 +86,8 @@ func NewMainKubelet(
clusterDomain string,
clusterDNS net.IP,
masterServiceNamespace string,
volumePlugins []volume.Plugin) (*Kubelet, error) {
volumePlugins []volume.Plugin,
streamingConnectionIdleTimeout time.Duration) (*Kubelet, error) {
if rootDirectory == "" {
return nil, fmt.Errorf("invalid root directory %q", rootDirectory)
}
@@ -104,28 +105,29 @@ func NewMainKubelet(
serviceLister := &cache.StoreToServiceLister{serviceStore}
klet := &Kubelet{
hostname: hostname,
dockerClient: dockerClient,
etcdClient: etcdClient,
kubeClient: kubeClient,
rootDirectory: rootDirectory,
resyncInterval: resyncInterval,
podInfraContainerImage: podInfraContainerImage,
podWorkers: newPodWorkers(),
dockerIDToRef: map[dockertools.DockerID]*api.ObjectReference{},
runner: dockertools.NewDockerContainerCommandRunner(dockerClient),
httpClient: &http.Client{},
pullQPS: pullQPS,
pullBurst: pullBurst,
minimumGCAge: minimumGCAge,
maxContainerCount: maxContainerCount,
sourceReady: sourceReady,
clusterDomain: clusterDomain,
clusterDNS: clusterDNS,
serviceLister: serviceLister,
masterServiceNamespace: masterServiceNamespace,
prober: newProbeHolder(),
readiness: newReadinessStates(),
hostname: hostname,
dockerClient: dockerClient,
etcdClient: etcdClient,
kubeClient: kubeClient,
rootDirectory: rootDirectory,
resyncInterval: resyncInterval,
podInfraContainerImage: podInfraContainerImage,
podWorkers: newPodWorkers(),
dockerIDToRef: map[dockertools.DockerID]*api.ObjectReference{},
runner: dockertools.NewDockerContainerCommandRunner(dockerClient),
httpClient: &http.Client{},
pullQPS: pullQPS,
pullBurst: pullBurst,
minimumGCAge: minimumGCAge,
maxContainerCount: maxContainerCount,
sourceReady: sourceReady,
clusterDomain: clusterDomain,
clusterDNS: clusterDNS,
serviceLister: serviceLister,
masterServiceNamespace: masterServiceNamespace,
prober: newProbeHolder(),
readiness: newReadinessStates(),
streamingConnectionIdleTimeout: streamingConnectionIdleTimeout,
}
if err := klet.setupDataDirs(); err != nil {
@@ -207,6 +209,10 @@ type Kubelet struct {
prober probeHolder
// container readiness state holder
readiness *readinessStates
// how long to keep idle streaming command execution/port forwarding
// connections open before terminating them
streamingConnectionIdleTimeout time.Duration
}
// getRootDir returns the full path to the directory under which kubelet can
@@ -1686,6 +1692,40 @@ func (kl *Kubelet) RunInContainer(podFullName string, uid types.UID, container s
return kl.runner.RunInContainer(dockerContainer.ID, cmd)
}
// ExecInContainer executes a command in a container, connecting the supplied
// stdin/stdout/stderr to the command's IO streams.
func (kl *Kubelet) ExecInContainer(podFullName string, uid types.UID, container string, cmd []string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool) error {
if kl.runner == nil {
return fmt.Errorf("no runner specified.")
}
dockerContainers, err := dockertools.GetKubeletDockerContainers(kl.dockerClient, false)
if err != nil {
return err
}
dockerContainer, found, _ := dockerContainers.FindPodContainer(podFullName, uid, container)
if !found {
return fmt.Errorf("container not found (%q)", container)
}
return kl.runner.ExecInContainer(dockerContainer.ID, cmd, stdin, stdout, stderr, tty)
}
// PortForward connects to the pod's port and copies data between the port
// and the stream.
func (kl *Kubelet) PortForward(podFullName string, uid types.UID, port uint16, stream io.ReadWriteCloser) error {
if kl.runner == nil {
return fmt.Errorf("no runner specified.")
}
dockerContainers, err := dockertools.GetKubeletDockerContainers(kl.dockerClient, false)
if err != nil {
return err
}
podInfraContainer, found, _ := dockerContainers.FindPodContainer(podFullName, uid, dockertools.PodInfraContainerName)
if !found {
return fmt.Errorf("Unable to find pod infra container for pod %s, uid %v", podFullName, uid)
}
return kl.runner.PortForward(podInfraContainer.ID, port, stream)
}
// BirthCry sends an event that the kubelet has started up.
func (kl *Kubelet) BirthCry() {
// Make an event that kubelet restarted.
@@ -1699,3 +1739,7 @@ func (kl *Kubelet) BirthCry() {
}
record.Eventf(ref, "starting", "Starting kubelet.")
}
func (kl *Kubelet) StreamingConnectionIdleTimeout() time.Duration {
return kl.streamingConnectionIdleTimeout
}

View File

@@ -17,7 +17,9 @@ limitations under the License.
package kubelet
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
@@ -1486,9 +1488,15 @@ func TestGetContainerInfoWithNoMatchingContainers(t *testing.T) {
}
type fakeContainerCommandRunner struct {
Cmd []string
ID string
E error
Cmd []string
ID string
E error
Stdin io.Reader
Stdout io.WriteCloser
Stderr io.WriteCloser
TTY bool
Port uint16
Stream io.ReadWriteCloser
}
func (f *fakeContainerCommandRunner) RunInContainer(id string, cmd []string) ([]byte, error) {
@@ -1501,6 +1509,23 @@ func (f *fakeContainerCommandRunner) GetDockerServerVersion() ([]uint, error) {
return nil, nil
}
func (f *fakeContainerCommandRunner) ExecInContainer(id string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error {
f.Cmd = cmd
f.ID = id
f.Stdin = in
f.Stdout = out
f.Stderr = err
f.TTY = tty
return f.E
}
func (f *fakeContainerCommandRunner) PortForward(podInfraContainerID string, port uint16, stream io.ReadWriteCloser) error {
f.ID = podInfraContainerID
f.Port = port
f.Stream = stream
return nil
}
func TestRunInContainerNoSuchPod(t *testing.T) {
fakeCommandRunner := fakeContainerCommandRunner{}
kubelet, fakeDocker := newTestKubelet(t)
@@ -2805,5 +2830,252 @@ func TestGetPodReadyCondition(t *testing.T) {
t.Errorf("On test case %v, expected:\n%+v\ngot\n%+v\n", i, test.expected, condition)
}
}
}
func TestExecInContainerNoSuchPod(t *testing.T) {
fakeCommandRunner := fakeContainerCommandRunner{}
kubelet, fakeDocker := newTestKubelet(t)
fakeDocker.ContainerList = []docker.APIContainers{}
kubelet.runner = &fakeCommandRunner
podName := "podFoo"
podNamespace := "etcd"
containerName := "containerFoo"
err := kubelet.ExecInContainer(
GetPodFullName(&api.BoundPod{ObjectMeta: api.ObjectMeta{Name: podName, Namespace: podNamespace}}),
"",
containerName,
[]string{"ls"},
nil,
nil,
nil,
false,
)
if err == nil {
t.Fatal("unexpected non-error")
}
if fakeCommandRunner.ID != "" {
t.Fatal("unexpected invocation of runner.ExecInContainer")
}
}
func TestExecInContainerNoSuchContainer(t *testing.T) {
fakeCommandRunner := fakeContainerCommandRunner{}
kubelet, fakeDocker := newTestKubelet(t)
kubelet.runner = &fakeCommandRunner
podName := "podFoo"
podNamespace := "etcd"
containerID := "containerFoo"
fakeDocker.ContainerList = []docker.APIContainers{
{
ID: "notfound",
Names: []string{"/k8s_notfound_" + podName + "." + podNamespace + ".test_12345678_42"},
},
}
err := kubelet.ExecInContainer(
GetPodFullName(&api.BoundPod{ObjectMeta: api.ObjectMeta{
UID: "12345678",
Name: podName,
Namespace: podNamespace,
Annotations: map[string]string{ConfigSourceAnnotationKey: "test"},
}}),
"",
containerID,
[]string{"ls"},
nil,
nil,
nil,
false,
)
if err == nil {
t.Fatal("unexpected non-error")
}
if fakeCommandRunner.ID != "" {
t.Fatal("unexpected invocation of runner.ExecInContainer")
}
}
type fakeReadWriteCloser struct{}
func (f *fakeReadWriteCloser) Write(data []byte) (int, error) {
return 0, nil
}
func (f *fakeReadWriteCloser) Read(data []byte) (int, error) {
return 0, nil
}
func (f *fakeReadWriteCloser) Close() error {
return nil
}
func TestExecInContainer(t *testing.T) {
fakeCommandRunner := fakeContainerCommandRunner{}
kubelet, fakeDocker := newTestKubelet(t)
kubelet.runner = &fakeCommandRunner
podName := "podFoo"
podNamespace := "etcd"
containerID := "containerFoo"
command := []string{"ls"}
stdin := &bytes.Buffer{}
stdout := &fakeReadWriteCloser{}
stderr := &fakeReadWriteCloser{}
tty := true
fakeDocker.ContainerList = []docker.APIContainers{
{
ID: containerID,
Names: []string{"/k8s_" + containerID + "_" + podName + "." + podNamespace + ".test_12345678_42"},
},
}
err := kubelet.ExecInContainer(
GetPodFullName(&api.BoundPod{ObjectMeta: api.ObjectMeta{
UID: "12345678",
Name: podName,
Namespace: podNamespace,
Annotations: map[string]string{ConfigSourceAnnotationKey: "test"},
}}),
"",
containerID,
[]string{"ls"},
stdin,
stdout,
stderr,
tty,
)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if e, a := containerID, fakeCommandRunner.ID; e != a {
t.Fatalf("container id: expected %s, got %s", e, a)
}
if e, a := command, fakeCommandRunner.Cmd; !reflect.DeepEqual(e, a) {
t.Fatalf("command: expected '%v', got '%v'", e, a)
}
if e, a := stdin, fakeCommandRunner.Stdin; e != a {
t.Fatalf("stdin: expected %#v, got %#v", e, a)
}
if e, a := stdout, fakeCommandRunner.Stdout; e != a {
t.Fatalf("stdout: expected %#v, got %#v", e, a)
}
if e, a := stderr, fakeCommandRunner.Stderr; e != a {
t.Fatalf("stderr: expected %#v, got %#v", e, a)
}
if e, a := tty, fakeCommandRunner.TTY; e != a {
t.Fatalf("tty: expected %t, got %t", e, a)
}
}
func TestPortForwardNoSuchPod(t *testing.T) {
fakeCommandRunner := fakeContainerCommandRunner{}
kubelet, fakeDocker := newTestKubelet(t)
fakeDocker.ContainerList = []docker.APIContainers{}
kubelet.runner = &fakeCommandRunner
podName := "podFoo"
podNamespace := "etcd"
var port uint16 = 5000
err := kubelet.PortForward(
GetPodFullName(&api.BoundPod{ObjectMeta: api.ObjectMeta{Name: podName, Namespace: podNamespace}}),
"",
port,
nil,
)
if err == nil {
t.Fatal("unexpected non-error")
}
if fakeCommandRunner.ID != "" {
t.Fatal("unexpected invocation of runner.PortForward")
}
}
func TestPortForwardNoSuchContainer(t *testing.T) {
fakeCommandRunner := fakeContainerCommandRunner{}
kubelet, fakeDocker := newTestKubelet(t)
kubelet.runner = &fakeCommandRunner
podName := "podFoo"
podNamespace := "etcd"
var port uint16 = 5000
fakeDocker.ContainerList = []docker.APIContainers{
{
ID: "notfound",
Names: []string{"/k8s_notfound_" + podName + "." + podNamespace + ".test_12345678_42"},
},
}
err := kubelet.PortForward(
GetPodFullName(&api.BoundPod{ObjectMeta: api.ObjectMeta{
UID: "12345678",
Name: podName,
Namespace: podNamespace,
Annotations: map[string]string{ConfigSourceAnnotationKey: "test"},
}}),
"",
port,
nil,
)
if err == nil {
t.Fatal("unexpected non-error")
}
if fakeCommandRunner.ID != "" {
t.Fatal("unexpected invocation of runner.PortForward")
}
}
func TestPortForward(t *testing.T) {
fakeCommandRunner := fakeContainerCommandRunner{}
kubelet, fakeDocker := newTestKubelet(t)
kubelet.runner = &fakeCommandRunner
podName := "podFoo"
podNamespace := "etcd"
containerID := "containerFoo"
var port uint16 = 5000
stream := &fakeReadWriteCloser{}
infraContainerID := "infra"
kubelet.podInfraContainerImage = "POD"
fakeDocker.ContainerList = []docker.APIContainers{
{
ID: infraContainerID,
Names: []string{"/k8s_" + kubelet.podInfraContainerImage + "_" + podName + "." + podNamespace + ".test_12345678_42"},
},
{
ID: containerID,
Names: []string{"/k8s_" + containerID + "_" + podName + "." + podNamespace + ".test_12345678_42"},
},
}
err := kubelet.PortForward(
GetPodFullName(&api.BoundPod{ObjectMeta: api.ObjectMeta{
UID: "12345678",
Name: podName,
Namespace: podNamespace,
Annotations: map[string]string{ConfigSourceAnnotationKey: "test"},
}}),
"",
port,
stream,
)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if e, a := infraContainerID, fakeCommandRunner.ID; e != a {
t.Fatalf("container id: expected %s, got %s", e, a)
}
if e, a := port, fakeCommandRunner.Port; e != a {
t.Fatalf("port: expected %v, got %v", e, a)
}
if e, a := stream, fakeCommandRunner.Stream; e != a {
t.Fatalf("stream: expected %v, got %v", e, a)
}
}

View File

@@ -27,6 +27,7 @@ import (
"path"
"strconv"
"strings"
"sync"
"time"
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
@@ -34,6 +35,8 @@ import (
"github.com/GoogleCloudPlatform/kubernetes/pkg/httplog"
"github.com/GoogleCloudPlatform/kubernetes/pkg/runtime"
"github.com/GoogleCloudPlatform/kubernetes/pkg/types"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream/spdy"
"github.com/golang/glog"
"github.com/google/cadvisor/info"
)
@@ -69,8 +72,11 @@ type HostInterface interface {
GetPodByName(namespace, name string) (*api.BoundPod, bool)
GetPodStatus(name string, uid types.UID) (api.PodStatus, error)
RunInContainer(name string, uid types.UID, container string, cmd []string) ([]byte, error)
ExecInContainer(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error
GetKubeletContainerLogs(podFullName, containerName, tail string, follow bool, stdout, stderr io.Writer) error
ServeLogs(w http.ResponseWriter, req *http.Request)
PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error
StreamingConnectionIdleTimeout() time.Duration
}
// NewServer initializes and configures a kubelet.Server object to handle HTTP requests.
@@ -99,6 +105,8 @@ func (s *Server) InstallDefaultHandlers() {
// InstallDeguggingHandlers registers the HTTP request patterns that serve logs or run commands/containers
func (s *Server) InstallDebuggingHandlers() {
s.mux.HandleFunc("/run/", s.handleRun)
s.mux.HandleFunc("/exec/", s.handleExec)
s.mux.HandleFunc("/portForward/", s.handlePortForward)
s.mux.HandleFunc("/logs/", s.handleLogs)
s.mux.HandleFunc("/containerLogs/", s.handleContainerLogs)
@@ -301,6 +309,28 @@ func (s *Server) handleSpec(w http.ResponseWriter, req *http.Request) {
w.Write(data)
}
func parseContainerCoordinates(path string) (namespace, pod string, uid types.UID, container string, err error) {
parts := strings.Split(path, "/")
if len(parts) == 5 {
namespace = parts[2]
pod = parts[3]
container = parts[4]
return
}
if len(parts) == 6 {
namespace = parts[2]
pod = parts[3]
uid = types.UID(parts[4])
container = parts[5]
return
}
err = fmt.Errorf("Unexpected path %s. Expected /.../.../<namespace>/<pod>/<container> or /.../.../<namespace>/<pod>/<uid>/<container>", path)
return
}
// handleRun handles requests to run a command inside a container.
func (s *Server) handleRun(w http.ResponseWriter, req *http.Request) {
u, err := url.ParseRequestURI(req.RequestURI)
@@ -308,20 +338,9 @@ func (s *Server) handleRun(w http.ResponseWriter, req *http.Request) {
s.error(w, err)
return
}
parts := strings.Split(u.Path, "/")
var podNamespace, podID, container string
var uid types.UID
if len(parts) == 5 {
podNamespace = parts[2]
podID = parts[3]
container = parts[4]
} else if len(parts) == 6 {
podNamespace = parts[2]
podID = parts[3]
uid = types.UID(parts[4])
container = parts[5]
} else {
http.Error(w, "Unexpected path for command running", http.StatusBadRequest)
podNamespace, podID, uid, container, err := parseContainerCoordinates(u.Path)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
pod, ok := s.host.GetPodByName(podNamespace, podID)
@@ -339,6 +358,227 @@ func (s *Server) handleRun(w http.ResponseWriter, req *http.Request) {
w.Write(data)
}
// handleExec handles requests to run a command inside a container.
func (s *Server) handleExec(w http.ResponseWriter, req *http.Request) {
u, err := url.ParseRequestURI(req.RequestURI)
if err != nil {
s.error(w, err)
return
}
podNamespace, podID, uid, container, err := parseContainerCoordinates(u.Path)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
pod, ok := s.host.GetPodByName(podNamespace, podID)
if !ok {
http.Error(w, "Pod does not exist", http.StatusNotFound)
return
}
req.ParseForm()
// start at 1 for error stream
expectedStreams := 1
if req.FormValue(api.ExecStdinParam) == "1" {
expectedStreams++
}
if req.FormValue(api.ExecStdoutParam) == "1" {
expectedStreams++
}
tty := req.FormValue(api.ExecTTYParam) == "1"
if !tty && req.FormValue(api.ExecStderrParam) == "1" {
expectedStreams++
}
if expectedStreams == 1 {
http.Error(w, "You must specify at least 1 of stdin, stdout, stderr", http.StatusBadRequest)
return
}
streamCh := make(chan httpstream.Stream)
upgrader := spdy.NewResponseUpgrader()
conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream) error {
streamCh <- stream
return nil
})
// from this point on, we can no longer call methods on w
if conn == nil {
// The upgrader is responsible for notifying the client of any errors that
// occurred during upgrading. All we can do is return here at this point
// if we weren't successful in upgrading.
return
}
defer conn.Close()
conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout())
// TODO find a good default timeout value
// TODO make it configurable?
expired := time.NewTimer(2 * time.Second)
var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream
receivedStreams := 0
WaitForStreams:
for {
select {
case stream := <-streamCh:
streamType := stream.Headers().Get(api.StreamType)
switch streamType {
case api.StreamTypeError:
errorStream = stream
defer errorStream.Reset()
receivedStreams++
case api.StreamTypeStdin:
stdinStream = stream
receivedStreams++
case api.StreamTypeStdout:
stdoutStream = stream
receivedStreams++
case api.StreamTypeStderr:
stderrStream = stream
receivedStreams++
default:
glog.Errorf("Unexpected stream type: '%s'", streamType)
}
if receivedStreams == expectedStreams {
break WaitForStreams
}
case <-expired.C:
// TODO find a way to return the error to the user. Maybe use a separate
// stream to report errors?
glog.Error("Timed out waiting for client to create streams")
return
}
}
if stdinStream != nil {
// close our half of the input stream, since we won't be writing to it
stdinStream.Close()
}
err = s.host.ExecInContainer(GetPodFullName(pod), uid, container, u.Query()[api.ExecCommandParamm], stdinStream, stdoutStream, stderrStream, tty)
if err != nil {
msg := fmt.Sprintf("Error executing command in container: %v", err)
glog.Error(msg)
errorStream.Write([]byte(msg))
}
}
func parsePodCoordinates(path string) (namespace, pod string, uid types.UID, err error) {
parts := strings.Split(path, "/")
if len(parts) == 4 {
namespace = parts[2]
pod = parts[3]
return
}
if len(parts) == 5 {
namespace = parts[2]
pod = parts[3]
uid = types.UID(parts[4])
return
}
err = fmt.Errorf("Unexpected path %s. Expected /.../.../<namespace>/<pod> or /.../.../<namespace>/<pod>/<uid>", path)
return
}
func (s *Server) handlePortForward(w http.ResponseWriter, req *http.Request) {
u, err := url.ParseRequestURI(req.RequestURI)
if err != nil {
s.error(w, err)
return
}
podNamespace, podID, uid, err := parsePodCoordinates(u.Path)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
pod, ok := s.host.GetPodByName(podNamespace, podID)
if !ok {
http.Error(w, "Pod does not exist", http.StatusNotFound)
return
}
streamChan := make(chan httpstream.Stream, 1)
upgrader := spdy.NewResponseUpgrader()
conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream) error {
portString := stream.Headers().Get(api.PortHeader)
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return fmt.Errorf("Unable to parse '%s' as a port: %v", portString, err)
}
if port < 1 {
return fmt.Errorf("Port '%d' must be greater than 0", port)
}
streamChan <- stream
return nil
})
if conn == nil {
return
}
defer conn.Close()
conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout())
var dataStreamLock sync.Mutex
dataStreamChans := make(map[string]chan httpstream.Stream)
Loop:
for {
select {
case <-conn.CloseChan():
break Loop
case stream := <-streamChan:
streamType := stream.Headers().Get(api.StreamType)
port := stream.Headers().Get(api.PortHeader)
dataStreamLock.Lock()
switch streamType {
case "error":
ch := make(chan httpstream.Stream)
dataStreamChans[port] = ch
go waitForPortForwardDataStreamAndRun(GetPodFullName(pod), uid, stream, ch, s.host)
case "data":
ch, ok := dataStreamChans[port]
if ok {
ch <- stream
delete(dataStreamChans, port)
} else {
glog.Errorf("Unable to locate data stream channel for port %s", port)
}
default:
glog.Errorf("streamType header must be 'error' or 'data', got: '%s'", streamType)
stream.Reset()
}
dataStreamLock.Unlock()
}
}
}
func waitForPortForwardDataStreamAndRun(pod string, uid types.UID, errorStream httpstream.Stream, dataStreamChan chan httpstream.Stream, host HostInterface) {
defer errorStream.Reset()
var dataStream httpstream.Stream
select {
case dataStream = <-dataStreamChan:
case <-time.After(1 * time.Second):
errorStream.Write([]byte("Timed out waiting for data stream"))
//TODO delete from dataStreamChans[port]
return
}
portString := dataStream.Headers().Get(api.PortHeader)
port, _ := strconv.ParseUint(portString, 10, 16)
err := host.PortForward(pod, uid, uint16(port), dataStream)
if err != nil {
msg := fmt.Errorf("Error forwarding port %d to pod %s, uid %v: %v", port, pod, uid, err)
glog.Error(msg)
errorStream.Write([]byte(msg.Error()))
}
}
// ServeHTTP responds to HTTP requests on the Kubelet.
func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
defer httplog.NewLogged(req, &w).StacktraceWhen(
@@ -347,6 +587,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
http.StatusMovedPermanently,
http.StatusTemporaryRedirect,
http.StatusNotFound,
http.StatusSwitchingProtocols,
),
).Log()
s.mux.ServeHTTP(w, req)

View File

@@ -46,35 +46,36 @@ const defaultRootDir = "/var/lib/kubelet"
// KubeletServer encapsulates all of the parameters necessary for starting up
// a kubelet. These can either be set via command line or directly.
type KubeletServer struct {
Config string
SyncFrequency time.Duration
FileCheckFrequency time.Duration
HTTPCheckFrequency time.Duration
ManifestURL string
EnableServer bool
Address util.IP
Port uint
HostnameOverride string
PodInfraContainerImage string
DockerEndpoint string
EtcdServerList util.StringList
EtcdConfigFile string
RootDirectory string
AllowPrivileged bool
RegistryPullQPS float64
RegistryBurst int
RunOnce bool
EnableDebuggingHandlers bool
MinimumGCAge time.Duration
MaxContainerCount int
AuthPath string
CAdvisorPort uint
OOMScoreAdj int
APIServerList util.StringList
ClusterDomain string
MasterServiceNamespace string
ClusterDNS util.IP
ReallyCrashForTesting bool
Config string
SyncFrequency time.Duration
FileCheckFrequency time.Duration
HTTPCheckFrequency time.Duration
ManifestURL string
EnableServer bool
Address util.IP
Port uint
HostnameOverride string
PodInfraContainerImage string
DockerEndpoint string
EtcdServerList util.StringList
EtcdConfigFile string
RootDirectory string
AllowPrivileged bool
RegistryPullQPS float64
RegistryBurst int
RunOnce bool
EnableDebuggingHandlers bool
MinimumGCAge time.Duration
MaxContainerCount int
AuthPath string
CAdvisorPort uint
OOMScoreAdj int
APIServerList util.StringList
ClusterDomain string
MasterServiceNamespace string
ClusterDNS util.IP
ReallyCrashForTesting bool
StreamingConnectionIdleTimeout time.Duration
}
// NewKubeletServer will create a new KubeletServer with default values.
@@ -149,6 +150,7 @@ func (s *KubeletServer) AddFlags(fs *pflag.FlagSet) {
fs.StringVar(&s.MasterServiceNamespace, "master_service_namespace", s.MasterServiceNamespace, "The namespace from which the kubernetes master services should be injected into pods")
fs.Var(&s.ClusterDNS, "cluster_dns", "IP address for a cluster DNS server. If set, kubelet will configure all containers to use this for DNS resolution in addition to the host's DNS servers")
fs.BoolVar(&s.ReallyCrashForTesting, "really_crash_for_testing", s.ReallyCrashForTesting, "If true, crash with panics more often.")
fs.DurationVar(&s.StreamingConnectionIdleTimeout, "streaming_connection_idle_timeout", 0, "Maximum time a streaming connection can be idle before the connection is automatically closed. Example: '5m'")
}
// Run runs the specified KubeletServer. This should never exit.
@@ -184,32 +186,33 @@ func (s *KubeletServer) Run(_ []string) error {
credentialprovider.SetPreferredDockercfgPath(s.RootDirectory)
kcfg := KubeletConfig{
Address: s.Address,
AllowPrivileged: s.AllowPrivileged,
HostnameOverride: s.HostnameOverride,
RootDirectory: s.RootDirectory,
ConfigFile: s.Config,
ManifestURL: s.ManifestURL,
FileCheckFrequency: s.FileCheckFrequency,
HTTPCheckFrequency: s.HTTPCheckFrequency,
PodInfraContainerImage: s.PodInfraContainerImage,
SyncFrequency: s.SyncFrequency,
RegistryPullQPS: s.RegistryPullQPS,
RegistryBurst: s.RegistryBurst,
MinimumGCAge: s.MinimumGCAge,
MaxContainerCount: s.MaxContainerCount,
ClusterDomain: s.ClusterDomain,
ClusterDNS: s.ClusterDNS,
Runonce: s.RunOnce,
Port: s.Port,
CAdvisorPort: s.CAdvisorPort,
EnableServer: s.EnableServer,
EnableDebuggingHandlers: s.EnableDebuggingHandlers,
DockerClient: dockertools.ConnectToDockerOrDie(s.DockerEndpoint),
KubeClient: client,
EtcdClient: kubelet.EtcdClientOrDie(s.EtcdServerList, s.EtcdConfigFile),
MasterServiceNamespace: s.MasterServiceNamespace,
VolumePlugins: ProbeVolumePlugins(),
Address: s.Address,
AllowPrivileged: s.AllowPrivileged,
HostnameOverride: s.HostnameOverride,
RootDirectory: s.RootDirectory,
ConfigFile: s.Config,
ManifestURL: s.ManifestURL,
FileCheckFrequency: s.FileCheckFrequency,
HTTPCheckFrequency: s.HTTPCheckFrequency,
PodInfraContainerImage: s.PodInfraContainerImage,
SyncFrequency: s.SyncFrequency,
RegistryPullQPS: s.RegistryPullQPS,
RegistryBurst: s.RegistryBurst,
MinimumGCAge: s.MinimumGCAge,
MaxContainerCount: s.MaxContainerCount,
ClusterDomain: s.ClusterDomain,
ClusterDNS: s.ClusterDNS,
Runonce: s.RunOnce,
Port: s.Port,
CAdvisorPort: s.CAdvisorPort,
EnableServer: s.EnableServer,
EnableDebuggingHandlers: s.EnableDebuggingHandlers,
DockerClient: dockertools.ConnectToDockerOrDie(s.DockerEndpoint),
KubeClient: client,
EtcdClient: kubelet.EtcdClientOrDie(s.EtcdServerList, s.EtcdConfigFile),
MasterServiceNamespace: s.MasterServiceNamespace,
VolumePlugins: ProbeVolumePlugins(),
StreamingConnectionIdleTimeout: s.StreamingConnectionIdleTimeout,
}
RunKubelet(&kcfg)
@@ -368,33 +371,34 @@ func makePodSourceConfig(kc *KubeletConfig) *config.PodConfig {
// KubeletConfig is all of the parameters necessary for running a kubelet.
// TODO: This should probably be merged with KubeletServer. The extra object is a consequence of refactoring.
type KubeletConfig struct {
EtcdClient tools.EtcdClient
KubeClient *client.Client
DockerClient dockertools.DockerInterface
CAdvisorPort uint
Address util.IP
AllowPrivileged bool
HostnameOverride string
RootDirectory string
ConfigFile string
ManifestURL string
FileCheckFrequency time.Duration
HTTPCheckFrequency time.Duration
Hostname string
PodInfraContainerImage string
SyncFrequency time.Duration
RegistryPullQPS float64
RegistryBurst int
MinimumGCAge time.Duration
MaxContainerCount int
ClusterDomain string
ClusterDNS util.IP
EnableServer bool
EnableDebuggingHandlers bool
Port uint
Runonce bool
MasterServiceNamespace string
VolumePlugins []volume.Plugin
EtcdClient tools.EtcdClient
KubeClient *client.Client
DockerClient dockertools.DockerInterface
CAdvisorPort uint
Address util.IP
AllowPrivileged bool
HostnameOverride string
RootDirectory string
ConfigFile string
ManifestURL string
FileCheckFrequency time.Duration
HTTPCheckFrequency time.Duration
Hostname string
PodInfraContainerImage string
SyncFrequency time.Duration
RegistryPullQPS float64
RegistryBurst int
MinimumGCAge time.Duration
MaxContainerCount int
ClusterDomain string
ClusterDNS util.IP
EnableServer bool
EnableDebuggingHandlers bool
Port uint
Runonce bool
MasterServiceNamespace string
VolumePlugins []volume.Plugin
StreamingConnectionIdleTimeout time.Duration
}
func createAndInitKubelet(kc *KubeletConfig, pc *config.PodConfig) (*kubelet.Kubelet, error) {
@@ -417,7 +421,8 @@ func createAndInitKubelet(kc *KubeletConfig, pc *config.PodConfig) (*kubelet.Kub
kc.ClusterDomain,
net.IP(kc.ClusterDNS),
kc.MasterServiceNamespace,
kc.VolumePlugins)
kc.VolumePlugins,
kc.StreamingConnectionIdleTimeout)
if err != nil {
return nil, err

View File

@@ -25,25 +25,32 @@ import (
"net/http/httptest"
"net/http/httputil"
"reflect"
"strconv"
"strings"
"testing"
"time"
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
"github.com/GoogleCloudPlatform/kubernetes/pkg/types"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream/spdy"
"github.com/google/cadvisor/info"
)
type fakeKubelet struct {
podByNameFunc func(namespace, name string) (*api.BoundPod, bool)
statusFunc func(name string) (api.PodStatus, error)
containerInfoFunc func(podFullName string, uid types.UID, containerName string, req *info.ContainerInfoRequest) (*info.ContainerInfo, error)
rootInfoFunc func(query *info.ContainerInfoRequest) (*info.ContainerInfo, error)
machineInfoFunc func() (*info.MachineInfo, error)
boundPodsFunc func() ([]api.BoundPod, error)
logFunc func(w http.ResponseWriter, req *http.Request)
runFunc func(podFullName string, uid types.UID, containerName string, cmd []string) ([]byte, error)
dockerVersionFunc func() ([]uint, error)
containerLogsFunc func(podFullName, containerName, tail string, follow bool, stdout, stderr io.Writer) error
podByNameFunc func(namespace, name string) (*api.BoundPod, bool)
statusFunc func(name string) (api.PodStatus, error)
containerInfoFunc func(podFullName string, uid types.UID, containerName string, req *info.ContainerInfoRequest) (*info.ContainerInfo, error)
rootInfoFunc func(query *info.ContainerInfoRequest) (*info.ContainerInfo, error)
machineInfoFunc func() (*info.MachineInfo, error)
boundPodsFunc func() ([]api.BoundPod, error)
logFunc func(w http.ResponseWriter, req *http.Request)
runFunc func(podFullName string, uid types.UID, containerName string, cmd []string) ([]byte, error)
dockerVersionFunc func() ([]uint, error)
execFunc func(pod string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error
portForwardFunc func(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error
containerLogsFunc func(podFullName, containerName, tail string, follow bool, stdout, stderr io.Writer) error
streamingConnectionIdleTimeoutFunc func() time.Duration
}
func (fk *fakeKubelet) GetPodByName(namespace, name string) (*api.BoundPod, bool) {
@@ -86,6 +93,18 @@ func (fk *fakeKubelet) RunInContainer(podFullName string, uid types.UID, contain
return fk.runFunc(podFullName, uid, containerName, cmd)
}
func (fk *fakeKubelet) ExecInContainer(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error {
return fk.execFunc(name, uid, container, cmd, in, out, err, tty)
}
func (fk *fakeKubelet) PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error {
return fk.portForwardFunc(name, uid, port, stream)
}
func (fk *fakeKubelet) StreamingConnectionIdleTimeout() time.Duration {
return fk.streamingConnectionIdleTimeoutFunc()
}
type serverTestFramework struct {
updateChan chan interface{}
updateReader *channelReader
@@ -542,3 +561,503 @@ func TestContainerLogsWithFollow(t *testing.T) {
t.Errorf("Expected: '%v', got: '%v'", output, result)
}
}
func TestServeExecInContainerIdleTimeout(t *testing.T) {
fw := newServerTest()
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
return 100 * time.Millisecond
}
idleSuccess := make(chan struct{})
fw.fakeKubelet.execFunc = func(podFullName string, uid types.UID, containerName string, cmd []string, in io.Reader, out, stderr io.WriteCloser, tty bool) error {
select {
case <-idleSuccess:
case <-time.After(150 * time.Millisecond):
t.Fatalf("execFunc timed out waiting for idle timeout")
}
return nil
}
podNamespace := "other"
podName := "foo"
expectedContainerName := "baz"
url := fw.testHTTPServer.URL + "/exec/" + podNamespace + "/" + podName + "/" + expectedContainerName + "?c=ls&c=-a&" + api.ExecStdinParam + "=1"
upgradeRoundTripper := spdy.NewRoundTripper(nil)
c := &http.Client{Transport: upgradeRoundTripper}
resp, err := c.Get(url)
if err != nil {
t.Fatalf("Got error GETing: %v", err)
}
defer resp.Body.Close()
conn, err := upgradeRoundTripper.NewConnection(resp)
if err != nil {
t.Fatalf("Unexpected error creating streaming connection: %s", err)
}
if conn == nil {
t.Fatal("Unexpected nil connection")
}
defer conn.Close()
h := http.Header{}
h.Set("type", "input")
stream, err := conn.CreateStream(h)
if err != nil {
t.Fatalf("error creating input stream: %v", err)
}
defer stream.Reset()
select {
case <-conn.CloseChan():
close(idleSuccess)
case <-time.After(150 * time.Millisecond):
t.Fatalf("Timed out waiting for connection closure due to idle timeout")
}
}
func TestServeExecInContainer(t *testing.T) {
tests := []struct {
stdin bool
stdout bool
stderr bool
tty bool
responseStatusCode int
uid bool
}{
{responseStatusCode: http.StatusBadRequest},
{stdin: true, responseStatusCode: http.StatusSwitchingProtocols},
{stdout: true, responseStatusCode: http.StatusSwitchingProtocols},
{stderr: true, responseStatusCode: http.StatusSwitchingProtocols},
{stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols},
{stdout: true, stderr: true, tty: true, responseStatusCode: http.StatusSwitchingProtocols},
{stdin: true, stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols},
}
for i, test := range tests {
fw := newServerTest()
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
return 0
}
podNamespace := "other"
podName := "foo"
expectedPodName := podName + "." + podNamespace + ".etcd"
expectedUid := "9b01b80f-8fb4-11e4-95ab-4200af06647"
expectedContainerName := "baz"
expectedCommand := "ls -a"
expectedStdin := "stdin"
expectedStdout := "stdout"
expectedStderr := "stderr"
execFuncDone := make(chan struct{})
clientStdoutReadDone := make(chan struct{})
clientStderrReadDone := make(chan struct{})
fw.fakeKubelet.execFunc = func(podFullName string, uid types.UID, containerName string, cmd []string, in io.Reader, out, stderr io.WriteCloser, tty bool) error {
defer close(execFuncDone)
if podFullName != expectedPodName {
t.Fatalf("%d: podFullName: expected %s, got %s", i, expectedPodName, podFullName)
}
if test.uid && string(uid) != expectedUid {
t.Fatalf("%d: uid: expected %v, got %v", i, expectedUid, uid)
}
if containerName != expectedContainerName {
t.Fatalf("%d: containerName: expected %s, got %s", i, expectedContainerName, containerName)
}
if strings.Join(cmd, " ") != expectedCommand {
t.Fatalf("%d: cmd: expected: %s, got %v", i, expectedCommand, cmd)
}
if test.stdin {
if in == nil {
t.Fatalf("%d: stdin: expected non-nil", i)
}
b := make([]byte, 10)
n, err := in.Read(b)
if err != nil {
t.Fatalf("%d: error reading from stdin: %v", i, err)
}
if e, a := expectedStdin, string(b[0:n]); e != a {
t.Fatalf("%d: stdin: expected to read %v, got %v", i, e, a)
}
} else if in != nil {
t.Fatalf("%d: stdin: expected nil: %#v", i, in)
}
if test.stdout {
if out == nil {
t.Fatalf("%d: stdout: expected non-nil", i)
}
_, err := out.Write([]byte(expectedStdout))
if err != nil {
t.Fatalf("%d:, error writing to stdout: %v", i, err)
}
out.Close()
select {
case <-clientStdoutReadDone:
case <-time.After(10 * time.Millisecond):
t.Fatalf("%d: timed out waiting for client to read stdout", i)
}
} else if out != nil {
t.Fatalf("%d: stdout: expected nil: %#v", i, out)
}
if tty {
if stderr != nil {
t.Fatalf("%d: tty set but received non-nil stderr: %v", i, stderr)
}
} else if test.stderr {
if stderr == nil {
t.Fatalf("%d: stderr: expected non-nil", i)
}
_, err := stderr.Write([]byte(expectedStderr))
if err != nil {
t.Fatalf("%d:, error writing to stderr: %v", i, err)
}
stderr.Close()
select {
case <-clientStderrReadDone:
case <-time.After(10 * time.Millisecond):
t.Fatalf("%d: timed out waiting for client to read stderr", i)
}
} else if stderr != nil {
t.Fatalf("%d: stderr: expected nil: %#v", i, stderr)
}
return nil
}
var url string
if test.uid {
url = fw.testHTTPServer.URL + "/exec/" + podNamespace + "/" + podName + "/" + expectedUid + "/" + expectedContainerName + "?command=ls&command=-a"
} else {
url = fw.testHTTPServer.URL + "/exec/" + podNamespace + "/" + podName + "/" + expectedContainerName + "?command=ls&command=-a"
}
if test.stdin {
url += "&" + api.ExecStdinParam + "=1"
}
if test.stdout {
url += "&" + api.ExecStdoutParam + "=1"
}
if test.stderr && !test.tty {
url += "&" + api.ExecStderrParam + "=1"
}
if test.tty {
url += "&" + api.ExecTTYParam + "=1"
}
var (
resp *http.Response
err error
upgradeRoundTripper httpstream.UpgradeRoundTripper
c *http.Client
)
if test.responseStatusCode != http.StatusSwitchingProtocols {
c = &http.Client{}
} else {
upgradeRoundTripper = spdy.NewRoundTripper(nil)
c = &http.Client{Transport: upgradeRoundTripper}
}
resp, err = c.Get(url)
if err != nil {
t.Fatalf("%d: Got error GETing: %v", i, err)
}
defer resp.Body.Close()
_, err = ioutil.ReadAll(resp.Body)
if err != nil {
t.Errorf("%d: Error reading response body: %v", i, err)
}
if e, a := test.responseStatusCode, resp.StatusCode; e != a {
t.Fatalf("%d: response status: expected %v, got %v", e, a)
}
if test.responseStatusCode != http.StatusSwitchingProtocols {
continue
}
conn, err := upgradeRoundTripper.NewConnection(resp)
if err != nil {
t.Fatalf("Unexpected error creating streaming connection: %s", err)
}
if conn == nil {
t.Fatalf("%d: unexpected nil conn", i)
}
defer conn.Close()
h := http.Header{}
h.Set(api.StreamType, api.StreamTypeError)
errorStream, err := conn.CreateStream(h)
if err != nil {
t.Fatalf("%d: error creating error stream: %v", i, err)
}
defer errorStream.Reset()
if test.stdin {
h.Set(api.StreamType, api.StreamTypeStdin)
stream, err := conn.CreateStream(h)
if err != nil {
t.Fatalf("%d: error creating stdin stream: %v", i, err)
}
defer stream.Reset()
_, err = stream.Write([]byte(expectedStdin))
if err != nil {
t.Fatalf("%d: error writing to stdin stream: %v", i, err)
}
}
var stdoutStream httpstream.Stream
if test.stdout {
h.Set(api.StreamType, api.StreamTypeStdout)
stdoutStream, err = conn.CreateStream(h)
if err != nil {
t.Fatalf("%d: error creating stdout stream: %v", i, err)
}
defer stdoutStream.Reset()
}
var stderrStream httpstream.Stream
if test.stderr && !test.tty {
h.Set(api.StreamType, api.StreamTypeStderr)
stderrStream, err = conn.CreateStream(h)
if err != nil {
t.Fatalf("%d: error creating stderr stream: %v", i, err)
}
defer stderrStream.Reset()
}
if test.stdout {
output := make([]byte, 10)
n, err := stdoutStream.Read(output)
close(clientStdoutReadDone)
if err != nil {
t.Fatalf("%d: error reading from stdout stream: %v", i, err)
}
if e, a := expectedStdout, string(output[0:n]); e != a {
t.Fatalf("%d: stdout: expected '%v', got '%v'", i, e, a)
}
}
if test.stderr && !test.tty {
output := make([]byte, 10)
n, err := stderrStream.Read(output)
close(clientStderrReadDone)
if err != nil {
t.Fatalf("%d: error reading from stderr stream: %v", i, err)
}
if e, a := expectedStderr, string(output[0:n]); e != a {
t.Fatalf("%d: stderr: expected '%v', got '%v'", i, e, a)
}
}
select {
case <-execFuncDone:
case <-time.After(10 * time.Millisecond):
t.Fatalf("%d: timed out waiting for execFunc to complete", i)
}
}
}
func TestServePortForwardIdleTimeout(t *testing.T) {
fw := newServerTest()
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
return 100 * time.Millisecond
}
idleSuccess := make(chan struct{})
fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error {
select {
case <-idleSuccess:
case <-time.After(150 * time.Millisecond):
t.Fatalf("execFunc timed out waiting for idle timeout")
}
return nil
}
podNamespace := "other"
podName := "foo"
url := fw.testHTTPServer.URL + "/portForward/" + podNamespace + "/" + podName
upgradeRoundTripper := spdy.NewRoundTripper(nil)
c := &http.Client{Transport: upgradeRoundTripper}
resp, err := c.Get(url)
if err != nil {
t.Fatalf("Got error GETing: %v", err)
}
defer resp.Body.Close()
conn, err := upgradeRoundTripper.NewConnection(resp)
if err != nil {
t.Fatalf("Unexpected error creating streaming connection: %s", err)
}
if conn == nil {
t.Fatal("Unexpected nil connection")
}
defer conn.Close()
select {
case <-conn.CloseChan():
close(idleSuccess)
case <-time.After(150 * time.Millisecond):
t.Fatalf("Timed out waiting for connection closure due to idle timeout")
}
}
func TestServePortForward(t *testing.T) {
tests := []struct {
port string
uid bool
clientData string
containerData string
shouldError bool
}{
{port: "", shouldError: true},
{port: "abc", shouldError: true},
{port: "-1", shouldError: true},
{port: "65536", shouldError: true},
{port: "0", shouldError: true},
{port: "1", shouldError: false},
{port: "8000", shouldError: false},
{port: "8000", clientData: "client data", containerData: "container data", shouldError: false},
{port: "65535", shouldError: false},
{port: "65535", uid: true, shouldError: false},
}
podNamespace := "other"
podName := "foo"
expectedPodName := podName + "." + podNamespace + ".etcd"
expectedUid := "9b01b80f-8fb4-11e4-95ab-4200af06647"
for i, test := range tests {
fw := newServerTest()
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
return 0
}
portForwardFuncDone := make(chan struct{})
fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error {
defer close(portForwardFuncDone)
if e, a := expectedPodName, name; e != a {
t.Fatalf("%d: pod name: expected '%v', got '%v'", i, e, a)
}
if e, a := expectedUid, uid; test.uid && e != string(a) {
t.Fatalf("%d: uid: expected '%v', got '%v'", i, e, a)
}
p, err := strconv.ParseUint(test.port, 10, 16)
if err != nil {
t.Fatalf("%d: error parsing port string '%s': %v", i, port, err)
}
if e, a := uint16(p), port; e != a {
t.Fatalf("%d: port: expected '%v', got '%v'", i, e, a)
}
if test.clientData != "" {
fromClient := make([]byte, 32)
n, err := stream.Read(fromClient)
if err != nil {
t.Fatalf("%d: error reading client data: %v", i, err)
}
if e, a := test.clientData, string(fromClient[0:n]); e != a {
t.Fatalf("%d: client data: expected to receive '%v', got '%v'", i, e, a)
}
}
if test.containerData != "" {
_, err := stream.Write([]byte(test.containerData))
if err != nil {
t.Fatalf("%d: error writing container data: %v", i, err)
}
}
return nil
}
var url string
if test.uid {
url = fmt.Sprintf("%s/portForward/%s/%s/%s", fw.testHTTPServer.URL, podNamespace, podName, expectedUid)
} else {
url = fmt.Sprintf("%s/portForward/%s/%s", fw.testHTTPServer.URL, podNamespace, podName)
}
upgradeRoundTripper := spdy.NewRoundTripper(nil)
c := &http.Client{Transport: upgradeRoundTripper}
resp, err := c.Get(url)
if err != nil {
t.Fatalf("%d: Got error GETing: %v", i, err)
}
defer resp.Body.Close()
conn, err := upgradeRoundTripper.NewConnection(resp)
if err != nil {
t.Fatalf("Unexpected error creating streaming connection: %s", err)
}
if conn == nil {
t.Fatal("%d: Unexpected nil connection", i)
}
defer conn.Close()
headers := http.Header{}
headers.Set("streamType", "error")
headers.Set("port", test.port)
errorStream, err := conn.CreateStream(headers)
_ = errorStream
haveErr := err != nil
if e, a := test.shouldError, haveErr; e != a {
t.Fatalf("%d: create stream: expected err=%t, got %t: %v", i, e, a, err)
}
if test.shouldError {
continue
}
headers.Set("streamType", "data")
headers.Set("port", test.port)
dataStream, err := conn.CreateStream(headers)
haveErr = err != nil
if e, a := test.shouldError, haveErr; e != a {
t.Fatalf("%d: create stream: expected err=%t, got %t: %v", i, e, a, err)
}
if test.clientData != "" {
_, err := dataStream.Write([]byte(test.clientData))
if err != nil {
t.Fatalf("%d: unexpected error writing client data: %v", i, err)
}
}
if test.containerData != "" {
fromContainer := make([]byte, 32)
n, err := dataStream.Read(fromContainer)
if err != nil {
t.Fatalf("%d: unexpected error reading container data: %v", i, err)
}
if e, a := test.containerData, string(fromContainer[0:n]); e != a {
t.Fatalf("%d: expected to receive '%v' from container, got '%v'", i, e, a)
}
}
select {
case <-portForwardFuncDone:
case <-time.After(100 * time.Millisecond):
t.Fatalf("%d: timed out waiting for portForwardFuncDone", i)
}
}
}

View File

@@ -0,0 +1,19 @@
/*
Copyright 2015 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package httpstream adds multiplexed streaming support to HTTP requests and
// responses via connection upgrades.
package httpstream

View File

@@ -0,0 +1,80 @@
/*
Copyright 2015 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package httpstream
import (
"io"
"net/http"
"time"
)
const (
HeaderConnection = "Connection"
HeaderUpgrade = "Upgrade"
)
// NewStreamHandler defines a function that is called when a new Stream is
// received. If no error is returned, the Stream is accepted; otherwise,
// the stream is rejected.
type NewStreamHandler func(Stream) error
// NoOpNewStreamHandler is a stream handler that accepts a new stream and
// performs no other logic.
func NoOpNewStreamHandler(stream Stream) error { return nil }
// UpgradeRoundTripper is a type of http.RoundTripper that is able to upgrade
// HTTP requests to support multiplexed bidirectional streams. After RoundTrip()
// is invoked, if the upgrade is successful, clients may retrieve the upgraded
// connection by calling UpgradeRoundTripper.Connection().
type UpgradeRoundTripper interface {
http.RoundTripper
// NewConnection validates the response and creates a new Connection.
NewConnection(resp *http.Response) (Connection, error)
}
// ResponseUpgrader knows how to upgrade HTTP requests and responses to
// add streaming support to them.
type ResponseUpgrader interface {
// UpgradeResponse upgrades an HTTP response to one that supports multiplexed
// streams. newStreamHandler will be called synchronously whenever the
// other end of the upgraded connection creates a new stream.
UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler NewStreamHandler) Connection
}
// Connection represents an upgraded HTTP connection.
type Connection interface {
// CreateStream creates a new Stream with the supplied headers.
CreateStream(headers http.Header) (Stream, error)
// Close resets all streams and closes the connection.
Close() error
// CloseChan returns a channel that is closed when the underlying connection is closed.
CloseChan() <-chan bool
// SetIdleTimeout sets the amount of time the connection may remain idle before
// it is automatically closed.
SetIdleTimeout(timeout time.Duration)
}
// Stream represents a bidirectional communications channel that is part of an
// upgraded connection.
type Stream interface {
io.ReadWriteCloser
// Reset closes both directions of the stream, indicating that neither client
// or server can use it any more.
Reset() error
// Headers returns the headers used to create the stream.
Headers() http.Header
}

View File

@@ -0,0 +1,139 @@
/*
Copyright 2015 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package spdy
import (
"net"
"net/http"
"sync"
"time"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
"github.com/docker/spdystream"
"github.com/golang/glog"
)
// connection maintains state about a spdystream.Connection and its associated
// streams.
type connection struct {
conn *spdystream.Connection
streams []httpstream.Stream
streamLock sync.Mutex
newStreamHandler httpstream.NewStreamHandler
}
// NewClientConnection creates a new SPDY client connection.
func NewClientConnection(conn net.Conn) (httpstream.Connection, error) {
spdyConn, err := spdystream.NewConnection(conn, false)
if err != nil {
defer conn.Close()
return nil, err
}
return newConnection(spdyConn, httpstream.NoOpNewStreamHandler), nil
}
// NewServerConnection creates a new SPDY server connection. newStreamHandler
// will be invoked when the server receives a newly created stream from the
// client.
func NewServerConnection(conn net.Conn, newStreamHandler httpstream.NewStreamHandler) (httpstream.Connection, error) {
spdyConn, err := spdystream.NewConnection(conn, true)
if err != nil {
defer conn.Close()
return nil, err
}
return newConnection(spdyConn, newStreamHandler), nil
}
// newConnection returns a new connection wrapping conn. newStreamHandler
// will be invoked when the server receives a newly created stream from the
// client.
func newConnection(conn *spdystream.Connection, newStreamHandler httpstream.NewStreamHandler) httpstream.Connection {
c := &connection{conn: conn, newStreamHandler: newStreamHandler}
go conn.Serve(c.newSpdyStream)
return c
}
// createStreamResponseTimeout indicates how long to wait for the other side to
// acknowledge the new stream before timing out.
const createStreamResponseTimeout = 2 * time.Second
// Close first sends a reset for all of the connection's streams, and then
// closes the underlying spdystream.Connection.
func (c *connection) Close() error {
c.streamLock.Lock()
for _, s := range c.streams {
s.Reset()
}
c.streams = make([]httpstream.Stream, 0)
c.streamLock.Unlock()
return c.conn.Close()
}
// CreateStream creates a new stream with the specified headers and registers
// it with the connection.
func (c *connection) CreateStream(headers http.Header) (httpstream.Stream, error) {
stream, err := c.conn.CreateStream(headers, nil, false)
if err != nil {
return nil, err
}
if err = stream.WaitTimeout(createStreamResponseTimeout); err != nil {
return nil, err
}
c.registerStream(stream)
return stream, nil
}
// registerStream adds the stream s to the connection's list of streams that
// it owns.
func (c *connection) registerStream(s httpstream.Stream) {
c.streamLock.Lock()
c.streams = append(c.streams, s)
c.streamLock.Unlock()
}
// CloseChan returns a channel that, when closed, indicates that the underlying
// spdystream.Connection has been closed.
func (c *connection) CloseChan() <-chan bool {
return c.conn.CloseChan()
}
// newSpdyStream is the internal new stream handler used by spdystream.Connection.Serve.
// It calls connection's newStreamHandler, giving it the opportunity to accept or reject
// the stream. If newStreamHandler returns an error, the stream is rejected. If not, the
// stream is accepted and registered with the connection.
func (c *connection) newSpdyStream(stream *spdystream.Stream) {
err := c.newStreamHandler(stream)
rejectStream := (err != nil)
if rejectStream {
glog.Warningf("Stream rejected: %v", err)
stream.Reset()
return
}
c.registerStream(stream)
stream.SendReply(http.Header{}, rejectStream)
}
// SetIdleTimeout sets the amount of time the connection may remain idle before
// it is automatically closed.
func (c *connection) SetIdleTimeout(timeout time.Duration) {
c.conn.SetIdleTimeout(timeout)
}

View File

@@ -0,0 +1,130 @@
/*
Copyright 2015 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package spdy
import (
"bufio"
"crypto/tls"
"fmt"
"io/ioutil"
"net"
"net/http"
"strings"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
)
// SpdyRoundTripper knows how to upgrade an HTTP request to one that supports
// multiplexed streams. After RoundTrip() is invoked, Conn will be set
// and usable. SpdyRoundTripper implements the UpgradeRoundTripper interface.
type SpdyRoundTripper struct {
//tlsConfig holds the TLS configuration settings to use when connecting
//to the remote server.
tlsConfig *tls.Config
/* TODO according to http://golang.org/pkg/net/http/#RoundTripper, a RoundTripper
must be safe for use by multiple concurrent goroutines. If this is absolutely
necessary, we could keep a map from http.Request to net.Conn. In practice,
a client will create an http.Client, set the transport to a new insteace of
SpdyRoundTripper, and use it a single time, so this hopefully won't be an issue.
*/
// conn is the underlying network connection to the remote server.
conn net.Conn
}
// NewSpdyRoundTripper creates a new SpdyRoundTripper that will use
// the specified tlsConfig.
func NewRoundTripper(tlsConfig *tls.Config) httpstream.UpgradeRoundTripper {
return &SpdyRoundTripper{tlsConfig: tlsConfig}
}
// dial dials the host specified by req, using TLS if appropriate.
func (s *SpdyRoundTripper) dial(req *http.Request) (net.Conn, error) {
dialAddr := util.CanonicalAddr(req.URL)
if req.URL.Scheme == "http" {
return net.Dial("tcp", dialAddr)
}
// TODO validate the TLSClientConfig is set up?
conn, err := tls.Dial("tcp", dialAddr, s.tlsConfig)
if err != nil {
return nil, err
}
host, _, err := net.SplitHostPort(dialAddr)
if err != nil {
return nil, err
}
err = conn.VerifyHostname(host)
if err != nil {
return nil, err
}
return conn, nil
}
// RoundTrip executes the Request and upgrades it. After a successful upgrade,
// clients may call SpdyRoundTripper.Connection() to retrieve the upgraded
// connection.
func (s *SpdyRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// TODO what's the best way to clone the request?
r := *req
req = &r
req.Header.Add(httpstream.HeaderConnection, httpstream.HeaderUpgrade)
req.Header.Add(httpstream.HeaderUpgrade, HeaderSpdy31)
conn, err := s.dial(req)
if err != nil {
return nil, err
}
err = req.Write(conn)
if err != nil {
return nil, err
}
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
return nil, err
}
s.conn = conn
return resp, nil
}
// NewConnection validates the upgrade response, creating and returning a new
// httpstream.Connection if there were no errors.
func (s *SpdyRoundTripper) NewConnection(resp *http.Response) (httpstream.Connection, error) {
connectionHeader := strings.ToLower(resp.Header.Get(httpstream.HeaderConnection))
upgradeHeader := strings.ToLower(resp.Header.Get(httpstream.HeaderUpgrade))
if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(HeaderSpdy31)) {
responseError := ""
responseErrorBytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
responseError = "Unable to read error from server response"
} else {
responseError = string(responseErrorBytes)
}
return nil, fmt.Errorf("Unable to upgrade connection: %s", responseError)
}
return NewClientConnection(s.conn)
}

View File

@@ -0,0 +1,226 @@
/*
Copyright 2015 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package spdy
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
)
func TestRoundTripAndNewConnection(t *testing.T) {
testCases := []struct {
serverConnectionHeader string
serverUpgradeHeader string
useTLS bool
shouldError bool
}{
{
serverConnectionHeader: "",
serverUpgradeHeader: "",
shouldError: true,
},
{
serverConnectionHeader: "Upgrade",
serverUpgradeHeader: "",
shouldError: true,
},
{
serverConnectionHeader: "",
serverUpgradeHeader: "SPDY/3.1",
shouldError: true,
},
{
serverConnectionHeader: "Upgrade",
serverUpgradeHeader: "SPDY/3.1",
shouldError: false,
},
{
serverConnectionHeader: "Upgrade",
serverUpgradeHeader: "SPDY/3.1",
useTLS: true,
shouldError: false,
},
}
for i, testCase := range testCases {
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if testCase.shouldError {
if e, a := httpstream.HeaderUpgrade, req.Header.Get(httpstream.HeaderConnection); e != a {
t.Fatalf("%d: Expected connection=upgrade header, got '%s", i, a)
}
w.Header().Set(httpstream.HeaderConnection, testCase.serverConnectionHeader)
w.Header().Set(httpstream.HeaderUpgrade, testCase.serverUpgradeHeader)
w.WriteHeader(http.StatusSwitchingProtocols)
return
}
streamCh := make(chan httpstream.Stream)
responseUpgrader := NewResponseUpgrader()
spdyConn := responseUpgrader.UpgradeResponse(w, req, func(s httpstream.Stream) error {
streamCh <- s
return nil
})
if spdyConn == nil {
t.Fatalf("%d: unexpected nil spdyConn", i)
}
defer spdyConn.Close()
stream := <-streamCh
io.Copy(stream, stream)
}))
clientTLS := &tls.Config{}
if testCase.useTLS {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("%d: error generating keypair: %s", i, err)
}
notBefore := time.Now()
notAfter := notBefore.Add(1 * time.Hour)
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Localhost Co"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
}
host := "127.0.0.1"
if ip := net.ParseIP(host); ip != nil {
template.IPAddresses = append(template.IPAddresses, ip)
}
template.DNSNames = append(template.DNSNames, host)
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
if err != nil {
t.Fatalf("%d: error creating cert: %s", i, err)
}
cert, err := x509.ParseCertificate(derBytes)
if err != nil {
t.Fatalf("%d: error parsing cert: %s", i, err)
}
roots := x509.NewCertPool()
roots.AddCert(cert)
server.TLS = &tls.Config{
RootCAs: roots,
}
clientTLS.RootCAs = roots
certBuf := bytes.Buffer{}
err = pem.Encode(&certBuf, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
if err != nil {
t.Fatalf("%d: error encoding cert: %s", i, err)
}
keyBuf := bytes.Buffer{}
err = pem.Encode(&keyBuf, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
if err != nil {
t.Fatalf("%d: error encoding key: %s", i, err)
}
tlsCert, err := tls.X509KeyPair(certBuf.Bytes(), keyBuf.Bytes())
if err != nil {
t.Fatalf("%d: error calling tls.X509KeyPair: %s", i, err)
}
server.TLS.Certificates = []tls.Certificate{tlsCert}
clientTLS.Certificates = []tls.Certificate{tlsCert}
server.StartTLS()
} else {
server.Start()
}
defer server.Close()
req, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
t.Fatalf("%d: Error creating request: %s", i, err)
}
spdyTransport := NewRoundTripper(clientTLS)
client := &http.Client{Transport: spdyTransport}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("%d: unexpected error from client.Do: %s", i, err)
}
conn, err := spdyTransport.NewConnection(resp)
haveErr := err != nil
if e, a := testCase.shouldError, haveErr; e != a {
t.Fatalf("%d: shouldError=%t, got %t: %v", i, e, a, err)
}
if testCase.shouldError {
continue
}
defer conn.Close()
if resp.StatusCode != http.StatusSwitchingProtocols {
t.Fatalf("%d: expected http 101 switching protocols, got %d", i, resp.StatusCode)
}
stream, err := conn.CreateStream(http.Header{})
if err != nil {
t.Fatalf("%d: error creating client stream: %s", i, err)
}
n, err := stream.Write([]byte("hello"))
if err != nil {
t.Fatalf("%d: error writing to stream: %s", i, err)
}
if n != 5 {
t.Fatalf("%d: Expected to write 5 bytes, but actually wrote %d", i, n)
}
b := make([]byte, 5)
n, err = stream.Read(b)
if err != nil {
t.Fatalf("%d: error reading from stream: %s", i, err)
}
if n != 5 {
t.Fatalf("%d: Expected to read 5 bytes, but actually read %d", i, n)
}
if e, a := "hello", string(b[0:n]); e != a {
t.Fatalf("%d: expected '%s', got '%s'", i, e, a)
}
}
}

View File

@@ -0,0 +1,78 @@
/*
Copyright 2015 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package spdy
import (
"fmt"
"net/http"
"strings"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
"github.com/golang/glog"
)
const HeaderSpdy31 = "SPDY/3.1"
// responseUpgrader knows how to upgrade HTTP responses. It
// implements the httpstream.ResponseUpgrader interface.
type responseUpgrader struct {
}
// NewResponseUpgrader returns a new httpstream.ResponseUpgrader that is
// capable of upgrading HTTP responses using SPDY/3.1 via the
// spdystream package.
func NewResponseUpgrader() httpstream.ResponseUpgrader {
return responseUpgrader{}
}
// UpgradeResponse upgrades an HTTP response to one that supports multiplexed
// streams. newStreamHandler will be called synchronously whenever the
// other end of the upgraded connection creates a new stream.
func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler httpstream.NewStreamHandler) httpstream.Connection {
connectionHeader := strings.ToLower(req.Header.Get(httpstream.HeaderConnection))
upgradeHeader := strings.ToLower(req.Header.Get(httpstream.HeaderUpgrade))
if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(HeaderSpdy31)) {
w.Write([]byte(fmt.Sprintf("Unable to upgrade: missing upgrade headers in request: %#v", req.Header)))
w.WriteHeader(http.StatusBadRequest)
return nil
}
hijacker, ok := w.(http.Hijacker)
if !ok {
w.Write([]byte("Unable to upgrade: unable to hijack response"))
w.WriteHeader(http.StatusInternalServerError)
return nil
}
w.Header().Add(httpstream.HeaderConnection, httpstream.HeaderUpgrade)
w.Header().Add(httpstream.HeaderUpgrade, HeaderSpdy31)
w.WriteHeader(http.StatusSwitchingProtocols)
conn, _, err := hijacker.Hijack()
if err != nil {
glog.Errorf("Unable to upgrade: error hijacking response: %v", err)
return nil
}
spdyConn, err := NewServerConnection(conn, newStreamHandler)
if err != nil {
glog.Errorf("Unable to upgrade: error creating SPDY server connection: %v", err)
return nil
}
return spdyConn
}

View File

@@ -0,0 +1,93 @@
/*
Copyright 2015 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package spdy
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestUpgradeResponse(t *testing.T) {
testCases := []struct {
connectionHeader string
upgradeHeader string
shouldError bool
}{
{
connectionHeader: "",
upgradeHeader: "",
shouldError: true,
},
{
connectionHeader: "Upgrade",
upgradeHeader: "",
shouldError: true,
},
{
connectionHeader: "",
upgradeHeader: "SPDY/3.1",
shouldError: true,
},
{
connectionHeader: "Upgrade",
upgradeHeader: "SPDY/3.1",
shouldError: false,
},
}
for i, testCase := range testCases {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
upgrader := NewResponseUpgrader()
conn := upgrader.UpgradeResponse(w, req, nil)
haveErr := conn == nil
if e, a := testCase.shouldError, haveErr; e != a {
t.Fatalf("%d: expected shouldErr=%t, got %t", i, testCase.shouldError, haveErr)
}
if haveErr {
return
}
if conn == nil {
t.Fatalf("%d: unexpected nil conn", i)
}
defer conn.Close()
}))
defer server.Close()
req, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
t.Fatalf("%d: error creating request: %s", i, err)
}
req.Header.Set("Connection", testCase.connectionHeader)
req.Header.Set("Upgrade", testCase.upgradeHeader)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("%d: unexpected non-nil err from client.Do: %s", i, err)
}
if testCase.shouldError {
continue
}
if resp.StatusCode != http.StatusSwitchingProtocols {
t.Fatalf("%d: expected status 101 switching protocols, got %d", i, resp.StatusCode)
}
}
}

View File

@@ -19,6 +19,7 @@ package util
import (
"fmt"
"net"
"net/url"
"strings"
)
@@ -61,3 +62,24 @@ func (ipnet *IPNet) Set(value string) error {
func (*IPNet) Type() string {
return "ipNet"
}
// FROM: http://golang.org/src/net/http/client.go
// Given a string of the form "host", "host:port", or "[ipv6::address]:port",
// return true if the string includes a port.
func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") }
// FROM: http://golang.org/src/net/http/transport.go
var portMap = map[string]string{
"http": "80",
"https": "443",
}
// FROM: http://golang.org/src/net/http/transport.go
// canonicalAddr returns url.Host but always with a ":port" suffix
func CanonicalAddr(url *url.URL) string {
addr := url.Host
if !hasPort(addr) {
return addr + ":" + portMap[url.Scheme]
}
return addr
}