Refactor CRI packages

Signed-off-by: Maksym Pavlenko <pavlenko.maksym@gmail.com>
This commit is contained in:
Maksym Pavlenko
2020-10-07 14:30:19 -07:00
parent 944e9b70e2
commit 3508ddd3dd
34 changed files with 22 additions and 22 deletions

View File

@@ -0,0 +1,40 @@
/*
Copyright The containerd Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package portforward contains server-side logic for handling port forwarding requests.
package portforward
// ProtocolV1Name is the name of the subprotocol used for port forwarding.
const ProtocolV1Name = "portforward.k8s.io"
// SupportedProtocols are the supported port forwarding protocols.
var SupportedProtocols = []string{ProtocolV1Name}

View File

@@ -0,0 +1,315 @@
/*
Copyright The containerd Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package portforward
import (
"errors"
"fmt"
"net/http"
"strconv"
"sync"
"time"
api "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/klog/v2"
)
func handleHTTPStreams(req *http.Request, w http.ResponseWriter, portForwarder PortForwarder, podName string, uid types.UID, supportedPortForwardProtocols []string, idleTimeout, streamCreationTimeout time.Duration) error {
_, err := httpstream.Handshake(req, w, supportedPortForwardProtocols)
// negotiated protocol isn't currently used server side, but could be in the future
if err != nil {
// Handshake writes the error to the client
return err
}
streamChan := make(chan httpstream.Stream, 1)
klog.V(5).Infof("Upgrading port forward response")
upgrader := spdy.NewResponseUpgrader()
conn := upgrader.UpgradeResponse(w, req, httpStreamReceived(streamChan))
if conn == nil {
return errors.New("unable to upgrade httpstream connection")
}
defer conn.Close()
klog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout)
conn.SetIdleTimeout(idleTimeout)
h := &httpStreamHandler{
conn: conn,
streamChan: streamChan,
streamPairs: make(map[string]*httpStreamPair),
streamCreationTimeout: streamCreationTimeout,
pod: podName,
uid: uid,
forwarder: portForwarder,
}
h.run()
return nil
}
// httpStreamReceived is the httpstream.NewStreamHandler for port
// forward streams. It checks each stream's port and stream type headers,
// rejecting any streams that with missing or invalid values. Each valid
// stream is sent to the streams channel.
func httpStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error {
return func(stream httpstream.Stream, replySent <-chan struct{}) error {
// make sure it has a valid port header
portString := stream.Headers().Get(api.PortHeader)
if len(portString) == 0 {
return fmt.Errorf("%q header is required", api.PortHeader)
}
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return fmt.Errorf("unable to parse %q as a port: %v", portString, err)
}
if port < 1 {
return fmt.Errorf("port %q must be > 0", portString)
}
// make sure it has a valid stream type header
streamType := stream.Headers().Get(api.StreamType)
if len(streamType) == 0 {
return fmt.Errorf("%q header is required", api.StreamType)
}
if streamType != api.StreamTypeError && streamType != api.StreamTypeData {
return fmt.Errorf("invalid stream type %q", streamType)
}
streams <- stream
return nil
}
}
// httpStreamHandler is capable of processing multiple port forward
// requests over a single httpstream.Connection.
type httpStreamHandler struct {
conn httpstream.Connection
streamChan chan httpstream.Stream
streamPairsLock sync.RWMutex
streamPairs map[string]*httpStreamPair
streamCreationTimeout time.Duration
pod string
uid types.UID
forwarder PortForwarder
}
// getStreamPair returns a httpStreamPair for requestID. This creates a
// new pair if one does not yet exist for the requestID. The returned bool is
// true if the pair was created.
func (h *httpStreamHandler) getStreamPair(requestID string) (*httpStreamPair, bool) {
h.streamPairsLock.Lock()
defer h.streamPairsLock.Unlock()
if p, ok := h.streamPairs[requestID]; ok {
klog.V(5).Infof("(conn=%p, request=%s) found existing stream pair", h.conn, requestID)
return p, false
}
klog.V(5).Infof("(conn=%p, request=%s) creating new stream pair", h.conn, requestID)
p := newPortForwardPair(requestID)
h.streamPairs[requestID] = p
return p, true
}
// monitorStreamPair waits for the pair to receive both its error and data
// streams, or for the timeout to expire (whichever happens first), and then
// removes the pair.
func (h *httpStreamHandler) monitorStreamPair(p *httpStreamPair, timeout <-chan time.Time) {
select {
case <-timeout:
err := fmt.Errorf("(conn=%v, request=%s) timed out waiting for streams", h.conn, p.requestID)
utilruntime.HandleError(err)
p.printError(err.Error())
case <-p.complete:
klog.V(5).Infof("(conn=%v, request=%s) successfully received error and data streams", h.conn, p.requestID)
}
h.removeStreamPair(p.requestID)
}
// removeStreamPair removes the stream pair identified by requestID from streamPairs.
func (h *httpStreamHandler) removeStreamPair(requestID string) {
h.streamPairsLock.Lock()
defer h.streamPairsLock.Unlock()
delete(h.streamPairs, requestID)
}
// requestID returns the request id for stream.
func (h *httpStreamHandler) requestID(stream httpstream.Stream) string {
requestID := stream.Headers().Get(api.PortForwardRequestIDHeader)
if len(requestID) == 0 {
klog.V(5).Infof("(conn=%p) stream received without %s header", h.conn, api.PortForwardRequestIDHeader)
// If we get here, it's because the connection came from an older client
// that isn't generating the request id header
// (https://github.com/kubernetes/kubernetes/blob/843134885e7e0b360eb5441e85b1410a8b1a7a0c/pkg/client/unversioned/portforward/portforward.go#L258-L287)
//
// This is a best-effort attempt at supporting older clients.
//
// When there aren't concurrent new forwarded connections, each connection
// will have a pair of streams (data, error), and the stream IDs will be
// consecutive odd numbers, e.g. 1 and 3 for the first connection. Convert
// the stream ID into a pseudo-request id by taking the stream type and
// using id = stream.Identifier() when the stream type is error,
// and id = stream.Identifier() - 2 when it's data.
//
// NOTE: this only works when there are not concurrent new streams from
// multiple forwarded connections; it's a best-effort attempt at supporting
// old clients that don't generate request ids. If there are concurrent
// new connections, it's possible that 1 connection gets streams whose IDs
// are not consecutive (e.g. 5 and 9 instead of 5 and 7).
streamType := stream.Headers().Get(api.StreamType)
switch streamType {
case api.StreamTypeError:
requestID = strconv.Itoa(int(stream.Identifier()))
case api.StreamTypeData:
requestID = strconv.Itoa(int(stream.Identifier()) - 2)
}
klog.V(5).Infof("(conn=%p) automatically assigning request ID=%q from stream type=%s, stream ID=%d", h.conn, requestID, streamType, stream.Identifier())
}
return requestID
}
// run is the main loop for the httpStreamHandler. It processes new
// streams, invoking portForward for each complete stream pair. The loop exits
// when the httpstream.Connection is closed.
func (h *httpStreamHandler) run() {
klog.V(5).Infof("(conn=%p) waiting for port forward streams", h.conn)
Loop:
for {
select {
case <-h.conn.CloseChan():
klog.V(5).Infof("(conn=%p) upgraded connection closed", h.conn)
break Loop
case stream := <-h.streamChan:
requestID := h.requestID(stream)
streamType := stream.Headers().Get(api.StreamType)
klog.V(5).Infof("(conn=%p, request=%s) received new stream of type %s", h.conn, requestID, streamType)
p, created := h.getStreamPair(requestID)
if created {
go h.monitorStreamPair(p, time.After(h.streamCreationTimeout))
}
if complete, err := p.add(stream); err != nil {
msg := fmt.Sprintf("error processing stream for request %s: %v", requestID, err)
utilruntime.HandleError(errors.New(msg))
p.printError(msg)
} else if complete {
go h.portForward(p)
}
}
}
}
// portForward invokes the httpStreamHandler's forwarder.PortForward
// function for the given stream pair.
func (h *httpStreamHandler) portForward(p *httpStreamPair) {
defer p.dataStream.Close()
defer p.errorStream.Close()
portString := p.dataStream.Headers().Get(api.PortHeader)
port, _ := strconv.ParseInt(portString, 10, 32)
klog.V(5).Infof("(conn=%p, request=%s) invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
err := h.forwarder.PortForward(h.pod, h.uid, int32(port), p.dataStream)
klog.V(5).Infof("(conn=%p, request=%s) done invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
if err != nil {
msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", port, h.pod, h.uid, err)
utilruntime.HandleError(msg)
fmt.Fprint(p.errorStream, msg.Error())
}
}
// httpStreamPair represents the error and data streams for a port
// forwarding request.
type httpStreamPair struct {
lock sync.RWMutex
requestID string
dataStream httpstream.Stream
errorStream httpstream.Stream
complete chan struct{}
}
// newPortForwardPair creates a new httpStreamPair.
func newPortForwardPair(requestID string) *httpStreamPair {
return &httpStreamPair{
requestID: requestID,
complete: make(chan struct{}),
}
}
// add adds the stream to the httpStreamPair. If the pair already
// contains a stream for the new stream's type, an error is returned. add
// returns true if both the data and error streams for this pair have been
// received.
func (p *httpStreamPair) add(stream httpstream.Stream) (bool, error) {
p.lock.Lock()
defer p.lock.Unlock()
switch stream.Headers().Get(api.StreamType) {
case api.StreamTypeError:
if p.errorStream != nil {
return false, errors.New("error stream already assigned")
}
p.errorStream = stream
case api.StreamTypeData:
if p.dataStream != nil {
return false, errors.New("data stream already assigned")
}
p.dataStream = stream
}
complete := p.errorStream != nil && p.dataStream != nil
if complete {
close(p.complete)
}
return complete, nil
}
// printError writes s to p.errorStream if p.errorStream has been set.
func (p *httpStreamPair) printError(s string) {
p.lock.RLock()
defer p.lock.RUnlock()
if p.errorStream != nil {
fmt.Fprint(p.errorStream, s)
}
}

View File

@@ -0,0 +1,69 @@
/*
Copyright The containerd Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package portforward
import (
"io"
"net/http"
"time"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/apiserver/pkg/util/wsstream"
)
// PortForwarder knows how to forward content from a data stream to/from a port
// in a pod.
type PortForwarder interface {
// PortForwarder copies data between a data stream and a port in a pod.
PortForward(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error
}
// ServePortForward handles a port forwarding request. A single request is
// kept alive as long as the client is still alive and the connection has not
// been timed out due to idleness. This function handles multiple forwarded
// connections; i.e., multiple `curl http://localhost:8888/` requests will be
// handled by a single invocation of ServePortForward.
func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, portForwardOptions *V4Options, idleTimeout time.Duration, streamCreationTimeout time.Duration, supportedProtocols []string) {
var err error
if wsstream.IsWebSocketRequest(req) {
err = handleWebSocketStreams(req, w, portForwarder, podName, uid, portForwardOptions, supportedProtocols, idleTimeout, streamCreationTimeout)
} else {
err = handleHTTPStreams(req, w, portForwarder, podName, uid, supportedProtocols, idleTimeout, streamCreationTimeout)
}
if err != nil {
runtime.HandleError(err)
return
}
}

View File

@@ -0,0 +1,213 @@
/*
Copyright The containerd Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package portforward
import (
"encoding/binary"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync"
"time"
"k8s.io/klog/v2"
api "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/apiserver/pkg/server/httplog"
"k8s.io/apiserver/pkg/util/wsstream"
)
const (
dataChannel = iota
errorChannel
v4BinaryWebsocketProtocol = "v4." + wsstream.ChannelWebSocketProtocol
v4Base64WebsocketProtocol = "v4." + wsstream.Base64ChannelWebSocketProtocol
)
// V4Options contains details about which streams are required for port
// forwarding.
// All fields included in V4Options need to be expressed explicitly in the
// CRI (k8s.io/cri-api/pkg/apis/{version}/api.proto) PortForwardRequest.
type V4Options struct {
Ports []int32
}
// NewV4Options creates a new options from the Request.
func NewV4Options(req *http.Request) (*V4Options, error) {
if !wsstream.IsWebSocketRequest(req) {
return &V4Options{}, nil
}
portStrings := req.URL.Query()[api.PortHeader]
if len(portStrings) == 0 {
return nil, fmt.Errorf("query parameter %q is required", api.PortHeader)
}
ports := make([]int32, 0, len(portStrings))
for _, portString := range portStrings {
if len(portString) == 0 {
return nil, fmt.Errorf("query parameter %q cannot be empty", api.PortHeader)
}
for _, p := range strings.Split(portString, ",") {
port, err := strconv.ParseUint(p, 10, 16)
if err != nil {
return nil, fmt.Errorf("unable to parse %q as a port: %v", portString, err)
}
if port < 1 {
return nil, fmt.Errorf("port %q must be > 0", portString)
}
ports = append(ports, int32(port))
}
}
return &V4Options{
Ports: ports,
}, nil
}
// BuildV4Options returns a V4Options based on the given information.
func BuildV4Options(ports []int32) (*V4Options, error) {
return &V4Options{Ports: ports}, nil
}
// handleWebSocketStreams handles requests to forward ports to a pod via
// a PortForwarder. A pair of streams are created per port (DATA n,
// ERROR n+1). The associated port is written to each stream as a unsigned 16
// bit integer in little endian format.
func handleWebSocketStreams(req *http.Request, w http.ResponseWriter, portForwarder PortForwarder, podName string, uid types.UID, opts *V4Options, supportedPortForwardProtocols []string, idleTimeout, streamCreationTimeout time.Duration) error {
channels := make([]wsstream.ChannelType, 0, len(opts.Ports)*2)
for i := 0; i < len(opts.Ports); i++ {
channels = append(channels, wsstream.ReadWriteChannel, wsstream.WriteChannel)
}
conn := wsstream.NewConn(map[string]wsstream.ChannelProtocolConfig{
"": {
Binary: true,
Channels: channels,
},
v4BinaryWebsocketProtocol: {
Binary: true,
Channels: channels,
},
v4Base64WebsocketProtocol: {
Binary: false,
Channels: channels,
},
})
conn.SetIdleTimeout(idleTimeout)
_, streams, err := conn.Open(httplog.Unlogged(req, w), req)
if err != nil {
err = fmt.Errorf("unable to upgrade websocket connection: %v", err)
return err
}
defer conn.Close()
streamPairs := make([]*websocketStreamPair, len(opts.Ports))
for i := range streamPairs {
streamPair := websocketStreamPair{
port: opts.Ports[i],
dataStream: streams[i*2+dataChannel],
errorStream: streams[i*2+errorChannel],
}
streamPairs[i] = &streamPair
portBytes := make([]byte, 2)
// port is always positive so conversion is allowable
binary.LittleEndian.PutUint16(portBytes, uint16(streamPair.port))
streamPair.dataStream.Write(portBytes)
streamPair.errorStream.Write(portBytes)
}
h := &websocketStreamHandler{
conn: conn,
streamPairs: streamPairs,
pod: podName,
uid: uid,
forwarder: portForwarder,
}
h.run()
return nil
}
// websocketStreamPair represents the error and data streams for a port
// forwarding request.
type websocketStreamPair struct {
port int32
dataStream io.ReadWriteCloser
errorStream io.WriteCloser
}
// websocketStreamHandler is capable of processing a single port forward
// request over a websocket connection
type websocketStreamHandler struct {
conn *wsstream.Conn
streamPairs []*websocketStreamPair
pod string
uid types.UID
forwarder PortForwarder
}
// run invokes the websocketStreamHandler's forwarder.PortForward
// function for the given stream pair.
func (h *websocketStreamHandler) run() {
wg := sync.WaitGroup{}
wg.Add(len(h.streamPairs))
for _, pair := range h.streamPairs {
p := pair
go func() {
defer wg.Done()
h.portForward(p)
}()
}
wg.Wait()
}
func (h *websocketStreamHandler) portForward(p *websocketStreamPair) {
defer p.dataStream.Close()
defer p.errorStream.Close()
klog.V(5).Infof("(conn=%p) invoking forwarder.PortForward for port %d", h.conn, p.port)
err := h.forwarder.PortForward(h.pod, h.uid, p.port, p.dataStream)
klog.V(5).Infof("(conn=%p) done invoking forwarder.PortForward for port %d", h.conn, p.port)
if err != nil {
msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", p.port, h.pod, h.uid, err)
runtime.HandleError(msg)
fmt.Fprint(p.errorStream, msg.Error())
}
}