sandbox: support vsock connection to task api
Signed-off-by: Abel Feng <fshb1988@gmail.com>
This commit is contained in:
@@ -19,17 +19,24 @@
|
||||
package shim
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/containerd/log"
|
||||
"github.com/mdlayher/vsock"
|
||||
|
||||
"github.com/containerd/containerd/v2/defaults"
|
||||
"github.com/containerd/containerd/v2/pkg/namespaces"
|
||||
"github.com/containerd/containerd/v2/pkg/sys"
|
||||
@@ -38,6 +45,9 @@ import (
|
||||
const (
|
||||
shimBinaryFormat = "containerd-shim-%s-%s"
|
||||
socketPathLimit = 106
|
||||
protoVsock = "vsock"
|
||||
protoHybridVsock = "hvsock"
|
||||
protoUnix = "unix"
|
||||
)
|
||||
|
||||
func getSysProcAttr() *syscall.SysProcAttr {
|
||||
@@ -76,7 +86,21 @@ func SocketAddress(ctx context.Context, socketPath, id string) (string, error) {
|
||||
|
||||
// AnonDialer returns a dialer for a socket
|
||||
func AnonDialer(address string, timeout time.Duration) (net.Conn, error) {
|
||||
return net.DialTimeout("unix", socket(address).path(), timeout)
|
||||
proto, addr, ok := strings.Cut(address, "://")
|
||||
if !ok {
|
||||
return net.DialTimeout("unix", socket(address).path(), timeout)
|
||||
}
|
||||
switch proto {
|
||||
case protoVsock:
|
||||
// vsock dialer can not set timeout
|
||||
return dialVsock(addr)
|
||||
case protoHybridVsock:
|
||||
return dialHybridVsock(addr, timeout)
|
||||
case protoUnix:
|
||||
return net.DialTimeout("unix", socket(address).path(), timeout)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported protocol: %s", proto)
|
||||
}
|
||||
}
|
||||
|
||||
// AnonReconnectDialer returns a dialer for an existing socket on reconnection
|
||||
@@ -177,3 +201,88 @@ func CanConnect(address string) bool {
|
||||
conn.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
func hybridVsockDialer(addr string, port uint64, timeout time.Duration) (net.Conn, error) {
|
||||
timeoutCh := time.After(timeout)
|
||||
// Do 10 retries before timeout
|
||||
retryInterval := timeout / 10
|
||||
for {
|
||||
conn, err := net.DialTimeout("unix", addr, timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err = conn.Write([]byte(fmt.Sprintf("CONNECT %d\n", port))); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
reader := bufio.NewReader(conn)
|
||||
response, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
if strings.Contains(response, "OK") {
|
||||
errChan <- nil
|
||||
} else {
|
||||
errChan <- fmt.Errorf("hybrid vsock handshake response error: %s", response)
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case err = <-errChan:
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
// When it is EOF, maybe the server side is not ready.
|
||||
if err == io.EOF {
|
||||
log.G(context.Background()).Warnf("Read hybrid vsock got EOF, server may not ready")
|
||||
time.Sleep(retryInterval)
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
case <-timeoutCh:
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("timeout waiting for hybrid vsocket handshake of %s:%d", addr, port)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func dialVsock(address string) (net.Conn, error) {
|
||||
contextIDString, portString, ok := strings.Cut(address, ":")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid vsock address %s", address)
|
||||
}
|
||||
contextID, err := strconv.ParseUint(contextIDString, 10, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse vsock context id %s, %v", contextIDString, err)
|
||||
}
|
||||
if contextID > math.MaxUint32 {
|
||||
return nil, fmt.Errorf("vsock context id %d is invalid", contextID)
|
||||
}
|
||||
port, err := strconv.ParseUint(portString, 10, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse vsock port %s, %v", portString, err)
|
||||
}
|
||||
if port > math.MaxUint32 {
|
||||
return nil, fmt.Errorf("vsock port %d is invalid", port)
|
||||
}
|
||||
return vsock.Dial(uint32(contextID), uint32(port), &vsock.Config{})
|
||||
}
|
||||
|
||||
func dialHybridVsock(address string, timeout time.Duration) (net.Conn, error) {
|
||||
addr, portString, ok := strings.Cut(address, ":")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid hybrid vsock address %s", address)
|
||||
}
|
||||
port, err := strconv.ParseUint(portString, 10, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse hybrid vsock port %s, %v", portString, err)
|
||||
}
|
||||
if port > math.MaxUint32 {
|
||||
return nil, fmt.Errorf("hybrid vsock port %d is invalid", port)
|
||||
}
|
||||
return hybridVsockDialer(addr, port, timeout)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user