Add streaming command execution & port forwarding
Add streaming command execution & port forwarding via HTTP connection upgrades (currently using SPDY).
This commit is contained in:
@@ -1349,3 +1349,32 @@ type SecretList struct {
|
||||
|
||||
Items []Secret `json:"items"`
|
||||
}
|
||||
|
||||
// These constants are for remote command execution and port forwarding and are
|
||||
// used by both the client side and server side components.
|
||||
//
|
||||
// This is probably not the ideal place for them, but it didn't seem worth it
|
||||
// to create pkg/exec and pkg/portforward just to contain a single file with
|
||||
// constants in it. Suggestions for more appropriate alternatives are
|
||||
// definitely welcome!
|
||||
const (
|
||||
// Enable stdin for remote command execution
|
||||
ExecStdinParam = "input"
|
||||
// Enable stdout for remote command execution
|
||||
ExecStdoutParam = "output"
|
||||
// Enable stderr for remote command execution
|
||||
ExecStderrParam = "error"
|
||||
// Enable TTY for remote command execution
|
||||
ExecTTYParam = "tty"
|
||||
// Command to run for remote command execution
|
||||
ExecCommandParamm = "command"
|
||||
|
||||
StreamType = "streamType"
|
||||
StreamTypeStdin = "stdin"
|
||||
StreamTypeStdout = "stdout"
|
||||
StreamTypeStderr = "stderr"
|
||||
StreamTypeData = "data"
|
||||
StreamTypeError = "error"
|
||||
|
||||
PortHeader = "port"
|
||||
)
|
||||
|
@@ -99,6 +99,7 @@ func RecoverPanics(handler http.Handler) http.Handler {
|
||||
http.StatusConflict,
|
||||
http.StatusNotFound,
|
||||
errors.StatusUnprocessableEntity,
|
||||
http.StatusSwitchingProtocols,
|
||||
),
|
||||
).Log()
|
||||
|
||||
|
@@ -22,6 +22,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
@@ -34,6 +35,7 @@ import (
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/httplog"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/runtime"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
|
||||
|
||||
"github.com/golang/glog"
|
||||
"golang.org/x/net/html"
|
||||
@@ -176,14 +178,67 @@ func (r *ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
httpCode = http.StatusOK
|
||||
newReq.Header = req.Header
|
||||
|
||||
proxy := httputil.NewSingleHostReverseProxy(&url.URL{Scheme: "http", Host: destURL.Host})
|
||||
proxy.Transport = &proxyTransport{
|
||||
proxyScheme: req.URL.Scheme,
|
||||
proxyHost: req.URL.Host,
|
||||
proxyPathPrepend: path.Join(r.prefix, "ns", namespace, resource, id),
|
||||
// TODO convert this entire proxy to an UpgradeAwareProxy similar to
|
||||
// https://github.com/openshift/origin/blob/master/pkg/util/httpproxy/upgradeawareproxy.go.
|
||||
// That proxy needs to be modified to support multiple backends, not just 1.
|
||||
connectionHeader := strings.ToLower(req.Header.Get(httpstream.HeaderConnection))
|
||||
if strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) && len(req.Header.Get(httpstream.HeaderUpgrade)) > 0 {
|
||||
//TODO support TLS? Doesn't look like proxyTransport does anything special ...
|
||||
dialAddr := util.CanonicalAddr(destURL)
|
||||
backendConn, err := net.Dial("tcp", dialAddr)
|
||||
if err != nil {
|
||||
status := errToAPIStatus(err)
|
||||
writeJSON(status.Code, r.codec, status, w)
|
||||
return
|
||||
}
|
||||
defer backendConn.Close()
|
||||
|
||||
// TODO should we use _ (a bufio.ReadWriter) instead of requestHijackedConn
|
||||
// when copying between the client and the backend? Docker doesn't when they
|
||||
// hijack, just for reference...
|
||||
requestHijackedConn, _, err := w.(http.Hijacker).Hijack()
|
||||
if err != nil {
|
||||
status := errToAPIStatus(err)
|
||||
writeJSON(status.Code, r.codec, status, w)
|
||||
return
|
||||
}
|
||||
defer requestHijackedConn.Close()
|
||||
|
||||
if err = newReq.Write(backendConn); err != nil {
|
||||
status := errToAPIStatus(err)
|
||||
writeJSON(status.Code, r.codec, status, w)
|
||||
return
|
||||
}
|
||||
|
||||
done := make(chan struct{}, 2)
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(backendConn, requestHijackedConn)
|
||||
if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
|
||||
glog.Errorf("Error proxying data from client to backend: %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(requestHijackedConn, backendConn)
|
||||
if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
|
||||
glog.Errorf("Error proxying data from backend to client: %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
<-done
|
||||
} else {
|
||||
proxy := httputil.NewSingleHostReverseProxy(&url.URL{Scheme: "http", Host: destURL.Host})
|
||||
proxy.Transport = &proxyTransport{
|
||||
proxyScheme: req.URL.Scheme,
|
||||
proxyHost: req.URL.Host,
|
||||
proxyPathPrepend: path.Join(r.prefix, "ns", namespace, resource, id),
|
||||
}
|
||||
proxy.FlushInterval = 200 * time.Millisecond
|
||||
proxy.ServeHTTP(w, newReq)
|
||||
}
|
||||
proxy.FlushInterval = 200 * time.Millisecond
|
||||
proxy.ServeHTTP(w, newReq)
|
||||
}
|
||||
|
||||
type proxyTransport struct {
|
||||
|
@@ -29,6 +29,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/html"
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
func parseURLOrDie(inURL string) *url.URL {
|
||||
@@ -327,3 +328,45 @@ func TestProxy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyUpgrade(t *testing.T) {
|
||||
backendServer := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {
|
||||
defer ws.Close()
|
||||
body := make([]byte, 5)
|
||||
ws.Read(body)
|
||||
ws.Write([]byte("hello " + string(body)))
|
||||
}))
|
||||
defer backendServer.Close()
|
||||
|
||||
simpleStorage := &SimpleRESTStorage{
|
||||
errors: map[string]error{},
|
||||
resourceLocation: backendServer.URL,
|
||||
expectedResourceNamespace: "myns",
|
||||
}
|
||||
|
||||
namespaceHandler := Handle(map[string]RESTStorage{
|
||||
"foo": simpleStorage,
|
||||
}, codec, "/prefix", "version", selfLinker, admissionControl, requestContextMapper, namespaceMapper)
|
||||
|
||||
server := httptest.NewServer(namespaceHandler)
|
||||
defer server.Close()
|
||||
|
||||
ws, err := websocket.Dial("ws://"+server.Listener.Addr().String()+"/prefix/version/proxy/namespaces/myns/foo/123", "", "http://127.0.0.1/")
|
||||
if err != nil {
|
||||
t.Fatalf("websocket dial err: %s", err)
|
||||
}
|
||||
defer ws.Close()
|
||||
|
||||
if _, err := ws.Write([]byte("world")); err != nil {
|
||||
t.Fatalf("write err: %s", err)
|
||||
}
|
||||
|
||||
response := make([]byte, 20)
|
||||
n, err := ws.Read(response)
|
||||
if err != nil {
|
||||
t.Fatalf("read err: %s", err)
|
||||
}
|
||||
if e, a := "hello world", string(response[0:n]); e != a {
|
||||
t.Fatalf("expected '%#v', got '%#v'", e, a)
|
||||
}
|
||||
}
|
||||
|
19
pkg/client/portforward/doc.go
Normal file
19
pkg/client/portforward/doc.go
Normal file
@@ -0,0 +1,19 @@
|
||||
/*
|
||||
Copyright 2015 Google Inc. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Package portforward adds support for SSH-like port forwarding from the client's
|
||||
// local host to remote containers.
|
||||
package portforward
|
300
pkg/client/portforward/portforward.go
Normal file
300
pkg/client/portforward/portforward.go
Normal file
@@ -0,0 +1,300 @@
|
||||
/*
|
||||
Copyright 2015 Google Inc. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/client"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream/spdy"
|
||||
"github.com/golang/glog"
|
||||
)
|
||||
|
||||
type upgrader interface {
|
||||
upgrade(*client.Request, *client.Config) (httpstream.Connection, error)
|
||||
}
|
||||
|
||||
type defaultUpgrader struct{}
|
||||
|
||||
func (u *defaultUpgrader) upgrade(req *client.Request, config *client.Config) (httpstream.Connection, error) {
|
||||
return req.Upgrade(config, spdy.NewRoundTripper)
|
||||
}
|
||||
|
||||
// PortForwarder knows how to listen for local connections and forward them to
|
||||
// a remote pod via an upgraded HTTP request.
|
||||
type PortForwarder struct {
|
||||
req *client.Request
|
||||
config *client.Config
|
||||
ports []ForwardedPort
|
||||
stopChan <-chan struct{}
|
||||
|
||||
streamConn httpstream.Connection
|
||||
listeners []io.Closer
|
||||
upgrader upgrader
|
||||
Ready chan struct{}
|
||||
}
|
||||
|
||||
// ForwardedPort contains a Local:Remote port pairing.
|
||||
type ForwardedPort struct {
|
||||
Local uint16
|
||||
Remote uint16
|
||||
}
|
||||
|
||||
/*
|
||||
valid port specifications:
|
||||
|
||||
5000
|
||||
- forwards from localhost:5000 to pod:5000
|
||||
|
||||
8888:5000
|
||||
- forwards from localhost:8888 to pod:5000
|
||||
|
||||
0:5000
|
||||
:5000
|
||||
- selects a random available local port,
|
||||
forwards from localhost:<random port> to pod:5000
|
||||
*/
|
||||
func parsePorts(ports []string) ([]ForwardedPort, error) {
|
||||
var forwards []ForwardedPort
|
||||
for _, portString := range ports {
|
||||
parts := strings.Split(portString, ":")
|
||||
var localString, remoteString string
|
||||
if len(parts) == 1 {
|
||||
localString = parts[0]
|
||||
remoteString = parts[0]
|
||||
} else if len(parts) == 2 {
|
||||
localString = parts[0]
|
||||
if localString == "" {
|
||||
// support :5000
|
||||
localString = "0"
|
||||
}
|
||||
remoteString = parts[1]
|
||||
} else {
|
||||
return nil, fmt.Errorf("Invalid port format '%s'", portString)
|
||||
}
|
||||
|
||||
localPort, err := strconv.ParseUint(localString, 10, 16)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error parsing local port '%s': %s", localString, err)
|
||||
}
|
||||
|
||||
remotePort, err := strconv.ParseUint(remoteString, 10, 16)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error parsing remote port '%s': %s", remoteString, err)
|
||||
}
|
||||
if remotePort == 0 {
|
||||
return nil, fmt.Errorf("Remote port must be > 0")
|
||||
}
|
||||
|
||||
forwards = append(forwards, ForwardedPort{uint16(localPort), uint16(remotePort)})
|
||||
}
|
||||
|
||||
return forwards, nil
|
||||
}
|
||||
|
||||
// New creates a new PortForwarder.
|
||||
func New(req *client.Request, config *client.Config, ports []string, stopChan <-chan struct{}) (*PortForwarder, error) {
|
||||
if len(ports) == 0 {
|
||||
return nil, errors.New("You must specify at least 1 port")
|
||||
}
|
||||
parsedPorts, err := parsePorts(ports)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &PortForwarder{
|
||||
req: req,
|
||||
config: config,
|
||||
ports: parsedPorts,
|
||||
stopChan: stopChan,
|
||||
Ready: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ForwardPorts formats and executes a port forwarding request. The connection will remain
|
||||
// open until stopChan is closed.
|
||||
func (pf *PortForwarder) ForwardPorts() error {
|
||||
defer pf.Close()
|
||||
|
||||
if pf.upgrader == nil {
|
||||
pf.upgrader = &defaultUpgrader{}
|
||||
}
|
||||
var err error
|
||||
pf.streamConn, err = pf.upgrader.upgrade(pf.req, pf.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error upgrading connection: %s", err)
|
||||
}
|
||||
defer pf.streamConn.Close()
|
||||
|
||||
return pf.forward()
|
||||
}
|
||||
|
||||
// forward dials the remote host specific in req, upgrades the request, starts
|
||||
// listeners for each port specified in ports, and forwards local connections
|
||||
// to the remote host via streams.
|
||||
func (pf *PortForwarder) forward() error {
|
||||
var err error
|
||||
|
||||
listenSuccess := false
|
||||
for _, port := range pf.ports {
|
||||
err = pf.listenOnPort(&port)
|
||||
if err != nil {
|
||||
glog.Warningf("Unable to listen on port %d: %v", port, err)
|
||||
}
|
||||
listenSuccess = true
|
||||
}
|
||||
|
||||
if !listenSuccess {
|
||||
return fmt.Errorf("Unable to listen on any of the requested ports: %v", pf.ports)
|
||||
}
|
||||
|
||||
close(pf.Ready)
|
||||
|
||||
// wait for interrupt or conn closure
|
||||
select {
|
||||
case <-pf.stopChan:
|
||||
case <-pf.streamConn.CloseChan():
|
||||
glog.Errorf("Lost connection to pod")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// listenOnPort creates a new listener on port and waits for new connections
|
||||
// in the background.
|
||||
func (pf *PortForwarder) listenOnPort(port *ForwardedPort) error {
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port.Local))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
parts := strings.Split(listener.Addr().String(), ":")
|
||||
localPort, err := strconv.ParseUint(parts[1], 10, 16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error parsing local part: %s", err)
|
||||
}
|
||||
port.Local = uint16(localPort)
|
||||
glog.Infof("Forwarding from %d -> %d", localPort, port.Remote)
|
||||
|
||||
pf.listeners = append(pf.listeners, listener)
|
||||
|
||||
go pf.waitForConnection(listener, *port)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// waitForConnection waits for new connections to listener and handles them in
|
||||
// the background.
|
||||
func (pf *PortForwarder) waitForConnection(listener net.Listener, port ForwardedPort) {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
// TODO consider using something like https://github.com/hydrogen18/stoppableListener?
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "use of closed network connection") {
|
||||
glog.Errorf("Error accepting connection on port %d: %v", port.Local, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
go pf.handleConnection(conn, port)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnection copies data between the local connection and the stream to
|
||||
// the remote server.
|
||||
func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
|
||||
defer conn.Close()
|
||||
|
||||
glog.Infof("Handling connection for %d", port.Local)
|
||||
|
||||
errorChan := make(chan error)
|
||||
doneChan := make(chan struct{}, 2)
|
||||
|
||||
// create error stream
|
||||
headers := http.Header{}
|
||||
headers.Set(api.StreamType, api.StreamTypeError)
|
||||
headers.Set(api.PortHeader, fmt.Sprintf("%d", port.Remote))
|
||||
errorStream, err := pf.streamConn.CreateStream(headers)
|
||||
if err != nil {
|
||||
glog.Errorf("Error creating error stream for port %d -> %d: %v", port.Local, port.Remote, err)
|
||||
return
|
||||
}
|
||||
defer errorStream.Reset()
|
||||
go func() {
|
||||
message, err := ioutil.ReadAll(errorStream)
|
||||
if err != nil && err != io.EOF {
|
||||
errorChan <- fmt.Errorf("Error reading from error stream for port %d -> %d: %v", port.Local, port.Remote, err)
|
||||
}
|
||||
if len(message) > 0 {
|
||||
errorChan <- fmt.Errorf("An error occurred forwarding %d -> %d: %v", port.Local, port.Remote, string(message))
|
||||
}
|
||||
}()
|
||||
|
||||
// create data stream
|
||||
headers.Set(api.StreamType, api.StreamTypeData)
|
||||
dataStream, err := pf.streamConn.CreateStream(headers)
|
||||
if err != nil {
|
||||
glog.Errorf("Error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err)
|
||||
return
|
||||
}
|
||||
// Send a Reset when this function exits to completely tear down the stream here
|
||||
// and in the remote server.
|
||||
defer dataStream.Reset()
|
||||
|
||||
go func() {
|
||||
// Copy from the remote side to the local port. We won't get an EOF from
|
||||
// the server as it has no way of knowing when to close the stream. We'll
|
||||
// take care of closing both ends of the stream with the call to
|
||||
// stream.Reset() when this function exits.
|
||||
if _, err := io.Copy(conn, dataStream); err != nil && err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") {
|
||||
glog.Errorf("Error copying from remote stream to local connection: %v", err)
|
||||
}
|
||||
doneChan <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
// Copy from the local port to the remote side. Here we will be able to know
|
||||
// when the Copy gets an EOF from conn, as that will happen as soon as conn is
|
||||
// closed (i.e. client disconnected).
|
||||
if _, err := io.Copy(dataStream, conn); err != nil && err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") {
|
||||
glog.Errorf("Error copying from local connection to remote stream: %v", err)
|
||||
}
|
||||
doneChan <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errorChan:
|
||||
glog.Error(err)
|
||||
case <-doneChan:
|
||||
}
|
||||
}
|
||||
|
||||
func (pf *PortForwarder) Close() {
|
||||
// stop all listeners
|
||||
for _, l := range pf.listeners {
|
||||
if err := l.Close(); err != nil {
|
||||
glog.Errorf("Error closing listener: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
321
pkg/client/portforward/portforward_test.go
Normal file
321
pkg/client/portforward/portforward_test.go
Normal file
@@ -0,0 +1,321 @@
|
||||
/*
|
||||
Copyright 2015 Google Inc. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/client"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
|
||||
)
|
||||
|
||||
func TestParsePortsAndNew(t *testing.T) {
|
||||
tests := []struct {
|
||||
input []string
|
||||
expected []ForwardedPort
|
||||
expectParseError bool
|
||||
expectNewError bool
|
||||
}{
|
||||
{input: []string{}, expectNewError: true},
|
||||
{input: []string{"a"}, expectParseError: true, expectNewError: true},
|
||||
{input: []string{":a"}, expectParseError: true, expectNewError: true},
|
||||
{input: []string{"-1"}, expectParseError: true, expectNewError: true},
|
||||
{input: []string{"65536"}, expectParseError: true, expectNewError: true},
|
||||
{input: []string{"0"}, expectParseError: true, expectNewError: true},
|
||||
{input: []string{"0:0"}, expectParseError: true, expectNewError: true},
|
||||
{input: []string{"a:5000"}, expectParseError: true, expectNewError: true},
|
||||
{input: []string{"5000:a"}, expectParseError: true, expectNewError: true},
|
||||
{
|
||||
input: []string{"5000", "5000:5000", "8888:5000", "5000:8888", ":5000", "0:5000"},
|
||||
expected: []ForwardedPort{
|
||||
{5000, 5000},
|
||||
{5000, 5000},
|
||||
{8888, 5000},
|
||||
{5000, 8888},
|
||||
{0, 5000},
|
||||
{0, 5000},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
parsed, err := parsePorts(test.input)
|
||||
haveError := err != nil
|
||||
if e, a := test.expectParseError, haveError; e != a {
|
||||
t.Fatalf("%d: parsePorts: error expected=%t, got %t: %s", i, e, a, err)
|
||||
}
|
||||
|
||||
expectedRequest := &client.Request{}
|
||||
expectedConfig := &client.Config{}
|
||||
expectedStopChan := make(chan struct{})
|
||||
pf, err := New(expectedRequest, expectedConfig, test.input, expectedStopChan)
|
||||
haveError = err != nil
|
||||
if e, a := test.expectNewError, haveError; e != a {
|
||||
t.Fatalf("%d: New: error expected=%t, got %t: %s", i, e, a, err)
|
||||
}
|
||||
|
||||
if test.expectParseError || test.expectNewError {
|
||||
continue
|
||||
}
|
||||
|
||||
for pi, expectedPort := range test.expected {
|
||||
if e, a := expectedPort.Local, parsed[pi].Local; e != a {
|
||||
t.Fatalf("%d: local expected: %d, got: %d", i, e, a)
|
||||
}
|
||||
if e, a := expectedPort.Remote, parsed[pi].Remote; e != a {
|
||||
t.Fatalf("%d: remote expected: %d, got: %d", i, e, a)
|
||||
}
|
||||
}
|
||||
|
||||
if e, a := expectedRequest, pf.req; e != a {
|
||||
t.Fatalf("%d: req: expected %#v, got %#v", i, e, a)
|
||||
}
|
||||
if e, a := expectedConfig, pf.config; e != a {
|
||||
t.Fatalf("%d: config: expected %#v, got %#v", i, e, a)
|
||||
}
|
||||
if e, a := test.expected, pf.ports; !reflect.DeepEqual(e, a) {
|
||||
t.Fatalf("%d: ports: expected %#v, got %#v", i, e, a)
|
||||
}
|
||||
if e, a := expectedStopChan, pf.stopChan; e != a {
|
||||
t.Fatalf("%d: stopChan: expected %#v, got %#v", i, e, a)
|
||||
}
|
||||
if pf.Ready == nil {
|
||||
t.Fatalf("%d: Ready should be non-nil", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type fakeUpgrader struct {
|
||||
conn *fakeUpgradeConnection
|
||||
err error
|
||||
}
|
||||
|
||||
func (u *fakeUpgrader) upgrade(req *client.Request, config *client.Config) (httpstream.Connection, error) {
|
||||
return u.conn, u.err
|
||||
}
|
||||
|
||||
type fakeUpgradeConnection struct {
|
||||
closeCalled bool
|
||||
lock sync.Mutex
|
||||
streams map[string]*fakeUpgradeStream
|
||||
portData map[string]string
|
||||
}
|
||||
|
||||
func newFakeUpgradeConnection() *fakeUpgradeConnection {
|
||||
return &fakeUpgradeConnection{
|
||||
streams: make(map[string]*fakeUpgradeStream),
|
||||
portData: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *fakeUpgradeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
stream := &fakeUpgradeStream{}
|
||||
c.streams[headers.Get(api.PortHeader)] = stream
|
||||
stream.data = c.portData[headers.Get(api.PortHeader)]
|
||||
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (c *fakeUpgradeConnection) Close() error {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
c.closeCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeUpgradeConnection) CloseChan() <-chan bool {
|
||||
return make(chan bool)
|
||||
}
|
||||
|
||||
func (c *fakeUpgradeConnection) SetIdleTimeout(timeout time.Duration) {
|
||||
}
|
||||
|
||||
type fakeUpgradeStream struct {
|
||||
readCalled bool
|
||||
writeCalled bool
|
||||
dataWritten []byte
|
||||
closeCalled bool
|
||||
resetCalled bool
|
||||
data string
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func (s *fakeUpgradeStream) Read(p []byte) (int, error) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.readCalled = true
|
||||
b := []byte(s.data)
|
||||
n := copy(p, b)
|
||||
return n, io.EOF
|
||||
}
|
||||
|
||||
func (s *fakeUpgradeStream) Write(p []byte) (int, error) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.writeCalled = true
|
||||
s.dataWritten = make([]byte, len(p))
|
||||
copy(s.dataWritten, p)
|
||||
return len(p), io.EOF
|
||||
}
|
||||
|
||||
func (s *fakeUpgradeStream) Close() error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.closeCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeUpgradeStream) Reset() error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.resetCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeUpgradeStream) Headers() http.Header {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
return http.Header{}
|
||||
}
|
||||
|
||||
func TestForwardPorts(t *testing.T) {
|
||||
testCases := []struct {
|
||||
Upgrader *fakeUpgrader
|
||||
Ports []string
|
||||
Send map[uint16]string
|
||||
Receive map[uint16]string
|
||||
Err bool
|
||||
}{
|
||||
{
|
||||
Upgrader: &fakeUpgrader{err: errors.New("bail")},
|
||||
Err: true,
|
||||
},
|
||||
{
|
||||
Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()},
|
||||
Ports: []string{"5000"},
|
||||
},
|
||||
{
|
||||
Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()},
|
||||
Ports: []string{"5000", "6000"},
|
||||
Send: map[uint16]string{
|
||||
5000: "abcd",
|
||||
6000: "ghij",
|
||||
},
|
||||
Receive: map[uint16]string{
|
||||
5000: "1234",
|
||||
6000: "5678",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, testCase := range testCases {
|
||||
stopChan := make(chan struct{}, 1)
|
||||
|
||||
pf, err := New(&client.Request{}, &client.Config{}, testCase.Ports, stopChan)
|
||||
hasErr := err != nil
|
||||
if hasErr != testCase.Err {
|
||||
t.Fatalf("%d: New: expected %t, got %t: %v", i, testCase.Err, hasErr, err)
|
||||
}
|
||||
if pf == nil {
|
||||
continue
|
||||
}
|
||||
pf.upgrader = testCase.Upgrader
|
||||
if testCase.Upgrader.err != nil {
|
||||
err := pf.ForwardPorts()
|
||||
hasErr := err != nil
|
||||
if hasErr != testCase.Err {
|
||||
t.Fatalf("%d: ForwardPorts: expected %t, got %t: %v", i, testCase.Err, hasErr, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
doneChan := make(chan error)
|
||||
go func() {
|
||||
doneChan <- pf.ForwardPorts()
|
||||
}()
|
||||
select {
|
||||
case <-pf.Ready:
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatalf("%d: timed out waiting for listeners", i)
|
||||
}
|
||||
|
||||
conn := testCase.Upgrader.conn
|
||||
|
||||
for port, data := range testCase.Send {
|
||||
conn.lock.Lock()
|
||||
conn.portData[fmt.Sprintf("%d", port)] = testCase.Receive[port]
|
||||
conn.lock.Unlock()
|
||||
|
||||
clientConn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error dialing %d: %s", i, port, err)
|
||||
}
|
||||
defer clientConn.Close()
|
||||
|
||||
n, err := clientConn.Write([]byte(data))
|
||||
if err != nil && err != io.EOF {
|
||||
t.Fatalf("%d: Error sending data '%s': %s", i, data, err)
|
||||
}
|
||||
if n == 0 {
|
||||
t.Fatalf("%d: unexpected write of 0 bytes", i)
|
||||
}
|
||||
b := make([]byte, 4)
|
||||
n, err = clientConn.Read(b)
|
||||
if err != nil && err != io.EOF {
|
||||
t.Fatalf("%d: Error reading data: %s", i, err)
|
||||
}
|
||||
if !bytes.Equal([]byte(testCase.Receive[port]), b) {
|
||||
t.Fatalf("%d: expected to read '%s', got '%s'", i, testCase.Receive[port], b)
|
||||
}
|
||||
}
|
||||
|
||||
// tell r.ForwardPorts to stop
|
||||
close(stopChan)
|
||||
|
||||
// wait for r.ForwardPorts to actually return
|
||||
select {
|
||||
case err := <-doneChan:
|
||||
if err != nil {
|
||||
t.Fatalf("%d: unexpected error: %s", err)
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatalf("%d: timeout waiting for ForwardPorts to finish")
|
||||
}
|
||||
|
||||
if e, a := len(testCase.Send), len(conn.streams); e != a {
|
||||
t.Fatalf("%d: expected %d streams to be created, got %d", e, a)
|
||||
}
|
||||
|
||||
if !conn.closeCalled {
|
||||
t.Fatalf("%d: expected conn closure", i)
|
||||
}
|
||||
}
|
||||
}
|
20
pkg/client/remotecommand/doc.go
Normal file
20
pkg/client/remotecommand/doc.go
Normal file
@@ -0,0 +1,20 @@
|
||||
/*
|
||||
Copyright 2015 Google Inc. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Package remotecommand adds support for executing commands in containers,
|
||||
// with support for separate stdin, stdout, and stderr streams, as well as
|
||||
// TTY.
|
||||
package remotecommand
|
186
pkg/client/remotecommand/remotecommand.go
Normal file
186
pkg/client/remotecommand/remotecommand.go
Normal file
@@ -0,0 +1,186 @@
|
||||
/*
|
||||
Copyright 2015 Google Inc. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package remotecommand
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/client"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream/spdy"
|
||||
"github.com/golang/glog"
|
||||
)
|
||||
|
||||
type upgrader interface {
|
||||
upgrade(*client.Request, *client.Config) (httpstream.Connection, error)
|
||||
}
|
||||
|
||||
type defaultUpgrader struct{}
|
||||
|
||||
func (u *defaultUpgrader) upgrade(req *client.Request, config *client.Config) (httpstream.Connection, error) {
|
||||
return req.Upgrade(config, spdy.NewRoundTripper)
|
||||
}
|
||||
|
||||
type RemoteCommandExecutor struct {
|
||||
req *client.Request
|
||||
config *client.Config
|
||||
command []string
|
||||
stdin io.Reader
|
||||
stdout io.Writer
|
||||
stderr io.Writer
|
||||
tty bool
|
||||
|
||||
upgrader upgrader
|
||||
}
|
||||
|
||||
func New(req *client.Request, config *client.Config, command []string, stdin io.Reader, stdout, stderr io.Writer, tty bool) *RemoteCommandExecutor {
|
||||
return &RemoteCommandExecutor{
|
||||
req: req,
|
||||
config: config,
|
||||
command: command,
|
||||
stdin: stdin,
|
||||
stdout: stdout,
|
||||
stderr: stderr,
|
||||
tty: tty,
|
||||
}
|
||||
}
|
||||
|
||||
// Execute sends a remote command execution request, upgrading the
|
||||
// connection and creating streams to represent stdin/stdout/stderr. Data is
|
||||
// copied between these streams and the supplied stdin/stdout/stderr parameters.
|
||||
func (e *RemoteCommandExecutor) Execute() error {
|
||||
doStdin := (e.stdin != nil)
|
||||
doStdout := (e.stdout != nil)
|
||||
doStderr := (!e.tty && e.stderr != nil)
|
||||
|
||||
if doStdin {
|
||||
e.req.Param(api.ExecStdinParam, "1")
|
||||
}
|
||||
if doStdout {
|
||||
e.req.Param(api.ExecStdoutParam, "1")
|
||||
}
|
||||
if doStderr {
|
||||
e.req.Param(api.ExecStderrParam, "1")
|
||||
}
|
||||
if e.tty {
|
||||
e.req.Param(api.ExecTTYParam, "1")
|
||||
}
|
||||
|
||||
for _, s := range e.command {
|
||||
e.req.Param(api.ExecCommandParamm, s)
|
||||
}
|
||||
|
||||
if e.upgrader == nil {
|
||||
e.upgrader = &defaultUpgrader{}
|
||||
}
|
||||
conn, err := e.upgrader.upgrade(e.req, e.config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
doneChan := make(chan struct{}, 2)
|
||||
errorChan := make(chan error)
|
||||
|
||||
cp := func(s string, dst io.Writer, src io.Reader) {
|
||||
glog.V(4).Infof("Copying %s", s)
|
||||
defer glog.V(4).Infof("Done copying %s", s)
|
||||
if _, err := io.Copy(dst, src); err != nil && err != io.EOF {
|
||||
glog.Errorf("Error copying %s: %v", s, err)
|
||||
}
|
||||
if s == api.StreamTypeStdout || s == api.StreamTypeStderr {
|
||||
doneChan <- struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set(api.StreamType, api.StreamTypeError)
|
||||
errorStream, err := conn.CreateStream(headers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
message, err := ioutil.ReadAll(errorStream)
|
||||
if err != nil && err != io.EOF {
|
||||
errorChan <- fmt.Errorf("Error reading from error stream: %s", err)
|
||||
return
|
||||
}
|
||||
if len(message) > 0 {
|
||||
errorChan <- fmt.Errorf("Error executing remote command: %s", message)
|
||||
return
|
||||
}
|
||||
}()
|
||||
defer errorStream.Reset()
|
||||
|
||||
if doStdin {
|
||||
headers.Set(api.StreamType, api.StreamTypeStdin)
|
||||
remoteStdin, err := conn.CreateStream(headers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer remoteStdin.Reset()
|
||||
// TODO this goroutine will never exit cleanly (the io.Copy never unblocks)
|
||||
// because stdin is not closed until the process exits. If we try to call
|
||||
// stdin.Close(), it returns no error but doesn't unblock the copy. It will
|
||||
// exit when the process exits, instead.
|
||||
go cp(api.StreamTypeStdin, remoteStdin, e.stdin)
|
||||
}
|
||||
|
||||
waitCount := 0
|
||||
completedStreams := 0
|
||||
|
||||
if doStdout {
|
||||
waitCount++
|
||||
headers.Set(api.StreamType, api.StreamTypeStdout)
|
||||
remoteStdout, err := conn.CreateStream(headers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer remoteStdout.Reset()
|
||||
go cp(api.StreamTypeStdout, e.stdout, remoteStdout)
|
||||
}
|
||||
|
||||
if doStderr && !e.tty {
|
||||
waitCount++
|
||||
headers.Set(api.StreamType, api.StreamTypeStderr)
|
||||
remoteStderr, err := conn.CreateStream(headers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer remoteStderr.Reset()
|
||||
go cp(api.StreamTypeStderr, e.stderr, remoteStderr)
|
||||
}
|
||||
|
||||
Loop:
|
||||
for {
|
||||
select {
|
||||
case <-doneChan:
|
||||
completedStreams++
|
||||
if completedStreams == waitCount {
|
||||
break Loop
|
||||
}
|
||||
case err := <-errorChan:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
288
pkg/client/remotecommand/remotecommand_test.go
Normal file
288
pkg/client/remotecommand/remotecommand_test.go
Normal file
@@ -0,0 +1,288 @@
|
||||
/*
|
||||
Copyright 2015 Google Inc. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package remotecommand
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/client"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
|
||||
)
|
||||
|
||||
type fakeUpgrader struct {
|
||||
conn *fakeUpgradeConnection
|
||||
err error
|
||||
}
|
||||
|
||||
func (u *fakeUpgrader) upgrade(req *client.Request, config *client.Config) (httpstream.Connection, error) {
|
||||
return u.conn, u.err
|
||||
}
|
||||
|
||||
type fakeUpgradeConnection struct {
|
||||
closeCalled bool
|
||||
lock sync.Mutex
|
||||
|
||||
stdin *fakeUpgradeStream
|
||||
stdout *fakeUpgradeStream
|
||||
stdoutData string
|
||||
stderr *fakeUpgradeStream
|
||||
stderrData string
|
||||
errorStream *fakeUpgradeStream
|
||||
errorData string
|
||||
unexpectedStreamCreated bool
|
||||
}
|
||||
|
||||
func newFakeUpgradeConnection() *fakeUpgradeConnection {
|
||||
return &fakeUpgradeConnection{}
|
||||
}
|
||||
|
||||
func (c *fakeUpgradeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
stream := &fakeUpgradeStream{}
|
||||
switch headers.Get(api.StreamType) {
|
||||
case api.StreamTypeStdin:
|
||||
c.stdin = stream
|
||||
case api.StreamTypeStdout:
|
||||
c.stdout = stream
|
||||
stream.data = c.stdoutData
|
||||
case api.StreamTypeStderr:
|
||||
c.stderr = stream
|
||||
stream.data = c.stderrData
|
||||
case api.StreamTypeError:
|
||||
c.errorStream = stream
|
||||
stream.data = c.errorData
|
||||
default:
|
||||
c.unexpectedStreamCreated = true
|
||||
}
|
||||
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (c *fakeUpgradeConnection) Close() error {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
c.closeCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeUpgradeConnection) CloseChan() <-chan bool {
|
||||
return make(chan bool)
|
||||
}
|
||||
|
||||
func (c *fakeUpgradeConnection) SetIdleTimeout(timeout time.Duration) {
|
||||
}
|
||||
|
||||
type fakeUpgradeStream struct {
|
||||
readCalled bool
|
||||
writeCalled bool
|
||||
dataWritten []byte
|
||||
closeCalled bool
|
||||
resetCalled bool
|
||||
data string
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func (s *fakeUpgradeStream) Read(p []byte) (int, error) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.readCalled = true
|
||||
b := []byte(s.data)
|
||||
n := copy(p, b)
|
||||
return n, io.EOF
|
||||
}
|
||||
|
||||
func (s *fakeUpgradeStream) Write(p []byte) (int, error) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.writeCalled = true
|
||||
s.dataWritten = make([]byte, len(p))
|
||||
copy(s.dataWritten, p)
|
||||
return len(p), io.EOF
|
||||
}
|
||||
|
||||
func (s *fakeUpgradeStream) Close() error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.closeCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeUpgradeStream) Reset() error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.resetCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeUpgradeStream) Headers() http.Header {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
return http.Header{}
|
||||
}
|
||||
|
||||
func TestRequestExecuteRemoteCommand(t *testing.T) {
|
||||
testCases := []struct {
|
||||
Upgrader *fakeUpgrader
|
||||
Stdin string
|
||||
Stdout string
|
||||
Stderr string
|
||||
Error string
|
||||
Tty bool
|
||||
ShouldError bool
|
||||
}{
|
||||
{
|
||||
Upgrader: &fakeUpgrader{err: errors.New("bail")},
|
||||
ShouldError: true,
|
||||
},
|
||||
{
|
||||
Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()},
|
||||
Stdin: "a",
|
||||
Stdout: "b",
|
||||
Stderr: "c",
|
||||
Error: "bail",
|
||||
ShouldError: true,
|
||||
},
|
||||
{
|
||||
Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()},
|
||||
Stdin: "a",
|
||||
Stdout: "b",
|
||||
Stderr: "c",
|
||||
},
|
||||
{
|
||||
Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()},
|
||||
Stdin: "a",
|
||||
Stdout: "b",
|
||||
Stderr: "c",
|
||||
Tty: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, testCase := range testCases {
|
||||
if testCase.Error != "" {
|
||||
testCase.Upgrader.conn.errorData = testCase.Error
|
||||
}
|
||||
if testCase.Stdout != "" {
|
||||
testCase.Upgrader.conn.stdoutData = testCase.Stdout
|
||||
}
|
||||
if testCase.Stderr != "" {
|
||||
testCase.Upgrader.conn.stderrData = testCase.Stderr
|
||||
}
|
||||
var localOut, localErr *bytes.Buffer
|
||||
if testCase.Stdout != "" {
|
||||
localOut = &bytes.Buffer{}
|
||||
}
|
||||
if testCase.Stderr != "" {
|
||||
localErr = &bytes.Buffer{}
|
||||
}
|
||||
e := New(&client.Request{}, &client.Config{}, []string{"ls", "/"}, strings.NewReader(testCase.Stdin), localOut, localErr, testCase.Tty)
|
||||
e.upgrader = testCase.Upgrader
|
||||
err := e.Execute()
|
||||
hasErr := err != nil
|
||||
if hasErr != testCase.ShouldError {
|
||||
t.Fatalf("%d: expected %t, got %t: %v", i, testCase.ShouldError, hasErr, err)
|
||||
}
|
||||
|
||||
conn := testCase.Upgrader.conn
|
||||
if testCase.Error != "" {
|
||||
if conn.errorStream == nil {
|
||||
t.Fatalf("%d: expected error stream creation", i)
|
||||
}
|
||||
if !conn.errorStream.readCalled {
|
||||
t.Fatalf("%d: expected error stream read", i)
|
||||
}
|
||||
if e, a := testCase.Error, err.Error(); !strings.Contains(a, e) {
|
||||
t.Fatalf("%d: expected error stream read '%v', got '%v'", i, e, a)
|
||||
}
|
||||
if !conn.errorStream.resetCalled {
|
||||
t.Fatalf("%d: expected error reset", i)
|
||||
}
|
||||
}
|
||||
|
||||
if testCase.ShouldError {
|
||||
continue
|
||||
}
|
||||
|
||||
if testCase.Stdin != "" {
|
||||
if conn.stdin == nil {
|
||||
t.Fatalf("%d: expected stdin stream creation", i)
|
||||
}
|
||||
if !conn.stdin.writeCalled {
|
||||
t.Fatalf("%d: expected stdin stream write", i)
|
||||
}
|
||||
if e, a := testCase.Stdin, string(conn.stdin.dataWritten); e != a {
|
||||
t.Fatalf("%d: expected stdin write %v, got %v", i, e, a)
|
||||
}
|
||||
if !conn.stdin.resetCalled {
|
||||
t.Fatalf("%d: expected stdin reset", i)
|
||||
}
|
||||
}
|
||||
|
||||
if testCase.Stdout != "" {
|
||||
if conn.stdout == nil {
|
||||
t.Fatalf("%d: expected stdout stream creation", i)
|
||||
}
|
||||
if !conn.stdout.readCalled {
|
||||
t.Fatalf("%d: expected stdout stream read", i)
|
||||
}
|
||||
if e, a := testCase.Stdout, localOut; e != a.String() {
|
||||
t.Fatalf("%d: expected stdout data '%s', got '%s'", i, e, a)
|
||||
}
|
||||
if !conn.stdout.resetCalled {
|
||||
t.Fatalf("%d: expected stdout reset", i)
|
||||
}
|
||||
}
|
||||
|
||||
if testCase.Stderr != "" {
|
||||
if testCase.Tty {
|
||||
if conn.stderr != nil {
|
||||
t.Fatalf("%d: unexpected stderr stream creation", i)
|
||||
}
|
||||
if localErr.String() != "" {
|
||||
t.Fatalf("%d: unexpected stderr data '%s'", i, localErr)
|
||||
}
|
||||
} else {
|
||||
if conn.stderr == nil {
|
||||
t.Fatalf("%d: expected stderr stream creation", i)
|
||||
}
|
||||
if !conn.stderr.readCalled {
|
||||
t.Fatalf("%d: expected stderr stream read", i)
|
||||
}
|
||||
if e, a := testCase.Stderr, localErr; e != a.String() {
|
||||
t.Fatalf("%d: expected stderr data '%s', got '%s'", i, e, a)
|
||||
}
|
||||
if !conn.stderr.resetCalled {
|
||||
t.Fatalf("%d: expected stderr reset", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !conn.closeCalled {
|
||||
t.Fatalf("%d: expected upgraded connection to get closed")
|
||||
}
|
||||
}
|
||||
}
|
@@ -18,6 +18,7 @@ package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
@@ -33,6 +34,7 @@ import (
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/labels"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/runtime"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/watch"
|
||||
watchjson "github.com/GoogleCloudPlatform/kubernetes/pkg/watch/json"
|
||||
"github.com/golang/glog"
|
||||
@@ -277,7 +279,7 @@ func (r *Request) setParam(paramName, value string) *Request {
|
||||
if r.params == nil {
|
||||
r.params = make(url.Values)
|
||||
}
|
||||
r.params[paramName] = []string{value}
|
||||
r.params[paramName] = append(r.params[paramName], value)
|
||||
return r
|
||||
}
|
||||
|
||||
@@ -347,8 +349,10 @@ func (r *Request) finalURL() string {
|
||||
finalURL.Path = p
|
||||
|
||||
query := url.Values{}
|
||||
for key, value := range r.params {
|
||||
query[key] = value
|
||||
for key, values := range r.params {
|
||||
for _, value := range values {
|
||||
query.Add(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
if r.namespaceSet && r.namespaceInQuery {
|
||||
@@ -434,6 +438,41 @@ func (r *Request) Stream() (io.ReadCloser, error) {
|
||||
return resp.Body, nil
|
||||
}
|
||||
|
||||
// Upgrade upgrades the request so that it supports multiplexed bidirectional
|
||||
// streams. The current implementation uses SPDY, but this could be replaced
|
||||
// with HTTP/2 once it's available, or something else.
|
||||
func (r *Request) Upgrade(config *Config, newRoundTripperFunc func(*tls.Config) httpstream.UpgradeRoundTripper) (httpstream.Connection, error) {
|
||||
if r.err != nil {
|
||||
return nil, r.err
|
||||
}
|
||||
|
||||
tlsConfig, err := TLSConfigFor(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
upgradeRoundTripper := newRoundTripperFunc(tlsConfig)
|
||||
wrapper, err := HTTPWrappersForConfig(config, upgradeRoundTripper)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.client = &http.Client{Transport: wrapper}
|
||||
|
||||
req, err := http.NewRequest(r.verb, r.finalURL(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error creating request: %s", err)
|
||||
}
|
||||
|
||||
resp, err := r.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error sending request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return upgradeRoundTripper.NewConnection(resp)
|
||||
}
|
||||
|
||||
// Do formats and executes the request. Returns a Result object for easy response
|
||||
// processing.
|
||||
//
|
||||
@@ -513,6 +552,8 @@ func (r *Request) transformResponse(resp *http.Response, req *http.Request) ([]b
|
||||
}
|
||||
|
||||
switch {
|
||||
case resp.StatusCode == http.StatusSwitchingProtocols:
|
||||
// no-op, we've been upgraded
|
||||
case resp.StatusCode < http.StatusOK || resp.StatusCode > http.StatusPartialContent:
|
||||
if !isStatusResponse {
|
||||
var err error = &UnexpectedStatusError{
|
||||
|
@@ -18,6 +18,7 @@ package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
@@ -40,6 +41,7 @@ import (
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/labels"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/runtime"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/watch"
|
||||
watchjson "github.com/GoogleCloudPlatform/kubernetes/pkg/watch/json"
|
||||
)
|
||||
@@ -151,16 +153,22 @@ func TestRequestParam(t *testing.T) {
|
||||
if !api.Semantic.DeepDerivative(r.params, url.Values{"foo": []string{"a"}}) {
|
||||
t.Errorf("should have set a param: %#v", r)
|
||||
}
|
||||
|
||||
r.Param("bar", "1")
|
||||
r.Param("bar", "2")
|
||||
if !api.Semantic.DeepDerivative(r.params, url.Values{"foo": []string{"a"}, "bar": []string{"1", "2"}}) {
|
||||
t.Errorf("should have set a param: %#v", r)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestURI(t *testing.T) {
|
||||
r := (&Request{}).Param("foo", "a")
|
||||
r.Prefix("other")
|
||||
r.RequestURI("/test?foo=b&a=b")
|
||||
r.RequestURI("/test?foo=b&a=b&c=1&c=2")
|
||||
if r.path != "/test" {
|
||||
t.Errorf("path is wrong: %#v", r)
|
||||
}
|
||||
if !api.Semantic.DeepDerivative(r.params, url.Values{"a": []string{"b"}, "foo": []string{"b"}}) {
|
||||
if !api.Semantic.DeepDerivative(r.params, url.Values{"a": []string{"b"}, "foo": []string{"b"}, "c": []string{"1", "2"}}) {
|
||||
t.Errorf("should have set a param: %#v", r)
|
||||
}
|
||||
}
|
||||
@@ -443,6 +451,122 @@ func TestRequestStream(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type fakeUpgradeConnection struct{}
|
||||
|
||||
func (c *fakeUpgradeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *fakeUpgradeConnection) Close() error {
|
||||
return nil
|
||||
}
|
||||
func (c *fakeUpgradeConnection) CloseChan() <-chan bool {
|
||||
return make(chan bool)
|
||||
}
|
||||
func (c *fakeUpgradeConnection) SetIdleTimeout(timeout time.Duration) {
|
||||
}
|
||||
|
||||
type fakeUpgradeRoundTripper struct {
|
||||
req *http.Request
|
||||
conn httpstream.Connection
|
||||
}
|
||||
|
||||
func (f *fakeUpgradeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
f.req = req
|
||||
b := []byte{}
|
||||
body := ioutil.NopCloser(bytes.NewReader(b))
|
||||
resp := &http.Response{
|
||||
StatusCode: 101,
|
||||
Body: body,
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (f *fakeUpgradeRoundTripper) NewConnection(resp *http.Response) (httpstream.Connection, error) {
|
||||
return f.conn, nil
|
||||
}
|
||||
|
||||
func TestRequestUpgrade(t *testing.T) {
|
||||
uri, _ := url.Parse("http://localhost/")
|
||||
testCases := []struct {
|
||||
Request *Request
|
||||
Config *Config
|
||||
RoundTripper *fakeUpgradeRoundTripper
|
||||
Err bool
|
||||
AuthBasicHeader bool
|
||||
AuthBearerHeader bool
|
||||
}{
|
||||
{
|
||||
Request: &Request{err: errors.New("bail")},
|
||||
Err: true,
|
||||
},
|
||||
{
|
||||
Request: &Request{},
|
||||
Config: &Config{
|
||||
TLSClientConfig: TLSClientConfig{
|
||||
CAFile: "foo",
|
||||
},
|
||||
Insecure: true,
|
||||
},
|
||||
Err: true,
|
||||
},
|
||||
{
|
||||
Request: &Request{},
|
||||
Config: &Config{
|
||||
Username: "u",
|
||||
Password: "p",
|
||||
BearerToken: "b",
|
||||
},
|
||||
Err: true,
|
||||
},
|
||||
{
|
||||
Request: NewRequest(nil, "", uri, testapi.Codec(), true, true),
|
||||
Config: &Config{
|
||||
Username: "u",
|
||||
Password: "p",
|
||||
},
|
||||
AuthBasicHeader: true,
|
||||
Err: false,
|
||||
},
|
||||
{
|
||||
Request: NewRequest(nil, "", uri, testapi.Codec(), true, true),
|
||||
Config: &Config{
|
||||
BearerToken: "b",
|
||||
},
|
||||
AuthBearerHeader: true,
|
||||
Err: false,
|
||||
},
|
||||
}
|
||||
for i, testCase := range testCases {
|
||||
r := testCase.Request
|
||||
rt := &fakeUpgradeRoundTripper{}
|
||||
expectedConn := &fakeUpgradeConnection{}
|
||||
conn, err := r.Upgrade(testCase.Config, func(config *tls.Config) httpstream.UpgradeRoundTripper {
|
||||
rt.conn = expectedConn
|
||||
return rt
|
||||
})
|
||||
_ = conn
|
||||
hasErr := err != nil
|
||||
if hasErr != testCase.Err {
|
||||
t.Errorf("%d: expected %t, got %t: %v", i, testCase.Err, hasErr, r.err)
|
||||
}
|
||||
if testCase.Err {
|
||||
continue
|
||||
}
|
||||
|
||||
if testCase.AuthBasicHeader && !strings.Contains(rt.req.Header.Get("Authorization"), "Basic") {
|
||||
t.Errorf("%d: expected basic auth header, got: %s", rt.req.Header.Get("Authorization"))
|
||||
}
|
||||
|
||||
if testCase.AuthBearerHeader && !strings.Contains(rt.req.Header.Get("Authorization"), "Bearer") {
|
||||
t.Errorf("%d: expected bearer auth header, got: %s", rt.req.Header.Get("Authorization"))
|
||||
}
|
||||
|
||||
if e, a := expectedConn, conn; e != a {
|
||||
t.Errorf("%d: conn: expected %#v, got %#v", i, e, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestDo(t *testing.T) {
|
||||
testCases := []struct {
|
||||
Request *Request
|
||||
|
@@ -355,7 +355,7 @@ func (e Equalities) deepValueDerive(v1, v2 reflect.Value, visited map[visit]bool
|
||||
}
|
||||
|
||||
// DeepDerivative is similar to DeepEqual except that unset fields in a1 are
|
||||
// ignored (not compared). This allows we to focus on the fields that matter to
|
||||
// ignored (not compared). This allows us to focus on the fields that matter to
|
||||
// the semantic comparison.
|
||||
//
|
||||
// The unset fields include a nil pointer and an empty string.
|
||||
|
@@ -17,7 +17,9 @@ limitations under the License.
|
||||
package httplog
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"time"
|
||||
@@ -46,6 +48,10 @@ type logger interface {
|
||||
|
||||
// Add a layer on top of ResponseWriter, so we can track latency and error
|
||||
// message sources.
|
||||
//
|
||||
// TODO now that we're using go-restful, we shouldn't need to be wrapping
|
||||
// the http.ResponseWriter. We can recover panics from go-restful, and
|
||||
// the logging value is questionable.
|
||||
type respLogger struct {
|
||||
status int
|
||||
statusStack string
|
||||
@@ -68,7 +74,7 @@ func (passthroughLogger) Addf(format string, data ...interface{}) {
|
||||
|
||||
// DefaultStacktracePred is the default implementation of StacktracePred.
|
||||
func DefaultStacktracePred(status int) bool {
|
||||
return status < http.StatusOK || status >= http.StatusBadRequest
|
||||
return (status < http.StatusOK || status >= http.StatusBadRequest) && status != http.StatusSwitchingProtocols
|
||||
}
|
||||
|
||||
// NewLogged turns a normal response writer into a logged response writer.
|
||||
@@ -186,3 +192,8 @@ func (rl *respLogger) WriteHeader(status int) {
|
||||
}
|
||||
rl.w.WriteHeader(status)
|
||||
}
|
||||
|
||||
// Hijack implements http.Hijacker.
|
||||
func (rl *respLogger) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return rl.w.(http.Hijacker).Hijack()
|
||||
}
|
||||
|
@@ -183,7 +183,7 @@ func (f *Factory) BindFlags(flags *pflag.FlagSet) {
|
||||
}
|
||||
|
||||
// NewKubectlCommand creates the `kubectl` command and its nested children.
|
||||
func (f *Factory) NewKubectlCommand(out io.Writer) *cobra.Command {
|
||||
func (f *Factory) NewKubectlCommand(in io.Reader, out, err io.Writer) *cobra.Command {
|
||||
// Parent command to which all subcommands are added.
|
||||
cmds := &cobra.Command{
|
||||
Use: "kubectl",
|
||||
@@ -211,6 +211,9 @@ Find more information at https://github.com/GoogleCloudPlatform/kubernetes.`,
|
||||
cmds.AddCommand(f.NewCmdRollingUpdate(out))
|
||||
cmds.AddCommand(f.NewCmdResize(out))
|
||||
|
||||
cmds.AddCommand(f.NewCmdExec(in, out, err))
|
||||
cmds.AddCommand(f.NewCmdPortForward())
|
||||
|
||||
cmds.AddCommand(f.NewCmdRunContainer(out))
|
||||
cmds.AddCommand(f.NewCmdStop(out))
|
||||
cmds.AddCommand(f.NewCmdExposeService(out))
|
||||
|
133
pkg/kubectl/cmd/exec.go
Normal file
133
pkg/kubectl/cmd/exec.go
Normal file
@@ -0,0 +1,133 @@
|
||||
/*
|
||||
Copyright 2014 Google Inc. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/client/remotecommand"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/kubectl/cmd/util"
|
||||
"github.com/docker/docker/pkg/term"
|
||||
"github.com/golang/glog"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func (f *Factory) NewCmdExec(cmdIn io.Reader, cmdOut, cmdErr io.Writer) *cobra.Command {
|
||||
flags := &struct {
|
||||
pod string
|
||||
container string
|
||||
stdin bool
|
||||
tty bool
|
||||
}{}
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "exec -p <pod> -c <container> -- <command> [<args...>]",
|
||||
Short: "Execute a command in a container.",
|
||||
Long: `Execute a command in a container.
|
||||
Examples:
|
||||
$ kubectl exec -p 123456-7890 -c ruby-container date
|
||||
<returns output from running 'date' in ruby-container from pod 123456-7890>
|
||||
|
||||
$ kubectl exec -p 123456-7890 -c ruby-container -i -t -- bash -il
|
||||
<switches to raw terminal mode, sends stdin to 'bash' in ruby-container from
|
||||
pod 123456-780 and sends stdout/stderr from 'bash' back to the client`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if len(flags.pod) == 0 {
|
||||
usageError(cmd, "<pod> is required for exec")
|
||||
}
|
||||
|
||||
if len(args) < 1 {
|
||||
usageError(cmd, "<command> is required for exec")
|
||||
}
|
||||
|
||||
namespace, err := f.DefaultNamespace(cmd)
|
||||
checkErr(err)
|
||||
|
||||
client, err := f.Client(cmd)
|
||||
checkErr(err)
|
||||
|
||||
pod, err := client.Pods(namespace).Get(flags.pod)
|
||||
checkErr(err)
|
||||
|
||||
if pod.Status.Phase != api.PodRunning {
|
||||
glog.Fatalf("Unable to execute command because pod is not running. Current status=%v", pod.Status.Phase)
|
||||
}
|
||||
|
||||
if len(flags.container) == 0 {
|
||||
flags.container = pod.Spec.Containers[0].Name
|
||||
}
|
||||
|
||||
var stdin io.Reader
|
||||
if util.GetFlagBool(cmd, "stdin") {
|
||||
stdin = cmdIn
|
||||
if flags.tty {
|
||||
if file, ok := cmdIn.(*os.File); ok {
|
||||
inFd := file.Fd()
|
||||
if term.IsTerminal(inFd) {
|
||||
oldState, err := term.SetRawTerminal(inFd)
|
||||
if err != nil {
|
||||
glog.Fatal(err)
|
||||
}
|
||||
// this handles a clean exit, where the command finished
|
||||
defer term.RestoreTerminal(inFd, oldState)
|
||||
|
||||
// SIGINT is handled by term.SetRawTerminal (it runs a goroutine that listens
|
||||
// for SIGINT and restores the terminal before exiting)
|
||||
|
||||
// this handles SIGTERM
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigChan
|
||||
term.RestoreTerminal(inFd, oldState)
|
||||
os.Exit(0)
|
||||
}()
|
||||
} else {
|
||||
glog.Warning("Stdin is not a terminal")
|
||||
}
|
||||
} else {
|
||||
flags.tty = false
|
||||
glog.Warning("Unable to use a TTY")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
config, err := f.ClientConfig(cmd)
|
||||
checkErr(err)
|
||||
|
||||
req := client.RESTClient.Get().
|
||||
Prefix("proxy").
|
||||
Resource("minions").
|
||||
Name(pod.Status.Host).
|
||||
Suffix("exec", namespace, flags.pod, flags.container)
|
||||
|
||||
e := remotecommand.New(req, config, args, stdin, cmdOut, cmdErr, flags.tty)
|
||||
err = e.Execute()
|
||||
checkErr(err)
|
||||
},
|
||||
}
|
||||
cmd.Flags().StringVarP(&flags.pod, "pod", "p", "", "Pod name")
|
||||
// TODO support UID
|
||||
cmd.Flags().StringVarP(&flags.container, "container", "c", "", "Container name")
|
||||
cmd.Flags().BoolVarP(&flags.stdin, "stdin", "i", false, "Pass stdin to the container")
|
||||
cmd.Flags().BoolVarP(&flags.tty, "tty", "t", false, "Stdin is a TTY")
|
||||
return cmd
|
||||
}
|
104
pkg/kubectl/cmd/portforward.go
Normal file
104
pkg/kubectl/cmd/portforward.go
Normal file
@@ -0,0 +1,104 @@
|
||||
/*
|
||||
Copyright 2014 Google Inc. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/client/portforward"
|
||||
"github.com/golang/glog"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func (f *Factory) NewCmdPortForward() *cobra.Command {
|
||||
flags := &struct {
|
||||
pod string
|
||||
container string
|
||||
}{}
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "port-forward -p <pod> [<local port>:]<remote port> [<port>...]",
|
||||
Short: "Forward 1 or more local ports to a pod.",
|
||||
Long: `Forward 1 or more local ports to a pod.
|
||||
Examples:
|
||||
$ kubectl port-forward -p mypod 5000 6000
|
||||
<listens on ports 5000 and 6000 locally, forwarding data to/from ports 5000
|
||||
and 6000 in the pod>
|
||||
|
||||
$ kubectl port-forward -p mypod 8888:5000
|
||||
<listens on port 8888 locally, forwarding to 5000 in the pod>
|
||||
|
||||
$ kubectl port-forward -p mypod :5000
|
||||
<listens on a random port locally, forwarding to 5000 in the pod>
|
||||
|
||||
$ kubectl port-forward -p mypod 0:5000
|
||||
<listens on a random port locally, forwarding to 5000 in the pod>
|
||||
`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if len(flags.pod) == 0 {
|
||||
usageError(cmd, "<pod> is required for exec")
|
||||
}
|
||||
|
||||
if len(args) < 1 {
|
||||
usageError(cmd, "at least 1 <port> is required for port-forward")
|
||||
}
|
||||
|
||||
namespace, err := f.DefaultNamespace(cmd)
|
||||
checkErr(err)
|
||||
|
||||
client, err := f.Client(cmd)
|
||||
checkErr(err)
|
||||
|
||||
pod, err := client.Pods(namespace).Get(flags.pod)
|
||||
checkErr(err)
|
||||
|
||||
if pod.Status.Phase != api.PodRunning {
|
||||
glog.Fatalf("Unable to execute command because pod is not running. Current status=%v", pod.Status.Phase)
|
||||
}
|
||||
|
||||
config, err := f.ClientConfig(cmd)
|
||||
checkErr(err)
|
||||
|
||||
signals := make(chan os.Signal, 1)
|
||||
signal.Notify(signals, os.Interrupt)
|
||||
defer signal.Stop(signals)
|
||||
|
||||
stopCh := make(chan struct{}, 1)
|
||||
go func() {
|
||||
<-signals
|
||||
close(stopCh)
|
||||
}()
|
||||
|
||||
req := client.RESTClient.Get().
|
||||
Prefix("proxy").
|
||||
Resource("minions").
|
||||
Name(pod.Status.Host).
|
||||
Suffix("portForward", namespace, flags.pod)
|
||||
|
||||
pf, err := portforward.New(req, config, args, stopCh)
|
||||
checkErr(err)
|
||||
|
||||
err = pf.ForwardPorts()
|
||||
checkErr(err)
|
||||
},
|
||||
}
|
||||
cmd.Flags().StringVarP(&flags.pod, "pod", "p", "", "Pod name")
|
||||
// TODO support UID
|
||||
return cmd
|
||||
}
|
@@ -198,6 +198,127 @@ func (d *dockerContainerCommandRunner) RunInContainer(containerID string, cmd []
|
||||
return buf.Bytes(), <-errChan
|
||||
}
|
||||
|
||||
// ExecInContainer uses nsenter to run the command inside the container identified by containerID.
|
||||
//
|
||||
// TODO:
|
||||
// - match cgroups of container
|
||||
// - should we support `docker exec`?
|
||||
// - should we support nsenter in a container, running with elevated privs and --pid=host?
|
||||
func (d *dockerContainerCommandRunner) ExecInContainer(containerId string, cmd []string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool) error {
|
||||
container, err := d.client.InspectContainer(containerId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !container.State.Running {
|
||||
return fmt.Errorf("container not running (%s)", container)
|
||||
}
|
||||
|
||||
containerPid := container.State.Pid
|
||||
|
||||
// TODO what if the container doesn't have `env`???
|
||||
args := []string{"-t", fmt.Sprintf("%d", containerPid), "-m", "-i", "-u", "-n", "-p", "--", "env", "-i"}
|
||||
args = append(args, fmt.Sprintf("HOSTNAME=%s", container.Config.Hostname))
|
||||
args = append(args, container.Config.Env...)
|
||||
args = append(args, cmd...)
|
||||
glog.Infof("ARGS %#v", args)
|
||||
command := exec.Command("nsenter", args...)
|
||||
// TODO use exec.LookPath
|
||||
if tty {
|
||||
p, err := StartPty(command)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer p.Close()
|
||||
|
||||
// make sure to close the stdout stream
|
||||
defer stdout.Close()
|
||||
|
||||
if stdin != nil {
|
||||
go io.Copy(p, stdin)
|
||||
}
|
||||
|
||||
if stdout != nil {
|
||||
go io.Copy(stdout, p)
|
||||
}
|
||||
|
||||
return command.Wait()
|
||||
} else {
|
||||
cp := func(dst io.WriteCloser, src io.Reader, closeDst bool) {
|
||||
defer func() {
|
||||
if closeDst {
|
||||
dst.Close()
|
||||
}
|
||||
}()
|
||||
io.Copy(dst, src)
|
||||
}
|
||||
if stdin != nil {
|
||||
inPipe, err := command.StdinPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
cp(inPipe, stdin, false)
|
||||
inPipe.Close()
|
||||
}()
|
||||
}
|
||||
|
||||
if stdout != nil {
|
||||
outPipe, err := command.StdoutPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go cp(stdout, outPipe, true)
|
||||
}
|
||||
|
||||
if stderr != nil {
|
||||
errPipe, err := command.StderrPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go cp(stderr, errPipe, true)
|
||||
}
|
||||
|
||||
return command.Run()
|
||||
}
|
||||
}
|
||||
|
||||
// PortForward executes socat in the pod's network namespace and copies
|
||||
// data between stream (representing the user's local connection on their
|
||||
// computer) and the specified port in the container.
|
||||
//
|
||||
// TODO:
|
||||
// - match cgroups of container
|
||||
// - should we support nsenter + socat on the host? (current impl)
|
||||
// - should we support nsenter + socat in a container, running with elevated privs and --pid=host?
|
||||
func (d *dockerContainerCommandRunner) PortForward(podInfraContainerID string, port uint16, stream io.ReadWriteCloser) error {
|
||||
container, err := d.client.InspectContainer(podInfraContainerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !container.State.Running {
|
||||
return fmt.Errorf("container not running (%s)", container)
|
||||
}
|
||||
|
||||
containerPid := container.State.Pid
|
||||
// TODO use exec.LookPath for socat / what if the host doesn't have it???
|
||||
args := []string{"-t", fmt.Sprintf("%d", containerPid), "-n", "socat", "-", fmt.Sprintf("TCP4:localhost:%d", port)}
|
||||
// TODO use exec.LookPath
|
||||
command := exec.Command("nsenter", args...)
|
||||
in, err := command.StdinPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
out, err := command.StdoutPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go io.Copy(in, stream)
|
||||
go io.Copy(stream, out)
|
||||
return command.Run()
|
||||
}
|
||||
|
||||
// NewDockerContainerCommandRunner creates a ContainerCommandRunner which uses nsinit to run a command
|
||||
// inside a container.
|
||||
func NewDockerContainerCommandRunner(client DockerInterface) ContainerCommandRunner {
|
||||
@@ -690,4 +811,6 @@ func ConnectToDockerOrDie(dockerEndpoint string) DockerInterface {
|
||||
type ContainerCommandRunner interface {
|
||||
RunInContainer(containerID string, cmd []string) ([]byte, error)
|
||||
GetDockerServerVersion() ([]uint, error)
|
||||
ExecInContainer(containerID string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error
|
||||
PortForward(podInfraContainerID string, port uint16, stream io.ReadWriteCloser) error
|
||||
}
|
||||
|
30
pkg/kubelet/dockertools/pty_linux.go
Normal file
30
pkg/kubelet/dockertools/pty_linux.go
Normal file
@@ -0,0 +1,30 @@
|
||||
// +build linux
|
||||
|
||||
/*
|
||||
Copyright 2015 Google Inc. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package dockertools
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
|
||||
"github.com/kr/pty"
|
||||
)
|
||||
|
||||
func StartPty(c *exec.Cmd) (*os.File, error) {
|
||||
return pty.Start(c)
|
||||
}
|
28
pkg/kubelet/dockertools/pty_unsupported.go
Normal file
28
pkg/kubelet/dockertools/pty_unsupported.go
Normal file
@@ -0,0 +1,28 @@
|
||||
// +build !linux
|
||||
|
||||
/*
|
||||
Copyright 2015 Google Inc. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package dockertools
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
func StartPty(c *exec.Cmd) (pty *os.File, err error) {
|
||||
return nil, nil
|
||||
}
|
@@ -86,7 +86,8 @@ func NewMainKubelet(
|
||||
clusterDomain string,
|
||||
clusterDNS net.IP,
|
||||
masterServiceNamespace string,
|
||||
volumePlugins []volume.Plugin) (*Kubelet, error) {
|
||||
volumePlugins []volume.Plugin,
|
||||
streamingConnectionIdleTimeout time.Duration) (*Kubelet, error) {
|
||||
if rootDirectory == "" {
|
||||
return nil, fmt.Errorf("invalid root directory %q", rootDirectory)
|
||||
}
|
||||
@@ -104,28 +105,29 @@ func NewMainKubelet(
|
||||
serviceLister := &cache.StoreToServiceLister{serviceStore}
|
||||
|
||||
klet := &Kubelet{
|
||||
hostname: hostname,
|
||||
dockerClient: dockerClient,
|
||||
etcdClient: etcdClient,
|
||||
kubeClient: kubeClient,
|
||||
rootDirectory: rootDirectory,
|
||||
resyncInterval: resyncInterval,
|
||||
podInfraContainerImage: podInfraContainerImage,
|
||||
podWorkers: newPodWorkers(),
|
||||
dockerIDToRef: map[dockertools.DockerID]*api.ObjectReference{},
|
||||
runner: dockertools.NewDockerContainerCommandRunner(dockerClient),
|
||||
httpClient: &http.Client{},
|
||||
pullQPS: pullQPS,
|
||||
pullBurst: pullBurst,
|
||||
minimumGCAge: minimumGCAge,
|
||||
maxContainerCount: maxContainerCount,
|
||||
sourceReady: sourceReady,
|
||||
clusterDomain: clusterDomain,
|
||||
clusterDNS: clusterDNS,
|
||||
serviceLister: serviceLister,
|
||||
masterServiceNamespace: masterServiceNamespace,
|
||||
prober: newProbeHolder(),
|
||||
readiness: newReadinessStates(),
|
||||
hostname: hostname,
|
||||
dockerClient: dockerClient,
|
||||
etcdClient: etcdClient,
|
||||
kubeClient: kubeClient,
|
||||
rootDirectory: rootDirectory,
|
||||
resyncInterval: resyncInterval,
|
||||
podInfraContainerImage: podInfraContainerImage,
|
||||
podWorkers: newPodWorkers(),
|
||||
dockerIDToRef: map[dockertools.DockerID]*api.ObjectReference{},
|
||||
runner: dockertools.NewDockerContainerCommandRunner(dockerClient),
|
||||
httpClient: &http.Client{},
|
||||
pullQPS: pullQPS,
|
||||
pullBurst: pullBurst,
|
||||
minimumGCAge: minimumGCAge,
|
||||
maxContainerCount: maxContainerCount,
|
||||
sourceReady: sourceReady,
|
||||
clusterDomain: clusterDomain,
|
||||
clusterDNS: clusterDNS,
|
||||
serviceLister: serviceLister,
|
||||
masterServiceNamespace: masterServiceNamespace,
|
||||
prober: newProbeHolder(),
|
||||
readiness: newReadinessStates(),
|
||||
streamingConnectionIdleTimeout: streamingConnectionIdleTimeout,
|
||||
}
|
||||
|
||||
if err := klet.setupDataDirs(); err != nil {
|
||||
@@ -207,6 +209,10 @@ type Kubelet struct {
|
||||
prober probeHolder
|
||||
// container readiness state holder
|
||||
readiness *readinessStates
|
||||
|
||||
// how long to keep idle streaming command execution/port forwarding
|
||||
// connections open before terminating them
|
||||
streamingConnectionIdleTimeout time.Duration
|
||||
}
|
||||
|
||||
// getRootDir returns the full path to the directory under which kubelet can
|
||||
@@ -1686,6 +1692,40 @@ func (kl *Kubelet) RunInContainer(podFullName string, uid types.UID, container s
|
||||
return kl.runner.RunInContainer(dockerContainer.ID, cmd)
|
||||
}
|
||||
|
||||
// ExecInContainer executes a command in a container, connecting the supplied
|
||||
// stdin/stdout/stderr to the command's IO streams.
|
||||
func (kl *Kubelet) ExecInContainer(podFullName string, uid types.UID, container string, cmd []string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool) error {
|
||||
if kl.runner == nil {
|
||||
return fmt.Errorf("no runner specified.")
|
||||
}
|
||||
dockerContainers, err := dockertools.GetKubeletDockerContainers(kl.dockerClient, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dockerContainer, found, _ := dockerContainers.FindPodContainer(podFullName, uid, container)
|
||||
if !found {
|
||||
return fmt.Errorf("container not found (%q)", container)
|
||||
}
|
||||
return kl.runner.ExecInContainer(dockerContainer.ID, cmd, stdin, stdout, stderr, tty)
|
||||
}
|
||||
|
||||
// PortForward connects to the pod's port and copies data between the port
|
||||
// and the stream.
|
||||
func (kl *Kubelet) PortForward(podFullName string, uid types.UID, port uint16, stream io.ReadWriteCloser) error {
|
||||
if kl.runner == nil {
|
||||
return fmt.Errorf("no runner specified.")
|
||||
}
|
||||
dockerContainers, err := dockertools.GetKubeletDockerContainers(kl.dockerClient, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
podInfraContainer, found, _ := dockerContainers.FindPodContainer(podFullName, uid, dockertools.PodInfraContainerName)
|
||||
if !found {
|
||||
return fmt.Errorf("Unable to find pod infra container for pod %s, uid %v", podFullName, uid)
|
||||
}
|
||||
return kl.runner.PortForward(podInfraContainer.ID, port, stream)
|
||||
}
|
||||
|
||||
// BirthCry sends an event that the kubelet has started up.
|
||||
func (kl *Kubelet) BirthCry() {
|
||||
// Make an event that kubelet restarted.
|
||||
@@ -1699,3 +1739,7 @@ func (kl *Kubelet) BirthCry() {
|
||||
}
|
||||
record.Eventf(ref, "starting", "Starting kubelet.")
|
||||
}
|
||||
|
||||
func (kl *Kubelet) StreamingConnectionIdleTimeout() time.Duration {
|
||||
return kl.streamingConnectionIdleTimeout
|
||||
}
|
||||
|
@@ -17,7 +17,9 @@ limitations under the License.
|
||||
package kubelet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -1486,9 +1488,15 @@ func TestGetContainerInfoWithNoMatchingContainers(t *testing.T) {
|
||||
}
|
||||
|
||||
type fakeContainerCommandRunner struct {
|
||||
Cmd []string
|
||||
ID string
|
||||
E error
|
||||
Cmd []string
|
||||
ID string
|
||||
E error
|
||||
Stdin io.Reader
|
||||
Stdout io.WriteCloser
|
||||
Stderr io.WriteCloser
|
||||
TTY bool
|
||||
Port uint16
|
||||
Stream io.ReadWriteCloser
|
||||
}
|
||||
|
||||
func (f *fakeContainerCommandRunner) RunInContainer(id string, cmd []string) ([]byte, error) {
|
||||
@@ -1501,6 +1509,23 @@ func (f *fakeContainerCommandRunner) GetDockerServerVersion() ([]uint, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeContainerCommandRunner) ExecInContainer(id string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error {
|
||||
f.Cmd = cmd
|
||||
f.ID = id
|
||||
f.Stdin = in
|
||||
f.Stdout = out
|
||||
f.Stderr = err
|
||||
f.TTY = tty
|
||||
return f.E
|
||||
}
|
||||
|
||||
func (f *fakeContainerCommandRunner) PortForward(podInfraContainerID string, port uint16, stream io.ReadWriteCloser) error {
|
||||
f.ID = podInfraContainerID
|
||||
f.Port = port
|
||||
f.Stream = stream
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRunInContainerNoSuchPod(t *testing.T) {
|
||||
fakeCommandRunner := fakeContainerCommandRunner{}
|
||||
kubelet, fakeDocker := newTestKubelet(t)
|
||||
@@ -2805,5 +2830,252 @@ func TestGetPodReadyCondition(t *testing.T) {
|
||||
t.Errorf("On test case %v, expected:\n%+v\ngot\n%+v\n", i, test.expected, condition)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestExecInContainerNoSuchPod(t *testing.T) {
|
||||
fakeCommandRunner := fakeContainerCommandRunner{}
|
||||
kubelet, fakeDocker := newTestKubelet(t)
|
||||
fakeDocker.ContainerList = []docker.APIContainers{}
|
||||
kubelet.runner = &fakeCommandRunner
|
||||
|
||||
podName := "podFoo"
|
||||
podNamespace := "etcd"
|
||||
containerName := "containerFoo"
|
||||
err := kubelet.ExecInContainer(
|
||||
GetPodFullName(&api.BoundPod{ObjectMeta: api.ObjectMeta{Name: podName, Namespace: podNamespace}}),
|
||||
"",
|
||||
containerName,
|
||||
[]string{"ls"},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
false,
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("unexpected non-error")
|
||||
}
|
||||
if fakeCommandRunner.ID != "" {
|
||||
t.Fatal("unexpected invocation of runner.ExecInContainer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecInContainerNoSuchContainer(t *testing.T) {
|
||||
fakeCommandRunner := fakeContainerCommandRunner{}
|
||||
kubelet, fakeDocker := newTestKubelet(t)
|
||||
kubelet.runner = &fakeCommandRunner
|
||||
|
||||
podName := "podFoo"
|
||||
podNamespace := "etcd"
|
||||
containerID := "containerFoo"
|
||||
|
||||
fakeDocker.ContainerList = []docker.APIContainers{
|
||||
{
|
||||
ID: "notfound",
|
||||
Names: []string{"/k8s_notfound_" + podName + "." + podNamespace + ".test_12345678_42"},
|
||||
},
|
||||
}
|
||||
|
||||
err := kubelet.ExecInContainer(
|
||||
GetPodFullName(&api.BoundPod{ObjectMeta: api.ObjectMeta{
|
||||
UID: "12345678",
|
||||
Name: podName,
|
||||
Namespace: podNamespace,
|
||||
Annotations: map[string]string{ConfigSourceAnnotationKey: "test"},
|
||||
}}),
|
||||
"",
|
||||
containerID,
|
||||
[]string{"ls"},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
false,
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("unexpected non-error")
|
||||
}
|
||||
if fakeCommandRunner.ID != "" {
|
||||
t.Fatal("unexpected invocation of runner.ExecInContainer")
|
||||
}
|
||||
}
|
||||
|
||||
type fakeReadWriteCloser struct{}
|
||||
|
||||
func (f *fakeReadWriteCloser) Write(data []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (f *fakeReadWriteCloser) Read(data []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (f *fakeReadWriteCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestExecInContainer(t *testing.T) {
|
||||
fakeCommandRunner := fakeContainerCommandRunner{}
|
||||
kubelet, fakeDocker := newTestKubelet(t)
|
||||
kubelet.runner = &fakeCommandRunner
|
||||
|
||||
podName := "podFoo"
|
||||
podNamespace := "etcd"
|
||||
containerID := "containerFoo"
|
||||
command := []string{"ls"}
|
||||
stdin := &bytes.Buffer{}
|
||||
stdout := &fakeReadWriteCloser{}
|
||||
stderr := &fakeReadWriteCloser{}
|
||||
tty := true
|
||||
|
||||
fakeDocker.ContainerList = []docker.APIContainers{
|
||||
{
|
||||
ID: containerID,
|
||||
Names: []string{"/k8s_" + containerID + "_" + podName + "." + podNamespace + ".test_12345678_42"},
|
||||
},
|
||||
}
|
||||
|
||||
err := kubelet.ExecInContainer(
|
||||
GetPodFullName(&api.BoundPod{ObjectMeta: api.ObjectMeta{
|
||||
UID: "12345678",
|
||||
Name: podName,
|
||||
Namespace: podNamespace,
|
||||
Annotations: map[string]string{ConfigSourceAnnotationKey: "test"},
|
||||
}}),
|
||||
"",
|
||||
containerID,
|
||||
[]string{"ls"},
|
||||
stdin,
|
||||
stdout,
|
||||
stderr,
|
||||
tty,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
if e, a := containerID, fakeCommandRunner.ID; e != a {
|
||||
t.Fatalf("container id: expected %s, got %s", e, a)
|
||||
}
|
||||
if e, a := command, fakeCommandRunner.Cmd; !reflect.DeepEqual(e, a) {
|
||||
t.Fatalf("command: expected '%v', got '%v'", e, a)
|
||||
}
|
||||
if e, a := stdin, fakeCommandRunner.Stdin; e != a {
|
||||
t.Fatalf("stdin: expected %#v, got %#v", e, a)
|
||||
}
|
||||
if e, a := stdout, fakeCommandRunner.Stdout; e != a {
|
||||
t.Fatalf("stdout: expected %#v, got %#v", e, a)
|
||||
}
|
||||
if e, a := stderr, fakeCommandRunner.Stderr; e != a {
|
||||
t.Fatalf("stderr: expected %#v, got %#v", e, a)
|
||||
}
|
||||
if e, a := tty, fakeCommandRunner.TTY; e != a {
|
||||
t.Fatalf("tty: expected %t, got %t", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortForwardNoSuchPod(t *testing.T) {
|
||||
fakeCommandRunner := fakeContainerCommandRunner{}
|
||||
kubelet, fakeDocker := newTestKubelet(t)
|
||||
fakeDocker.ContainerList = []docker.APIContainers{}
|
||||
kubelet.runner = &fakeCommandRunner
|
||||
|
||||
podName := "podFoo"
|
||||
podNamespace := "etcd"
|
||||
var port uint16 = 5000
|
||||
|
||||
err := kubelet.PortForward(
|
||||
GetPodFullName(&api.BoundPod{ObjectMeta: api.ObjectMeta{Name: podName, Namespace: podNamespace}}),
|
||||
"",
|
||||
port,
|
||||
nil,
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("unexpected non-error")
|
||||
}
|
||||
if fakeCommandRunner.ID != "" {
|
||||
t.Fatal("unexpected invocation of runner.PortForward")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortForwardNoSuchContainer(t *testing.T) {
|
||||
fakeCommandRunner := fakeContainerCommandRunner{}
|
||||
kubelet, fakeDocker := newTestKubelet(t)
|
||||
kubelet.runner = &fakeCommandRunner
|
||||
|
||||
podName := "podFoo"
|
||||
podNamespace := "etcd"
|
||||
var port uint16 = 5000
|
||||
|
||||
fakeDocker.ContainerList = []docker.APIContainers{
|
||||
{
|
||||
ID: "notfound",
|
||||
Names: []string{"/k8s_notfound_" + podName + "." + podNamespace + ".test_12345678_42"},
|
||||
},
|
||||
}
|
||||
|
||||
err := kubelet.PortForward(
|
||||
GetPodFullName(&api.BoundPod{ObjectMeta: api.ObjectMeta{
|
||||
UID: "12345678",
|
||||
Name: podName,
|
||||
Namespace: podNamespace,
|
||||
Annotations: map[string]string{ConfigSourceAnnotationKey: "test"},
|
||||
}}),
|
||||
"",
|
||||
port,
|
||||
nil,
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("unexpected non-error")
|
||||
}
|
||||
if fakeCommandRunner.ID != "" {
|
||||
t.Fatal("unexpected invocation of runner.PortForward")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortForward(t *testing.T) {
|
||||
fakeCommandRunner := fakeContainerCommandRunner{}
|
||||
kubelet, fakeDocker := newTestKubelet(t)
|
||||
kubelet.runner = &fakeCommandRunner
|
||||
|
||||
podName := "podFoo"
|
||||
podNamespace := "etcd"
|
||||
containerID := "containerFoo"
|
||||
var port uint16 = 5000
|
||||
stream := &fakeReadWriteCloser{}
|
||||
|
||||
infraContainerID := "infra"
|
||||
kubelet.podInfraContainerImage = "POD"
|
||||
|
||||
fakeDocker.ContainerList = []docker.APIContainers{
|
||||
{
|
||||
ID: infraContainerID,
|
||||
Names: []string{"/k8s_" + kubelet.podInfraContainerImage + "_" + podName + "." + podNamespace + ".test_12345678_42"},
|
||||
},
|
||||
{
|
||||
ID: containerID,
|
||||
Names: []string{"/k8s_" + containerID + "_" + podName + "." + podNamespace + ".test_12345678_42"},
|
||||
},
|
||||
}
|
||||
|
||||
err := kubelet.PortForward(
|
||||
GetPodFullName(&api.BoundPod{ObjectMeta: api.ObjectMeta{
|
||||
UID: "12345678",
|
||||
Name: podName,
|
||||
Namespace: podNamespace,
|
||||
Annotations: map[string]string{ConfigSourceAnnotationKey: "test"},
|
||||
}}),
|
||||
"",
|
||||
port,
|
||||
stream,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
if e, a := infraContainerID, fakeCommandRunner.ID; e != a {
|
||||
t.Fatalf("container id: expected %s, got %s", e, a)
|
||||
}
|
||||
if e, a := port, fakeCommandRunner.Port; e != a {
|
||||
t.Fatalf("port: expected %v, got %v", e, a)
|
||||
}
|
||||
if e, a := stream, fakeCommandRunner.Stream; e != a {
|
||||
t.Fatalf("stream: expected %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
@@ -27,6 +27,7 @@ import (
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
|
||||
@@ -34,6 +35,8 @@ import (
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/httplog"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/runtime"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/types"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream/spdy"
|
||||
"github.com/golang/glog"
|
||||
"github.com/google/cadvisor/info"
|
||||
)
|
||||
@@ -69,8 +72,11 @@ type HostInterface interface {
|
||||
GetPodByName(namespace, name string) (*api.BoundPod, bool)
|
||||
GetPodStatus(name string, uid types.UID) (api.PodStatus, error)
|
||||
RunInContainer(name string, uid types.UID, container string, cmd []string) ([]byte, error)
|
||||
ExecInContainer(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error
|
||||
GetKubeletContainerLogs(podFullName, containerName, tail string, follow bool, stdout, stderr io.Writer) error
|
||||
ServeLogs(w http.ResponseWriter, req *http.Request)
|
||||
PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error
|
||||
StreamingConnectionIdleTimeout() time.Duration
|
||||
}
|
||||
|
||||
// NewServer initializes and configures a kubelet.Server object to handle HTTP requests.
|
||||
@@ -99,6 +105,8 @@ func (s *Server) InstallDefaultHandlers() {
|
||||
// InstallDeguggingHandlers registers the HTTP request patterns that serve logs or run commands/containers
|
||||
func (s *Server) InstallDebuggingHandlers() {
|
||||
s.mux.HandleFunc("/run/", s.handleRun)
|
||||
s.mux.HandleFunc("/exec/", s.handleExec)
|
||||
s.mux.HandleFunc("/portForward/", s.handlePortForward)
|
||||
|
||||
s.mux.HandleFunc("/logs/", s.handleLogs)
|
||||
s.mux.HandleFunc("/containerLogs/", s.handleContainerLogs)
|
||||
@@ -301,6 +309,28 @@ func (s *Server) handleSpec(w http.ResponseWriter, req *http.Request) {
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
func parseContainerCoordinates(path string) (namespace, pod string, uid types.UID, container string, err error) {
|
||||
parts := strings.Split(path, "/")
|
||||
|
||||
if len(parts) == 5 {
|
||||
namespace = parts[2]
|
||||
pod = parts[3]
|
||||
container = parts[4]
|
||||
return
|
||||
}
|
||||
|
||||
if len(parts) == 6 {
|
||||
namespace = parts[2]
|
||||
pod = parts[3]
|
||||
uid = types.UID(parts[4])
|
||||
container = parts[5]
|
||||
return
|
||||
}
|
||||
|
||||
err = fmt.Errorf("Unexpected path %s. Expected /.../.../<namespace>/<pod>/<container> or /.../.../<namespace>/<pod>/<uid>/<container>", path)
|
||||
return
|
||||
}
|
||||
|
||||
// handleRun handles requests to run a command inside a container.
|
||||
func (s *Server) handleRun(w http.ResponseWriter, req *http.Request) {
|
||||
u, err := url.ParseRequestURI(req.RequestURI)
|
||||
@@ -308,20 +338,9 @@ func (s *Server) handleRun(w http.ResponseWriter, req *http.Request) {
|
||||
s.error(w, err)
|
||||
return
|
||||
}
|
||||
parts := strings.Split(u.Path, "/")
|
||||
var podNamespace, podID, container string
|
||||
var uid types.UID
|
||||
if len(parts) == 5 {
|
||||
podNamespace = parts[2]
|
||||
podID = parts[3]
|
||||
container = parts[4]
|
||||
} else if len(parts) == 6 {
|
||||
podNamespace = parts[2]
|
||||
podID = parts[3]
|
||||
uid = types.UID(parts[4])
|
||||
container = parts[5]
|
||||
} else {
|
||||
http.Error(w, "Unexpected path for command running", http.StatusBadRequest)
|
||||
podNamespace, podID, uid, container, err := parseContainerCoordinates(u.Path)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
pod, ok := s.host.GetPodByName(podNamespace, podID)
|
||||
@@ -339,6 +358,227 @@ func (s *Server) handleRun(w http.ResponseWriter, req *http.Request) {
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
// handleExec handles requests to run a command inside a container.
|
||||
func (s *Server) handleExec(w http.ResponseWriter, req *http.Request) {
|
||||
u, err := url.ParseRequestURI(req.RequestURI)
|
||||
if err != nil {
|
||||
s.error(w, err)
|
||||
return
|
||||
}
|
||||
podNamespace, podID, uid, container, err := parseContainerCoordinates(u.Path)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
pod, ok := s.host.GetPodByName(podNamespace, podID)
|
||||
if !ok {
|
||||
http.Error(w, "Pod does not exist", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
req.ParseForm()
|
||||
// start at 1 for error stream
|
||||
expectedStreams := 1
|
||||
if req.FormValue(api.ExecStdinParam) == "1" {
|
||||
expectedStreams++
|
||||
}
|
||||
if req.FormValue(api.ExecStdoutParam) == "1" {
|
||||
expectedStreams++
|
||||
}
|
||||
tty := req.FormValue(api.ExecTTYParam) == "1"
|
||||
if !tty && req.FormValue(api.ExecStderrParam) == "1" {
|
||||
expectedStreams++
|
||||
}
|
||||
|
||||
if expectedStreams == 1 {
|
||||
http.Error(w, "You must specify at least 1 of stdin, stdout, stderr", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
streamCh := make(chan httpstream.Stream)
|
||||
|
||||
upgrader := spdy.NewResponseUpgrader()
|
||||
conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream) error {
|
||||
streamCh <- stream
|
||||
return nil
|
||||
})
|
||||
// from this point on, we can no longer call methods on w
|
||||
if conn == nil {
|
||||
// The upgrader is responsible for notifying the client of any errors that
|
||||
// occurred during upgrading. All we can do is return here at this point
|
||||
// if we weren't successful in upgrading.
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout())
|
||||
|
||||
// TODO find a good default timeout value
|
||||
// TODO make it configurable?
|
||||
expired := time.NewTimer(2 * time.Second)
|
||||
|
||||
var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream
|
||||
receivedStreams := 0
|
||||
WaitForStreams:
|
||||
for {
|
||||
select {
|
||||
case stream := <-streamCh:
|
||||
streamType := stream.Headers().Get(api.StreamType)
|
||||
switch streamType {
|
||||
case api.StreamTypeError:
|
||||
errorStream = stream
|
||||
defer errorStream.Reset()
|
||||
receivedStreams++
|
||||
case api.StreamTypeStdin:
|
||||
stdinStream = stream
|
||||
receivedStreams++
|
||||
case api.StreamTypeStdout:
|
||||
stdoutStream = stream
|
||||
receivedStreams++
|
||||
case api.StreamTypeStderr:
|
||||
stderrStream = stream
|
||||
receivedStreams++
|
||||
default:
|
||||
glog.Errorf("Unexpected stream type: '%s'", streamType)
|
||||
}
|
||||
if receivedStreams == expectedStreams {
|
||||
break WaitForStreams
|
||||
}
|
||||
case <-expired.C:
|
||||
// TODO find a way to return the error to the user. Maybe use a separate
|
||||
// stream to report errors?
|
||||
glog.Error("Timed out waiting for client to create streams")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if stdinStream != nil {
|
||||
// close our half of the input stream, since we won't be writing to it
|
||||
stdinStream.Close()
|
||||
}
|
||||
|
||||
err = s.host.ExecInContainer(GetPodFullName(pod), uid, container, u.Query()[api.ExecCommandParamm], stdinStream, stdoutStream, stderrStream, tty)
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("Error executing command in container: %v", err)
|
||||
glog.Error(msg)
|
||||
errorStream.Write([]byte(msg))
|
||||
}
|
||||
}
|
||||
|
||||
func parsePodCoordinates(path string) (namespace, pod string, uid types.UID, err error) {
|
||||
parts := strings.Split(path, "/")
|
||||
|
||||
if len(parts) == 4 {
|
||||
namespace = parts[2]
|
||||
pod = parts[3]
|
||||
return
|
||||
}
|
||||
|
||||
if len(parts) == 5 {
|
||||
namespace = parts[2]
|
||||
pod = parts[3]
|
||||
uid = types.UID(parts[4])
|
||||
return
|
||||
}
|
||||
|
||||
err = fmt.Errorf("Unexpected path %s. Expected /.../.../<namespace>/<pod> or /.../.../<namespace>/<pod>/<uid>", path)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) handlePortForward(w http.ResponseWriter, req *http.Request) {
|
||||
u, err := url.ParseRequestURI(req.RequestURI)
|
||||
if err != nil {
|
||||
s.error(w, err)
|
||||
return
|
||||
}
|
||||
podNamespace, podID, uid, err := parsePodCoordinates(u.Path)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
pod, ok := s.host.GetPodByName(podNamespace, podID)
|
||||
if !ok {
|
||||
http.Error(w, "Pod does not exist", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
streamChan := make(chan httpstream.Stream, 1)
|
||||
upgrader := spdy.NewResponseUpgrader()
|
||||
conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream) error {
|
||||
portString := stream.Headers().Get(api.PortHeader)
|
||||
port, err := strconv.ParseUint(portString, 10, 16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to parse '%s' as a port: %v", portString, err)
|
||||
}
|
||||
if port < 1 {
|
||||
return fmt.Errorf("Port '%d' must be greater than 0", port)
|
||||
}
|
||||
streamChan <- stream
|
||||
return nil
|
||||
})
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout())
|
||||
|
||||
var dataStreamLock sync.Mutex
|
||||
dataStreamChans := make(map[string]chan httpstream.Stream)
|
||||
|
||||
Loop:
|
||||
for {
|
||||
select {
|
||||
case <-conn.CloseChan():
|
||||
break Loop
|
||||
case stream := <-streamChan:
|
||||
streamType := stream.Headers().Get(api.StreamType)
|
||||
port := stream.Headers().Get(api.PortHeader)
|
||||
dataStreamLock.Lock()
|
||||
switch streamType {
|
||||
case "error":
|
||||
ch := make(chan httpstream.Stream)
|
||||
dataStreamChans[port] = ch
|
||||
go waitForPortForwardDataStreamAndRun(GetPodFullName(pod), uid, stream, ch, s.host)
|
||||
case "data":
|
||||
ch, ok := dataStreamChans[port]
|
||||
if ok {
|
||||
ch <- stream
|
||||
delete(dataStreamChans, port)
|
||||
} else {
|
||||
glog.Errorf("Unable to locate data stream channel for port %s", port)
|
||||
}
|
||||
default:
|
||||
glog.Errorf("streamType header must be 'error' or 'data', got: '%s'", streamType)
|
||||
stream.Reset()
|
||||
}
|
||||
dataStreamLock.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func waitForPortForwardDataStreamAndRun(pod string, uid types.UID, errorStream httpstream.Stream, dataStreamChan chan httpstream.Stream, host HostInterface) {
|
||||
defer errorStream.Reset()
|
||||
|
||||
var dataStream httpstream.Stream
|
||||
|
||||
select {
|
||||
case dataStream = <-dataStreamChan:
|
||||
case <-time.After(1 * time.Second):
|
||||
errorStream.Write([]byte("Timed out waiting for data stream"))
|
||||
//TODO delete from dataStreamChans[port]
|
||||
return
|
||||
}
|
||||
|
||||
portString := dataStream.Headers().Get(api.PortHeader)
|
||||
port, _ := strconv.ParseUint(portString, 10, 16)
|
||||
err := host.PortForward(pod, uid, uint16(port), dataStream)
|
||||
if err != nil {
|
||||
msg := fmt.Errorf("Error forwarding port %d to pod %s, uid %v: %v", port, pod, uid, err)
|
||||
glog.Error(msg)
|
||||
errorStream.Write([]byte(msg.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP responds to HTTP requests on the Kubelet.
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
defer httplog.NewLogged(req, &w).StacktraceWhen(
|
||||
@@ -347,6 +587,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
http.StatusMovedPermanently,
|
||||
http.StatusTemporaryRedirect,
|
||||
http.StatusNotFound,
|
||||
http.StatusSwitchingProtocols,
|
||||
),
|
||||
).Log()
|
||||
s.mux.ServeHTTP(w, req)
|
||||
|
@@ -46,35 +46,36 @@ const defaultRootDir = "/var/lib/kubelet"
|
||||
// KubeletServer encapsulates all of the parameters necessary for starting up
|
||||
// a kubelet. These can either be set via command line or directly.
|
||||
type KubeletServer struct {
|
||||
Config string
|
||||
SyncFrequency time.Duration
|
||||
FileCheckFrequency time.Duration
|
||||
HTTPCheckFrequency time.Duration
|
||||
ManifestURL string
|
||||
EnableServer bool
|
||||
Address util.IP
|
||||
Port uint
|
||||
HostnameOverride string
|
||||
PodInfraContainerImage string
|
||||
DockerEndpoint string
|
||||
EtcdServerList util.StringList
|
||||
EtcdConfigFile string
|
||||
RootDirectory string
|
||||
AllowPrivileged bool
|
||||
RegistryPullQPS float64
|
||||
RegistryBurst int
|
||||
RunOnce bool
|
||||
EnableDebuggingHandlers bool
|
||||
MinimumGCAge time.Duration
|
||||
MaxContainerCount int
|
||||
AuthPath string
|
||||
CAdvisorPort uint
|
||||
OOMScoreAdj int
|
||||
APIServerList util.StringList
|
||||
ClusterDomain string
|
||||
MasterServiceNamespace string
|
||||
ClusterDNS util.IP
|
||||
ReallyCrashForTesting bool
|
||||
Config string
|
||||
SyncFrequency time.Duration
|
||||
FileCheckFrequency time.Duration
|
||||
HTTPCheckFrequency time.Duration
|
||||
ManifestURL string
|
||||
EnableServer bool
|
||||
Address util.IP
|
||||
Port uint
|
||||
HostnameOverride string
|
||||
PodInfraContainerImage string
|
||||
DockerEndpoint string
|
||||
EtcdServerList util.StringList
|
||||
EtcdConfigFile string
|
||||
RootDirectory string
|
||||
AllowPrivileged bool
|
||||
RegistryPullQPS float64
|
||||
RegistryBurst int
|
||||
RunOnce bool
|
||||
EnableDebuggingHandlers bool
|
||||
MinimumGCAge time.Duration
|
||||
MaxContainerCount int
|
||||
AuthPath string
|
||||
CAdvisorPort uint
|
||||
OOMScoreAdj int
|
||||
APIServerList util.StringList
|
||||
ClusterDomain string
|
||||
MasterServiceNamespace string
|
||||
ClusterDNS util.IP
|
||||
ReallyCrashForTesting bool
|
||||
StreamingConnectionIdleTimeout time.Duration
|
||||
}
|
||||
|
||||
// NewKubeletServer will create a new KubeletServer with default values.
|
||||
@@ -149,6 +150,7 @@ func (s *KubeletServer) AddFlags(fs *pflag.FlagSet) {
|
||||
fs.StringVar(&s.MasterServiceNamespace, "master_service_namespace", s.MasterServiceNamespace, "The namespace from which the kubernetes master services should be injected into pods")
|
||||
fs.Var(&s.ClusterDNS, "cluster_dns", "IP address for a cluster DNS server. If set, kubelet will configure all containers to use this for DNS resolution in addition to the host's DNS servers")
|
||||
fs.BoolVar(&s.ReallyCrashForTesting, "really_crash_for_testing", s.ReallyCrashForTesting, "If true, crash with panics more often.")
|
||||
fs.DurationVar(&s.StreamingConnectionIdleTimeout, "streaming_connection_idle_timeout", 0, "Maximum time a streaming connection can be idle before the connection is automatically closed. Example: '5m'")
|
||||
}
|
||||
|
||||
// Run runs the specified KubeletServer. This should never exit.
|
||||
@@ -184,32 +186,33 @@ func (s *KubeletServer) Run(_ []string) error {
|
||||
credentialprovider.SetPreferredDockercfgPath(s.RootDirectory)
|
||||
|
||||
kcfg := KubeletConfig{
|
||||
Address: s.Address,
|
||||
AllowPrivileged: s.AllowPrivileged,
|
||||
HostnameOverride: s.HostnameOverride,
|
||||
RootDirectory: s.RootDirectory,
|
||||
ConfigFile: s.Config,
|
||||
ManifestURL: s.ManifestURL,
|
||||
FileCheckFrequency: s.FileCheckFrequency,
|
||||
HTTPCheckFrequency: s.HTTPCheckFrequency,
|
||||
PodInfraContainerImage: s.PodInfraContainerImage,
|
||||
SyncFrequency: s.SyncFrequency,
|
||||
RegistryPullQPS: s.RegistryPullQPS,
|
||||
RegistryBurst: s.RegistryBurst,
|
||||
MinimumGCAge: s.MinimumGCAge,
|
||||
MaxContainerCount: s.MaxContainerCount,
|
||||
ClusterDomain: s.ClusterDomain,
|
||||
ClusterDNS: s.ClusterDNS,
|
||||
Runonce: s.RunOnce,
|
||||
Port: s.Port,
|
||||
CAdvisorPort: s.CAdvisorPort,
|
||||
EnableServer: s.EnableServer,
|
||||
EnableDebuggingHandlers: s.EnableDebuggingHandlers,
|
||||
DockerClient: dockertools.ConnectToDockerOrDie(s.DockerEndpoint),
|
||||
KubeClient: client,
|
||||
EtcdClient: kubelet.EtcdClientOrDie(s.EtcdServerList, s.EtcdConfigFile),
|
||||
MasterServiceNamespace: s.MasterServiceNamespace,
|
||||
VolumePlugins: ProbeVolumePlugins(),
|
||||
Address: s.Address,
|
||||
AllowPrivileged: s.AllowPrivileged,
|
||||
HostnameOverride: s.HostnameOverride,
|
||||
RootDirectory: s.RootDirectory,
|
||||
ConfigFile: s.Config,
|
||||
ManifestURL: s.ManifestURL,
|
||||
FileCheckFrequency: s.FileCheckFrequency,
|
||||
HTTPCheckFrequency: s.HTTPCheckFrequency,
|
||||
PodInfraContainerImage: s.PodInfraContainerImage,
|
||||
SyncFrequency: s.SyncFrequency,
|
||||
RegistryPullQPS: s.RegistryPullQPS,
|
||||
RegistryBurst: s.RegistryBurst,
|
||||
MinimumGCAge: s.MinimumGCAge,
|
||||
MaxContainerCount: s.MaxContainerCount,
|
||||
ClusterDomain: s.ClusterDomain,
|
||||
ClusterDNS: s.ClusterDNS,
|
||||
Runonce: s.RunOnce,
|
||||
Port: s.Port,
|
||||
CAdvisorPort: s.CAdvisorPort,
|
||||
EnableServer: s.EnableServer,
|
||||
EnableDebuggingHandlers: s.EnableDebuggingHandlers,
|
||||
DockerClient: dockertools.ConnectToDockerOrDie(s.DockerEndpoint),
|
||||
KubeClient: client,
|
||||
EtcdClient: kubelet.EtcdClientOrDie(s.EtcdServerList, s.EtcdConfigFile),
|
||||
MasterServiceNamespace: s.MasterServiceNamespace,
|
||||
VolumePlugins: ProbeVolumePlugins(),
|
||||
StreamingConnectionIdleTimeout: s.StreamingConnectionIdleTimeout,
|
||||
}
|
||||
|
||||
RunKubelet(&kcfg)
|
||||
@@ -368,33 +371,34 @@ func makePodSourceConfig(kc *KubeletConfig) *config.PodConfig {
|
||||
// KubeletConfig is all of the parameters necessary for running a kubelet.
|
||||
// TODO: This should probably be merged with KubeletServer. The extra object is a consequence of refactoring.
|
||||
type KubeletConfig struct {
|
||||
EtcdClient tools.EtcdClient
|
||||
KubeClient *client.Client
|
||||
DockerClient dockertools.DockerInterface
|
||||
CAdvisorPort uint
|
||||
Address util.IP
|
||||
AllowPrivileged bool
|
||||
HostnameOverride string
|
||||
RootDirectory string
|
||||
ConfigFile string
|
||||
ManifestURL string
|
||||
FileCheckFrequency time.Duration
|
||||
HTTPCheckFrequency time.Duration
|
||||
Hostname string
|
||||
PodInfraContainerImage string
|
||||
SyncFrequency time.Duration
|
||||
RegistryPullQPS float64
|
||||
RegistryBurst int
|
||||
MinimumGCAge time.Duration
|
||||
MaxContainerCount int
|
||||
ClusterDomain string
|
||||
ClusterDNS util.IP
|
||||
EnableServer bool
|
||||
EnableDebuggingHandlers bool
|
||||
Port uint
|
||||
Runonce bool
|
||||
MasterServiceNamespace string
|
||||
VolumePlugins []volume.Plugin
|
||||
EtcdClient tools.EtcdClient
|
||||
KubeClient *client.Client
|
||||
DockerClient dockertools.DockerInterface
|
||||
CAdvisorPort uint
|
||||
Address util.IP
|
||||
AllowPrivileged bool
|
||||
HostnameOverride string
|
||||
RootDirectory string
|
||||
ConfigFile string
|
||||
ManifestURL string
|
||||
FileCheckFrequency time.Duration
|
||||
HTTPCheckFrequency time.Duration
|
||||
Hostname string
|
||||
PodInfraContainerImage string
|
||||
SyncFrequency time.Duration
|
||||
RegistryPullQPS float64
|
||||
RegistryBurst int
|
||||
MinimumGCAge time.Duration
|
||||
MaxContainerCount int
|
||||
ClusterDomain string
|
||||
ClusterDNS util.IP
|
||||
EnableServer bool
|
||||
EnableDebuggingHandlers bool
|
||||
Port uint
|
||||
Runonce bool
|
||||
MasterServiceNamespace string
|
||||
VolumePlugins []volume.Plugin
|
||||
StreamingConnectionIdleTimeout time.Duration
|
||||
}
|
||||
|
||||
func createAndInitKubelet(kc *KubeletConfig, pc *config.PodConfig) (*kubelet.Kubelet, error) {
|
||||
@@ -417,7 +421,8 @@ func createAndInitKubelet(kc *KubeletConfig, pc *config.PodConfig) (*kubelet.Kub
|
||||
kc.ClusterDomain,
|
||||
net.IP(kc.ClusterDNS),
|
||||
kc.MasterServiceNamespace,
|
||||
kc.VolumePlugins)
|
||||
kc.VolumePlugins,
|
||||
kc.StreamingConnectionIdleTimeout)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@@ -25,25 +25,32 @@ import (
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/types"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream/spdy"
|
||||
"github.com/google/cadvisor/info"
|
||||
)
|
||||
|
||||
type fakeKubelet struct {
|
||||
podByNameFunc func(namespace, name string) (*api.BoundPod, bool)
|
||||
statusFunc func(name string) (api.PodStatus, error)
|
||||
containerInfoFunc func(podFullName string, uid types.UID, containerName string, req *info.ContainerInfoRequest) (*info.ContainerInfo, error)
|
||||
rootInfoFunc func(query *info.ContainerInfoRequest) (*info.ContainerInfo, error)
|
||||
machineInfoFunc func() (*info.MachineInfo, error)
|
||||
boundPodsFunc func() ([]api.BoundPod, error)
|
||||
logFunc func(w http.ResponseWriter, req *http.Request)
|
||||
runFunc func(podFullName string, uid types.UID, containerName string, cmd []string) ([]byte, error)
|
||||
dockerVersionFunc func() ([]uint, error)
|
||||
containerLogsFunc func(podFullName, containerName, tail string, follow bool, stdout, stderr io.Writer) error
|
||||
podByNameFunc func(namespace, name string) (*api.BoundPod, bool)
|
||||
statusFunc func(name string) (api.PodStatus, error)
|
||||
containerInfoFunc func(podFullName string, uid types.UID, containerName string, req *info.ContainerInfoRequest) (*info.ContainerInfo, error)
|
||||
rootInfoFunc func(query *info.ContainerInfoRequest) (*info.ContainerInfo, error)
|
||||
machineInfoFunc func() (*info.MachineInfo, error)
|
||||
boundPodsFunc func() ([]api.BoundPod, error)
|
||||
logFunc func(w http.ResponseWriter, req *http.Request)
|
||||
runFunc func(podFullName string, uid types.UID, containerName string, cmd []string) ([]byte, error)
|
||||
dockerVersionFunc func() ([]uint, error)
|
||||
execFunc func(pod string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error
|
||||
portForwardFunc func(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error
|
||||
containerLogsFunc func(podFullName, containerName, tail string, follow bool, stdout, stderr io.Writer) error
|
||||
streamingConnectionIdleTimeoutFunc func() time.Duration
|
||||
}
|
||||
|
||||
func (fk *fakeKubelet) GetPodByName(namespace, name string) (*api.BoundPod, bool) {
|
||||
@@ -86,6 +93,18 @@ func (fk *fakeKubelet) RunInContainer(podFullName string, uid types.UID, contain
|
||||
return fk.runFunc(podFullName, uid, containerName, cmd)
|
||||
}
|
||||
|
||||
func (fk *fakeKubelet) ExecInContainer(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error {
|
||||
return fk.execFunc(name, uid, container, cmd, in, out, err, tty)
|
||||
}
|
||||
|
||||
func (fk *fakeKubelet) PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error {
|
||||
return fk.portForwardFunc(name, uid, port, stream)
|
||||
}
|
||||
|
||||
func (fk *fakeKubelet) StreamingConnectionIdleTimeout() time.Duration {
|
||||
return fk.streamingConnectionIdleTimeoutFunc()
|
||||
}
|
||||
|
||||
type serverTestFramework struct {
|
||||
updateChan chan interface{}
|
||||
updateReader *channelReader
|
||||
@@ -542,3 +561,503 @@ func TestContainerLogsWithFollow(t *testing.T) {
|
||||
t.Errorf("Expected: '%v', got: '%v'", output, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeExecInContainerIdleTimeout(t *testing.T) {
|
||||
fw := newServerTest()
|
||||
|
||||
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
|
||||
return 100 * time.Millisecond
|
||||
}
|
||||
|
||||
idleSuccess := make(chan struct{})
|
||||
|
||||
fw.fakeKubelet.execFunc = func(podFullName string, uid types.UID, containerName string, cmd []string, in io.Reader, out, stderr io.WriteCloser, tty bool) error {
|
||||
select {
|
||||
case <-idleSuccess:
|
||||
case <-time.After(150 * time.Millisecond):
|
||||
t.Fatalf("execFunc timed out waiting for idle timeout")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
podNamespace := "other"
|
||||
podName := "foo"
|
||||
expectedContainerName := "baz"
|
||||
|
||||
url := fw.testHTTPServer.URL + "/exec/" + podNamespace + "/" + podName + "/" + expectedContainerName + "?c=ls&c=-a&" + api.ExecStdinParam + "=1"
|
||||
|
||||
upgradeRoundTripper := spdy.NewRoundTripper(nil)
|
||||
c := &http.Client{Transport: upgradeRoundTripper}
|
||||
|
||||
resp, err := c.Get(url)
|
||||
if err != nil {
|
||||
t.Fatalf("Got error GETing: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
conn, err := upgradeRoundTripper.NewConnection(resp)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error creating streaming connection: %s", err)
|
||||
}
|
||||
if conn == nil {
|
||||
t.Fatal("Unexpected nil connection")
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
h := http.Header{}
|
||||
h.Set("type", "input")
|
||||
stream, err := conn.CreateStream(h)
|
||||
if err != nil {
|
||||
t.Fatalf("error creating input stream: %v", err)
|
||||
}
|
||||
defer stream.Reset()
|
||||
|
||||
select {
|
||||
case <-conn.CloseChan():
|
||||
close(idleSuccess)
|
||||
case <-time.After(150 * time.Millisecond):
|
||||
t.Fatalf("Timed out waiting for connection closure due to idle timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeExecInContainer(t *testing.T) {
|
||||
tests := []struct {
|
||||
stdin bool
|
||||
stdout bool
|
||||
stderr bool
|
||||
tty bool
|
||||
responseStatusCode int
|
||||
uid bool
|
||||
}{
|
||||
{responseStatusCode: http.StatusBadRequest},
|
||||
{stdin: true, responseStatusCode: http.StatusSwitchingProtocols},
|
||||
{stdout: true, responseStatusCode: http.StatusSwitchingProtocols},
|
||||
{stderr: true, responseStatusCode: http.StatusSwitchingProtocols},
|
||||
{stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols},
|
||||
{stdout: true, stderr: true, tty: true, responseStatusCode: http.StatusSwitchingProtocols},
|
||||
{stdin: true, stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
fw := newServerTest()
|
||||
|
||||
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
|
||||
return 0
|
||||
}
|
||||
|
||||
podNamespace := "other"
|
||||
podName := "foo"
|
||||
expectedPodName := podName + "." + podNamespace + ".etcd"
|
||||
expectedUid := "9b01b80f-8fb4-11e4-95ab-4200af06647"
|
||||
expectedContainerName := "baz"
|
||||
expectedCommand := "ls -a"
|
||||
expectedStdin := "stdin"
|
||||
expectedStdout := "stdout"
|
||||
expectedStderr := "stderr"
|
||||
execFuncDone := make(chan struct{})
|
||||
clientStdoutReadDone := make(chan struct{})
|
||||
clientStderrReadDone := make(chan struct{})
|
||||
|
||||
fw.fakeKubelet.execFunc = func(podFullName string, uid types.UID, containerName string, cmd []string, in io.Reader, out, stderr io.WriteCloser, tty bool) error {
|
||||
defer close(execFuncDone)
|
||||
if podFullName != expectedPodName {
|
||||
t.Fatalf("%d: podFullName: expected %s, got %s", i, expectedPodName, podFullName)
|
||||
}
|
||||
if test.uid && string(uid) != expectedUid {
|
||||
t.Fatalf("%d: uid: expected %v, got %v", i, expectedUid, uid)
|
||||
}
|
||||
if containerName != expectedContainerName {
|
||||
t.Fatalf("%d: containerName: expected %s, got %s", i, expectedContainerName, containerName)
|
||||
}
|
||||
if strings.Join(cmd, " ") != expectedCommand {
|
||||
t.Fatalf("%d: cmd: expected: %s, got %v", i, expectedCommand, cmd)
|
||||
}
|
||||
|
||||
if test.stdin {
|
||||
if in == nil {
|
||||
t.Fatalf("%d: stdin: expected non-nil", i)
|
||||
}
|
||||
b := make([]byte, 10)
|
||||
n, err := in.Read(b)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error reading from stdin: %v", i, err)
|
||||
}
|
||||
if e, a := expectedStdin, string(b[0:n]); e != a {
|
||||
t.Fatalf("%d: stdin: expected to read %v, got %v", i, e, a)
|
||||
}
|
||||
} else if in != nil {
|
||||
t.Fatalf("%d: stdin: expected nil: %#v", i, in)
|
||||
}
|
||||
|
||||
if test.stdout {
|
||||
if out == nil {
|
||||
t.Fatalf("%d: stdout: expected non-nil", i)
|
||||
}
|
||||
_, err := out.Write([]byte(expectedStdout))
|
||||
if err != nil {
|
||||
t.Fatalf("%d:, error writing to stdout: %v", i, err)
|
||||
}
|
||||
out.Close()
|
||||
select {
|
||||
case <-clientStdoutReadDone:
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
t.Fatalf("%d: timed out waiting for client to read stdout", i)
|
||||
}
|
||||
} else if out != nil {
|
||||
t.Fatalf("%d: stdout: expected nil: %#v", i, out)
|
||||
}
|
||||
|
||||
if tty {
|
||||
if stderr != nil {
|
||||
t.Fatalf("%d: tty set but received non-nil stderr: %v", i, stderr)
|
||||
}
|
||||
} else if test.stderr {
|
||||
if stderr == nil {
|
||||
t.Fatalf("%d: stderr: expected non-nil", i)
|
||||
}
|
||||
_, err := stderr.Write([]byte(expectedStderr))
|
||||
if err != nil {
|
||||
t.Fatalf("%d:, error writing to stderr: %v", i, err)
|
||||
}
|
||||
stderr.Close()
|
||||
select {
|
||||
case <-clientStderrReadDone:
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
t.Fatalf("%d: timed out waiting for client to read stderr", i)
|
||||
}
|
||||
} else if stderr != nil {
|
||||
t.Fatalf("%d: stderr: expected nil: %#v", i, stderr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var url string
|
||||
if test.uid {
|
||||
url = fw.testHTTPServer.URL + "/exec/" + podNamespace + "/" + podName + "/" + expectedUid + "/" + expectedContainerName + "?command=ls&command=-a"
|
||||
} else {
|
||||
url = fw.testHTTPServer.URL + "/exec/" + podNamespace + "/" + podName + "/" + expectedContainerName + "?command=ls&command=-a"
|
||||
}
|
||||
if test.stdin {
|
||||
url += "&" + api.ExecStdinParam + "=1"
|
||||
}
|
||||
if test.stdout {
|
||||
url += "&" + api.ExecStdoutParam + "=1"
|
||||
}
|
||||
if test.stderr && !test.tty {
|
||||
url += "&" + api.ExecStderrParam + "=1"
|
||||
}
|
||||
if test.tty {
|
||||
url += "&" + api.ExecTTYParam + "=1"
|
||||
}
|
||||
|
||||
var (
|
||||
resp *http.Response
|
||||
err error
|
||||
upgradeRoundTripper httpstream.UpgradeRoundTripper
|
||||
c *http.Client
|
||||
)
|
||||
|
||||
if test.responseStatusCode != http.StatusSwitchingProtocols {
|
||||
c = &http.Client{}
|
||||
} else {
|
||||
upgradeRoundTripper = spdy.NewRoundTripper(nil)
|
||||
c = &http.Client{Transport: upgradeRoundTripper}
|
||||
}
|
||||
|
||||
resp, err = c.Get(url)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: Got error GETing: %v", i, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
_, err = ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Errorf("%d: Error reading response body: %v", i, err)
|
||||
}
|
||||
|
||||
if e, a := test.responseStatusCode, resp.StatusCode; e != a {
|
||||
t.Fatalf("%d: response status: expected %v, got %v", e, a)
|
||||
}
|
||||
|
||||
if test.responseStatusCode != http.StatusSwitchingProtocols {
|
||||
continue
|
||||
}
|
||||
|
||||
conn, err := upgradeRoundTripper.NewConnection(resp)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error creating streaming connection: %s", err)
|
||||
}
|
||||
if conn == nil {
|
||||
t.Fatalf("%d: unexpected nil conn", i)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
h := http.Header{}
|
||||
h.Set(api.StreamType, api.StreamTypeError)
|
||||
errorStream, err := conn.CreateStream(h)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error creating error stream: %v", i, err)
|
||||
}
|
||||
defer errorStream.Reset()
|
||||
|
||||
if test.stdin {
|
||||
h.Set(api.StreamType, api.StreamTypeStdin)
|
||||
stream, err := conn.CreateStream(h)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error creating stdin stream: %v", i, err)
|
||||
}
|
||||
defer stream.Reset()
|
||||
_, err = stream.Write([]byte(expectedStdin))
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error writing to stdin stream: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
var stdoutStream httpstream.Stream
|
||||
if test.stdout {
|
||||
h.Set(api.StreamType, api.StreamTypeStdout)
|
||||
stdoutStream, err = conn.CreateStream(h)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error creating stdout stream: %v", i, err)
|
||||
}
|
||||
defer stdoutStream.Reset()
|
||||
}
|
||||
|
||||
var stderrStream httpstream.Stream
|
||||
if test.stderr && !test.tty {
|
||||
h.Set(api.StreamType, api.StreamTypeStderr)
|
||||
stderrStream, err = conn.CreateStream(h)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error creating stderr stream: %v", i, err)
|
||||
}
|
||||
defer stderrStream.Reset()
|
||||
}
|
||||
|
||||
if test.stdout {
|
||||
output := make([]byte, 10)
|
||||
n, err := stdoutStream.Read(output)
|
||||
close(clientStdoutReadDone)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error reading from stdout stream: %v", i, err)
|
||||
}
|
||||
if e, a := expectedStdout, string(output[0:n]); e != a {
|
||||
t.Fatalf("%d: stdout: expected '%v', got '%v'", i, e, a)
|
||||
}
|
||||
}
|
||||
|
||||
if test.stderr && !test.tty {
|
||||
output := make([]byte, 10)
|
||||
n, err := stderrStream.Read(output)
|
||||
close(clientStderrReadDone)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error reading from stderr stream: %v", i, err)
|
||||
}
|
||||
if e, a := expectedStderr, string(output[0:n]); e != a {
|
||||
t.Fatalf("%d: stderr: expected '%v', got '%v'", i, e, a)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-execFuncDone:
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
t.Fatalf("%d: timed out waiting for execFunc to complete", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServePortForwardIdleTimeout(t *testing.T) {
|
||||
fw := newServerTest()
|
||||
|
||||
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
|
||||
return 100 * time.Millisecond
|
||||
}
|
||||
|
||||
idleSuccess := make(chan struct{})
|
||||
|
||||
fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error {
|
||||
select {
|
||||
case <-idleSuccess:
|
||||
case <-time.After(150 * time.Millisecond):
|
||||
t.Fatalf("execFunc timed out waiting for idle timeout")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
podNamespace := "other"
|
||||
podName := "foo"
|
||||
|
||||
url := fw.testHTTPServer.URL + "/portForward/" + podNamespace + "/" + podName
|
||||
|
||||
upgradeRoundTripper := spdy.NewRoundTripper(nil)
|
||||
c := &http.Client{Transport: upgradeRoundTripper}
|
||||
|
||||
resp, err := c.Get(url)
|
||||
if err != nil {
|
||||
t.Fatalf("Got error GETing: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
conn, err := upgradeRoundTripper.NewConnection(resp)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error creating streaming connection: %s", err)
|
||||
}
|
||||
if conn == nil {
|
||||
t.Fatal("Unexpected nil connection")
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
select {
|
||||
case <-conn.CloseChan():
|
||||
close(idleSuccess)
|
||||
case <-time.After(150 * time.Millisecond):
|
||||
t.Fatalf("Timed out waiting for connection closure due to idle timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServePortForward(t *testing.T) {
|
||||
tests := []struct {
|
||||
port string
|
||||
uid bool
|
||||
clientData string
|
||||
containerData string
|
||||
shouldError bool
|
||||
}{
|
||||
{port: "", shouldError: true},
|
||||
{port: "abc", shouldError: true},
|
||||
{port: "-1", shouldError: true},
|
||||
{port: "65536", shouldError: true},
|
||||
{port: "0", shouldError: true},
|
||||
{port: "1", shouldError: false},
|
||||
{port: "8000", shouldError: false},
|
||||
{port: "8000", clientData: "client data", containerData: "container data", shouldError: false},
|
||||
{port: "65535", shouldError: false},
|
||||
{port: "65535", uid: true, shouldError: false},
|
||||
}
|
||||
|
||||
podNamespace := "other"
|
||||
podName := "foo"
|
||||
expectedPodName := podName + "." + podNamespace + ".etcd"
|
||||
expectedUid := "9b01b80f-8fb4-11e4-95ab-4200af06647"
|
||||
|
||||
for i, test := range tests {
|
||||
fw := newServerTest()
|
||||
|
||||
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
|
||||
return 0
|
||||
}
|
||||
|
||||
portForwardFuncDone := make(chan struct{})
|
||||
|
||||
fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error {
|
||||
defer close(portForwardFuncDone)
|
||||
|
||||
if e, a := expectedPodName, name; e != a {
|
||||
t.Fatalf("%d: pod name: expected '%v', got '%v'", i, e, a)
|
||||
}
|
||||
|
||||
if e, a := expectedUid, uid; test.uid && e != string(a) {
|
||||
t.Fatalf("%d: uid: expected '%v', got '%v'", i, e, a)
|
||||
}
|
||||
|
||||
p, err := strconv.ParseUint(test.port, 10, 16)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error parsing port string '%s': %v", i, port, err)
|
||||
}
|
||||
if e, a := uint16(p), port; e != a {
|
||||
t.Fatalf("%d: port: expected '%v', got '%v'", i, e, a)
|
||||
}
|
||||
|
||||
if test.clientData != "" {
|
||||
fromClient := make([]byte, 32)
|
||||
n, err := stream.Read(fromClient)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error reading client data: %v", i, err)
|
||||
}
|
||||
if e, a := test.clientData, string(fromClient[0:n]); e != a {
|
||||
t.Fatalf("%d: client data: expected to receive '%v', got '%v'", i, e, a)
|
||||
}
|
||||
}
|
||||
|
||||
if test.containerData != "" {
|
||||
_, err := stream.Write([]byte(test.containerData))
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error writing container data: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var url string
|
||||
if test.uid {
|
||||
url = fmt.Sprintf("%s/portForward/%s/%s/%s", fw.testHTTPServer.URL, podNamespace, podName, expectedUid)
|
||||
} else {
|
||||
url = fmt.Sprintf("%s/portForward/%s/%s", fw.testHTTPServer.URL, podNamespace, podName)
|
||||
}
|
||||
|
||||
upgradeRoundTripper := spdy.NewRoundTripper(nil)
|
||||
c := &http.Client{Transport: upgradeRoundTripper}
|
||||
|
||||
resp, err := c.Get(url)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: Got error GETing: %v", i, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
conn, err := upgradeRoundTripper.NewConnection(resp)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error creating streaming connection: %s", err)
|
||||
}
|
||||
if conn == nil {
|
||||
t.Fatal("%d: Unexpected nil connection", i)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set("streamType", "error")
|
||||
headers.Set("port", test.port)
|
||||
errorStream, err := conn.CreateStream(headers)
|
||||
_ = errorStream
|
||||
haveErr := err != nil
|
||||
if e, a := test.shouldError, haveErr; e != a {
|
||||
t.Fatalf("%d: create stream: expected err=%t, got %t: %v", i, e, a, err)
|
||||
}
|
||||
|
||||
if test.shouldError {
|
||||
continue
|
||||
}
|
||||
|
||||
headers.Set("streamType", "data")
|
||||
headers.Set("port", test.port)
|
||||
dataStream, err := conn.CreateStream(headers)
|
||||
haveErr = err != nil
|
||||
if e, a := test.shouldError, haveErr; e != a {
|
||||
t.Fatalf("%d: create stream: expected err=%t, got %t: %v", i, e, a, err)
|
||||
}
|
||||
|
||||
if test.clientData != "" {
|
||||
_, err := dataStream.Write([]byte(test.clientData))
|
||||
if err != nil {
|
||||
t.Fatalf("%d: unexpected error writing client data: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
if test.containerData != "" {
|
||||
fromContainer := make([]byte, 32)
|
||||
n, err := dataStream.Read(fromContainer)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: unexpected error reading container data: %v", i, err)
|
||||
}
|
||||
if e, a := test.containerData, string(fromContainer[0:n]); e != a {
|
||||
t.Fatalf("%d: expected to receive '%v' from container, got '%v'", i, e, a)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-portForwardFuncDone:
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatalf("%d: timed out waiting for portForwardFuncDone", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
19
pkg/util/httpstream/doc.go
Normal file
19
pkg/util/httpstream/doc.go
Normal file
@@ -0,0 +1,19 @@
|
||||
/*
|
||||
Copyright 2015 Google Inc. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Package httpstream adds multiplexed streaming support to HTTP requests and
|
||||
// responses via connection upgrades.
|
||||
package httpstream
|
80
pkg/util/httpstream/httpstream.go
Normal file
80
pkg/util/httpstream/httpstream.go
Normal file
@@ -0,0 +1,80 @@
|
||||
/*
|
||||
Copyright 2015 Google Inc. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package httpstream
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
HeaderConnection = "Connection"
|
||||
HeaderUpgrade = "Upgrade"
|
||||
)
|
||||
|
||||
// NewStreamHandler defines a function that is called when a new Stream is
|
||||
// received. If no error is returned, the Stream is accepted; otherwise,
|
||||
// the stream is rejected.
|
||||
type NewStreamHandler func(Stream) error
|
||||
|
||||
// NoOpNewStreamHandler is a stream handler that accepts a new stream and
|
||||
// performs no other logic.
|
||||
func NoOpNewStreamHandler(stream Stream) error { return nil }
|
||||
|
||||
// UpgradeRoundTripper is a type of http.RoundTripper that is able to upgrade
|
||||
// HTTP requests to support multiplexed bidirectional streams. After RoundTrip()
|
||||
// is invoked, if the upgrade is successful, clients may retrieve the upgraded
|
||||
// connection by calling UpgradeRoundTripper.Connection().
|
||||
type UpgradeRoundTripper interface {
|
||||
http.RoundTripper
|
||||
// NewConnection validates the response and creates a new Connection.
|
||||
NewConnection(resp *http.Response) (Connection, error)
|
||||
}
|
||||
|
||||
// ResponseUpgrader knows how to upgrade HTTP requests and responses to
|
||||
// add streaming support to them.
|
||||
type ResponseUpgrader interface {
|
||||
// UpgradeResponse upgrades an HTTP response to one that supports multiplexed
|
||||
// streams. newStreamHandler will be called synchronously whenever the
|
||||
// other end of the upgraded connection creates a new stream.
|
||||
UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler NewStreamHandler) Connection
|
||||
}
|
||||
|
||||
// Connection represents an upgraded HTTP connection.
|
||||
type Connection interface {
|
||||
// CreateStream creates a new Stream with the supplied headers.
|
||||
CreateStream(headers http.Header) (Stream, error)
|
||||
// Close resets all streams and closes the connection.
|
||||
Close() error
|
||||
// CloseChan returns a channel that is closed when the underlying connection is closed.
|
||||
CloseChan() <-chan bool
|
||||
// SetIdleTimeout sets the amount of time the connection may remain idle before
|
||||
// it is automatically closed.
|
||||
SetIdleTimeout(timeout time.Duration)
|
||||
}
|
||||
|
||||
// Stream represents a bidirectional communications channel that is part of an
|
||||
// upgraded connection.
|
||||
type Stream interface {
|
||||
io.ReadWriteCloser
|
||||
// Reset closes both directions of the stream, indicating that neither client
|
||||
// or server can use it any more.
|
||||
Reset() error
|
||||
// Headers returns the headers used to create the stream.
|
||||
Headers() http.Header
|
||||
}
|
139
pkg/util/httpstream/spdy/connection.go
Normal file
139
pkg/util/httpstream/spdy/connection.go
Normal file
@@ -0,0 +1,139 @@
|
||||
/*
|
||||
Copyright 2015 Google Inc. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package spdy
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
|
||||
"github.com/docker/spdystream"
|
||||
"github.com/golang/glog"
|
||||
)
|
||||
|
||||
// connection maintains state about a spdystream.Connection and its associated
|
||||
// streams.
|
||||
type connection struct {
|
||||
conn *spdystream.Connection
|
||||
streams []httpstream.Stream
|
||||
streamLock sync.Mutex
|
||||
newStreamHandler httpstream.NewStreamHandler
|
||||
}
|
||||
|
||||
// NewClientConnection creates a new SPDY client connection.
|
||||
func NewClientConnection(conn net.Conn) (httpstream.Connection, error) {
|
||||
spdyConn, err := spdystream.NewConnection(conn, false)
|
||||
if err != nil {
|
||||
defer conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newConnection(spdyConn, httpstream.NoOpNewStreamHandler), nil
|
||||
}
|
||||
|
||||
// NewServerConnection creates a new SPDY server connection. newStreamHandler
|
||||
// will be invoked when the server receives a newly created stream from the
|
||||
// client.
|
||||
func NewServerConnection(conn net.Conn, newStreamHandler httpstream.NewStreamHandler) (httpstream.Connection, error) {
|
||||
spdyConn, err := spdystream.NewConnection(conn, true)
|
||||
if err != nil {
|
||||
defer conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newConnection(spdyConn, newStreamHandler), nil
|
||||
}
|
||||
|
||||
// newConnection returns a new connection wrapping conn. newStreamHandler
|
||||
// will be invoked when the server receives a newly created stream from the
|
||||
// client.
|
||||
func newConnection(conn *spdystream.Connection, newStreamHandler httpstream.NewStreamHandler) httpstream.Connection {
|
||||
c := &connection{conn: conn, newStreamHandler: newStreamHandler}
|
||||
go conn.Serve(c.newSpdyStream)
|
||||
return c
|
||||
}
|
||||
|
||||
// createStreamResponseTimeout indicates how long to wait for the other side to
|
||||
// acknowledge the new stream before timing out.
|
||||
const createStreamResponseTimeout = 2 * time.Second
|
||||
|
||||
// Close first sends a reset for all of the connection's streams, and then
|
||||
// closes the underlying spdystream.Connection.
|
||||
func (c *connection) Close() error {
|
||||
c.streamLock.Lock()
|
||||
for _, s := range c.streams {
|
||||
s.Reset()
|
||||
}
|
||||
c.streams = make([]httpstream.Stream, 0)
|
||||
c.streamLock.Unlock()
|
||||
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
// CreateStream creates a new stream with the specified headers and registers
|
||||
// it with the connection.
|
||||
func (c *connection) CreateStream(headers http.Header) (httpstream.Stream, error) {
|
||||
stream, err := c.conn.CreateStream(headers, nil, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = stream.WaitTimeout(createStreamResponseTimeout); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.registerStream(stream)
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
// registerStream adds the stream s to the connection's list of streams that
|
||||
// it owns.
|
||||
func (c *connection) registerStream(s httpstream.Stream) {
|
||||
c.streamLock.Lock()
|
||||
c.streams = append(c.streams, s)
|
||||
c.streamLock.Unlock()
|
||||
}
|
||||
|
||||
// CloseChan returns a channel that, when closed, indicates that the underlying
|
||||
// spdystream.Connection has been closed.
|
||||
func (c *connection) CloseChan() <-chan bool {
|
||||
return c.conn.CloseChan()
|
||||
}
|
||||
|
||||
// newSpdyStream is the internal new stream handler used by spdystream.Connection.Serve.
|
||||
// It calls connection's newStreamHandler, giving it the opportunity to accept or reject
|
||||
// the stream. If newStreamHandler returns an error, the stream is rejected. If not, the
|
||||
// stream is accepted and registered with the connection.
|
||||
func (c *connection) newSpdyStream(stream *spdystream.Stream) {
|
||||
err := c.newStreamHandler(stream)
|
||||
rejectStream := (err != nil)
|
||||
if rejectStream {
|
||||
glog.Warningf("Stream rejected: %v", err)
|
||||
stream.Reset()
|
||||
return
|
||||
}
|
||||
|
||||
c.registerStream(stream)
|
||||
stream.SendReply(http.Header{}, rejectStream)
|
||||
}
|
||||
|
||||
// SetIdleTimeout sets the amount of time the connection may remain idle before
|
||||
// it is automatically closed.
|
||||
func (c *connection) SetIdleTimeout(timeout time.Duration) {
|
||||
c.conn.SetIdleTimeout(timeout)
|
||||
}
|
130
pkg/util/httpstream/spdy/roundtripper.go
Normal file
130
pkg/util/httpstream/spdy/roundtripper.go
Normal file
@@ -0,0 +1,130 @@
|
||||
/*
|
||||
Copyright 2015 Google Inc. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package spdy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
|
||||
)
|
||||
|
||||
// SpdyRoundTripper knows how to upgrade an HTTP request to one that supports
|
||||
// multiplexed streams. After RoundTrip() is invoked, Conn will be set
|
||||
// and usable. SpdyRoundTripper implements the UpgradeRoundTripper interface.
|
||||
type SpdyRoundTripper struct {
|
||||
//tlsConfig holds the TLS configuration settings to use when connecting
|
||||
//to the remote server.
|
||||
tlsConfig *tls.Config
|
||||
|
||||
/* TODO according to http://golang.org/pkg/net/http/#RoundTripper, a RoundTripper
|
||||
must be safe for use by multiple concurrent goroutines. If this is absolutely
|
||||
necessary, we could keep a map from http.Request to net.Conn. In practice,
|
||||
a client will create an http.Client, set the transport to a new insteace of
|
||||
SpdyRoundTripper, and use it a single time, so this hopefully won't be an issue.
|
||||
*/
|
||||
// conn is the underlying network connection to the remote server.
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
// NewSpdyRoundTripper creates a new SpdyRoundTripper that will use
|
||||
// the specified tlsConfig.
|
||||
func NewRoundTripper(tlsConfig *tls.Config) httpstream.UpgradeRoundTripper {
|
||||
return &SpdyRoundTripper{tlsConfig: tlsConfig}
|
||||
}
|
||||
|
||||
// dial dials the host specified by req, using TLS if appropriate.
|
||||
func (s *SpdyRoundTripper) dial(req *http.Request) (net.Conn, error) {
|
||||
dialAddr := util.CanonicalAddr(req.URL)
|
||||
|
||||
if req.URL.Scheme == "http" {
|
||||
return net.Dial("tcp", dialAddr)
|
||||
}
|
||||
|
||||
// TODO validate the TLSClientConfig is set up?
|
||||
conn, err := tls.Dial("tcp", dialAddr, s.tlsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(dialAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = conn.VerifyHostname(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// RoundTrip executes the Request and upgrades it. After a successful upgrade,
|
||||
// clients may call SpdyRoundTripper.Connection() to retrieve the upgraded
|
||||
// connection.
|
||||
func (s *SpdyRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// TODO what's the best way to clone the request?
|
||||
r := *req
|
||||
req = &r
|
||||
req.Header.Add(httpstream.HeaderConnection, httpstream.HeaderUpgrade)
|
||||
req.Header.Add(httpstream.HeaderUpgrade, HeaderSpdy31)
|
||||
|
||||
conn, err := s.dial(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = req.Write(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.conn = conn
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// NewConnection validates the upgrade response, creating and returning a new
|
||||
// httpstream.Connection if there were no errors.
|
||||
func (s *SpdyRoundTripper) NewConnection(resp *http.Response) (httpstream.Connection, error) {
|
||||
connectionHeader := strings.ToLower(resp.Header.Get(httpstream.HeaderConnection))
|
||||
upgradeHeader := strings.ToLower(resp.Header.Get(httpstream.HeaderUpgrade))
|
||||
if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(HeaderSpdy31)) {
|
||||
responseError := ""
|
||||
responseErrorBytes, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
responseError = "Unable to read error from server response"
|
||||
} else {
|
||||
responseError = string(responseErrorBytes)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("Unable to upgrade connection: %s", responseError)
|
||||
}
|
||||
|
||||
return NewClientConnection(s.conn)
|
||||
}
|
226
pkg/util/httpstream/spdy/roundtripper_test.go
Normal file
226
pkg/util/httpstream/spdy/roundtripper_test.go
Normal file
@@ -0,0 +1,226 @@
|
||||
/*
|
||||
Copyright 2015 Google Inc. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package spdy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
|
||||
)
|
||||
|
||||
func TestRoundTripAndNewConnection(t *testing.T) {
|
||||
testCases := []struct {
|
||||
serverConnectionHeader string
|
||||
serverUpgradeHeader string
|
||||
useTLS bool
|
||||
shouldError bool
|
||||
}{
|
||||
{
|
||||
serverConnectionHeader: "",
|
||||
serverUpgradeHeader: "",
|
||||
shouldError: true,
|
||||
},
|
||||
{
|
||||
serverConnectionHeader: "Upgrade",
|
||||
serverUpgradeHeader: "",
|
||||
shouldError: true,
|
||||
},
|
||||
{
|
||||
serverConnectionHeader: "",
|
||||
serverUpgradeHeader: "SPDY/3.1",
|
||||
shouldError: true,
|
||||
},
|
||||
{
|
||||
serverConnectionHeader: "Upgrade",
|
||||
serverUpgradeHeader: "SPDY/3.1",
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
serverConnectionHeader: "Upgrade",
|
||||
serverUpgradeHeader: "SPDY/3.1",
|
||||
useTLS: true,
|
||||
shouldError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, testCase := range testCases {
|
||||
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
if testCase.shouldError {
|
||||
if e, a := httpstream.HeaderUpgrade, req.Header.Get(httpstream.HeaderConnection); e != a {
|
||||
t.Fatalf("%d: Expected connection=upgrade header, got '%s", i, a)
|
||||
}
|
||||
|
||||
w.Header().Set(httpstream.HeaderConnection, testCase.serverConnectionHeader)
|
||||
w.Header().Set(httpstream.HeaderUpgrade, testCase.serverUpgradeHeader)
|
||||
w.WriteHeader(http.StatusSwitchingProtocols)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
streamCh := make(chan httpstream.Stream)
|
||||
|
||||
responseUpgrader := NewResponseUpgrader()
|
||||
spdyConn := responseUpgrader.UpgradeResponse(w, req, func(s httpstream.Stream) error {
|
||||
streamCh <- s
|
||||
return nil
|
||||
})
|
||||
if spdyConn == nil {
|
||||
t.Fatalf("%d: unexpected nil spdyConn", i)
|
||||
}
|
||||
defer spdyConn.Close()
|
||||
|
||||
stream := <-streamCh
|
||||
io.Copy(stream, stream)
|
||||
}))
|
||||
|
||||
clientTLS := &tls.Config{}
|
||||
|
||||
if testCase.useTLS {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error generating keypair: %s", i, err)
|
||||
}
|
||||
|
||||
notBefore := time.Now()
|
||||
notAfter := notBefore.Add(1 * time.Hour)
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Localhost Co"},
|
||||
},
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
|
||||
host := "127.0.0.1"
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
template.IPAddresses = append(template.IPAddresses, ip)
|
||||
}
|
||||
template.DNSNames = append(template.DNSNames, host)
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error creating cert: %s", i, err)
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(derBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error parsing cert: %s", i, err)
|
||||
}
|
||||
|
||||
roots := x509.NewCertPool()
|
||||
roots.AddCert(cert)
|
||||
server.TLS = &tls.Config{
|
||||
RootCAs: roots,
|
||||
}
|
||||
clientTLS.RootCAs = roots
|
||||
|
||||
certBuf := bytes.Buffer{}
|
||||
err = pem.Encode(&certBuf, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error encoding cert: %s", i, err)
|
||||
}
|
||||
|
||||
keyBuf := bytes.Buffer{}
|
||||
err = pem.Encode(&keyBuf, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error encoding key: %s", i, err)
|
||||
}
|
||||
|
||||
tlsCert, err := tls.X509KeyPair(certBuf.Bytes(), keyBuf.Bytes())
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error calling tls.X509KeyPair: %s", i, err)
|
||||
}
|
||||
server.TLS.Certificates = []tls.Certificate{tlsCert}
|
||||
clientTLS.Certificates = []tls.Certificate{tlsCert}
|
||||
server.StartTLS()
|
||||
} else {
|
||||
server.Start()
|
||||
}
|
||||
defer server.Close()
|
||||
|
||||
req, err := http.NewRequest("GET", server.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: Error creating request: %s", i, err)
|
||||
}
|
||||
|
||||
spdyTransport := NewRoundTripper(clientTLS)
|
||||
client := &http.Client{Transport: spdyTransport}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: unexpected error from client.Do: %s", i, err)
|
||||
}
|
||||
|
||||
conn, err := spdyTransport.NewConnection(resp)
|
||||
haveErr := err != nil
|
||||
if e, a := testCase.shouldError, haveErr; e != a {
|
||||
t.Fatalf("%d: shouldError=%t, got %t: %v", i, e, a, err)
|
||||
}
|
||||
if testCase.shouldError {
|
||||
continue
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
t.Fatalf("%d: expected http 101 switching protocols, got %d", i, resp.StatusCode)
|
||||
}
|
||||
|
||||
stream, err := conn.CreateStream(http.Header{})
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error creating client stream: %s", i, err)
|
||||
}
|
||||
|
||||
n, err := stream.Write([]byte("hello"))
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error writing to stream: %s", i, err)
|
||||
}
|
||||
if n != 5 {
|
||||
t.Fatalf("%d: Expected to write 5 bytes, but actually wrote %d", i, n)
|
||||
}
|
||||
|
||||
b := make([]byte, 5)
|
||||
n, err = stream.Read(b)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error reading from stream: %s", i, err)
|
||||
}
|
||||
if n != 5 {
|
||||
t.Fatalf("%d: Expected to read 5 bytes, but actually read %d", i, n)
|
||||
}
|
||||
if e, a := "hello", string(b[0:n]); e != a {
|
||||
t.Fatalf("%d: expected '%s', got '%s'", i, e, a)
|
||||
}
|
||||
}
|
||||
}
|
78
pkg/util/httpstream/spdy/upgrade.go
Normal file
78
pkg/util/httpstream/spdy/upgrade.go
Normal file
@@ -0,0 +1,78 @@
|
||||
/*
|
||||
Copyright 2015 Google Inc. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package spdy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
|
||||
"github.com/golang/glog"
|
||||
)
|
||||
|
||||
const HeaderSpdy31 = "SPDY/3.1"
|
||||
|
||||
// responseUpgrader knows how to upgrade HTTP responses. It
|
||||
// implements the httpstream.ResponseUpgrader interface.
|
||||
type responseUpgrader struct {
|
||||
}
|
||||
|
||||
// NewResponseUpgrader returns a new httpstream.ResponseUpgrader that is
|
||||
// capable of upgrading HTTP responses using SPDY/3.1 via the
|
||||
// spdystream package.
|
||||
func NewResponseUpgrader() httpstream.ResponseUpgrader {
|
||||
return responseUpgrader{}
|
||||
}
|
||||
|
||||
// UpgradeResponse upgrades an HTTP response to one that supports multiplexed
|
||||
// streams. newStreamHandler will be called synchronously whenever the
|
||||
// other end of the upgraded connection creates a new stream.
|
||||
func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler httpstream.NewStreamHandler) httpstream.Connection {
|
||||
connectionHeader := strings.ToLower(req.Header.Get(httpstream.HeaderConnection))
|
||||
upgradeHeader := strings.ToLower(req.Header.Get(httpstream.HeaderUpgrade))
|
||||
if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(HeaderSpdy31)) {
|
||||
w.Write([]byte(fmt.Sprintf("Unable to upgrade: missing upgrade headers in request: %#v", req.Header)))
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return nil
|
||||
}
|
||||
|
||||
hijacker, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
w.Write([]byte("Unable to upgrade: unable to hijack response"))
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return nil
|
||||
}
|
||||
|
||||
w.Header().Add(httpstream.HeaderConnection, httpstream.HeaderUpgrade)
|
||||
w.Header().Add(httpstream.HeaderUpgrade, HeaderSpdy31)
|
||||
w.WriteHeader(http.StatusSwitchingProtocols)
|
||||
|
||||
conn, _, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
glog.Errorf("Unable to upgrade: error hijacking response: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
spdyConn, err := NewServerConnection(conn, newStreamHandler)
|
||||
if err != nil {
|
||||
glog.Errorf("Unable to upgrade: error creating SPDY server connection: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return spdyConn
|
||||
}
|
93
pkg/util/httpstream/spdy/upgrade_test.go
Normal file
93
pkg/util/httpstream/spdy/upgrade_test.go
Normal file
@@ -0,0 +1,93 @@
|
||||
/*
|
||||
Copyright 2015 Google Inc. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package spdy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUpgradeResponse(t *testing.T) {
|
||||
testCases := []struct {
|
||||
connectionHeader string
|
||||
upgradeHeader string
|
||||
shouldError bool
|
||||
}{
|
||||
{
|
||||
connectionHeader: "",
|
||||
upgradeHeader: "",
|
||||
shouldError: true,
|
||||
},
|
||||
{
|
||||
connectionHeader: "Upgrade",
|
||||
upgradeHeader: "",
|
||||
shouldError: true,
|
||||
},
|
||||
{
|
||||
connectionHeader: "",
|
||||
upgradeHeader: "SPDY/3.1",
|
||||
shouldError: true,
|
||||
},
|
||||
{
|
||||
connectionHeader: "Upgrade",
|
||||
upgradeHeader: "SPDY/3.1",
|
||||
shouldError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, testCase := range testCases {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
upgrader := NewResponseUpgrader()
|
||||
conn := upgrader.UpgradeResponse(w, req, nil)
|
||||
haveErr := conn == nil
|
||||
if e, a := testCase.shouldError, haveErr; e != a {
|
||||
t.Fatalf("%d: expected shouldErr=%t, got %t", i, testCase.shouldError, haveErr)
|
||||
}
|
||||
if haveErr {
|
||||
return
|
||||
}
|
||||
if conn == nil {
|
||||
t.Fatalf("%d: unexpected nil conn", i)
|
||||
}
|
||||
defer conn.Close()
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req, err := http.NewRequest("GET", server.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error creating request: %s", i, err)
|
||||
}
|
||||
|
||||
req.Header.Set("Connection", testCase.connectionHeader)
|
||||
req.Header.Set("Upgrade", testCase.upgradeHeader)
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: unexpected non-nil err from client.Do: %s", i, err)
|
||||
}
|
||||
|
||||
if testCase.shouldError {
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
t.Fatalf("%d: expected status 101 switching protocols, got %d", i, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
}
|
@@ -19,6 +19,7 @@ package util
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -61,3 +62,24 @@ func (ipnet *IPNet) Set(value string) error {
|
||||
func (*IPNet) Type() string {
|
||||
return "ipNet"
|
||||
}
|
||||
|
||||
// FROM: http://golang.org/src/net/http/client.go
|
||||
// Given a string of the form "host", "host:port", or "[ipv6::address]:port",
|
||||
// return true if the string includes a port.
|
||||
func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") }
|
||||
|
||||
// FROM: http://golang.org/src/net/http/transport.go
|
||||
var portMap = map[string]string{
|
||||
"http": "80",
|
||||
"https": "443",
|
||||
}
|
||||
|
||||
// FROM: http://golang.org/src/net/http/transport.go
|
||||
// canonicalAddr returns url.Host but always with a ":port" suffix
|
||||
func CanonicalAddr(url *url.URL) string {
|
||||
addr := url.Host
|
||||
if !hasPort(addr) {
|
||||
return addr + ":" + portMap[url.Scheme]
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
Reference in New Issue
Block a user