diff --git a/integration/shim_dial_unix_test.go b/integration/shim_dial_unix_test.go new file mode 100644 index 000000000..861c2a8bf --- /dev/null +++ b/integration/shim_dial_unix_test.go @@ -0,0 +1,177 @@ +//go:build !windows +// +build !windows + +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package integration + +import ( + "context" + "io/ioutil" + "net" + "os" + "path/filepath" + "strings" + "syscall" + "testing" + "time" + + v1shimcli "github.com/containerd/containerd/runtime/v1/shim/client" + v2shimcli "github.com/containerd/containerd/runtime/v2/shim" + "github.com/containerd/ttrpc" + "github.com/pkg/errors" +) + +const abstractSocketPrefix = "\x00" + +// TestFailFastWhenConnectShim is to test that the containerd task manager +// should not tolerate ENOENT during restarting. In linux, the containerd shim +// always listens on socket before task manager dial. If there is ENOENT or +// ECONNREFUSED error, the task manager should clean up because that socket file +// is gone or shim doesn't listen on the socket anymore. +func TestFailFastWhenConnectShim(t *testing.T) { + t.Parallel() + + t.Run("abstract-unix-socket-v1", testFailFastWhenConnectShim(true, v1shimcli.AnonDialer)) + t.Run("abstract-unix-socket-v2", testFailFastWhenConnectShim(true, v2shimcli.AnonDialer)) + t.Run("normal-unix-socket-v1", testFailFastWhenConnectShim(false, v1shimcli.AnonDialer)) + t.Run("normal-unix-socket-v2", testFailFastWhenConnectShim(false, v2shimcli.AnonDialer)) +} + +type dialFunc func(address string, timeout time.Duration) (net.Conn, error) + +func testFailFastWhenConnectShim(abstract bool, dialFn dialFunc) func(*testing.T) { + return func(t *testing.T) { + var ( + ctx = context.Background() + addr, listener, cleanup = newTestListener(t, abstract) + errCh = make(chan error, 1) + + checkDialErr = func(addr string, errCh chan error, expected error) { + go func() { + _, err := dialFn(addr, 1*time.Hour) + errCh <- err + }() + + select { + case <-time.After(10 * time.Second): + t.Fatalf("expected fail fast, but got timeout") + case err := <-errCh: + t.Helper() + if !errors.Is(err, expected) { + t.Fatalf("expected error %v, but got %v", expected, err) + } + } + } + ) + defer cleanup() + defer listener.Close() + + ttrpcSrv, err := ttrpc.NewServer() + if err != nil { + t.Fatalf("failed to new ttrpc server: %v", err) + } + go func() { + ttrpcSrv.Serve(ctx, listener) + }() + + // ttrpcSrv starts in other goroutine so that we need to retry AnonDialer + // here until ttrpcSrv receives the request. + go func() { + to := time.After(10 * time.Second) + + for { + select { + case <-to: + errCh <- errors.New("timeout") + return + default: + } + + conn, err := dialFn(addr, 1*time.Hour) + if err != nil { + if errors.Is(err, syscall.ECONNREFUSED) { + time.Sleep(10 * time.Millisecond) + continue + } + errCh <- err + return + } + + conn.Close() + errCh <- nil + return + } + }() + + // it should be successful + if err := <-errCh; err != nil { + t.Fatalf("failed to dial: %v", err) + } + + // NOTE(fuweid): + // + // UnixListener will unlink that the socket file when call Close. + // Disable unlink when close to keep the socket file. + listener.(*net.UnixListener).SetUnlinkOnClose(false) + + listener.Close() + ttrpcSrv.Shutdown(ctx) + + checkDialErr(addr, errCh, syscall.ECONNREFUSED) + + // remove the socket file + cleanup() + + if abstract { + checkDialErr(addr, errCh, syscall.ECONNREFUSED) + } else { + // should not wait for the socket file show up again. + checkDialErr(addr, errCh, syscall.ENOENT) + } + } +} + +func newTestListener(t testing.TB, abstract bool) (string, net.Listener, func()) { + tmpDir, err := ioutil.TempDir("", "shim-ut-XX") + if err != nil { + t.Fatalf("failed to create tmp directory: %v", err) + } + + // NOTE(fuweid): + // + // Before patch https://github.com/containerd/containerd/commit/bd908acabd1a31c8329570b5283e8fdca0b39906, + // The shim stores the abstract socket file without abstract socket + // prefix and `unix://`. For the existing shim, if the socket file + // only contains the path, it will indicate that it is abstract socket. + // Otherwise, it will be normal socket file formated in `unix:///xyz'. + addr := filepath.Join(tmpDir, "uds.socket") + if abstract { + addr = abstractSocketPrefix + addr + } else { + addr = "unix://" + addr + } + + listener, err := net.Listen("unix", strings.TrimPrefix(addr, "unix://")) + if err != nil { + t.Fatalf("failed to listen on %s: %v", addr, err) + } + + return strings.TrimPrefix(addr, abstractSocketPrefix), listener, func() { + os.RemoveAll(tmpDir) + } +} diff --git a/runtime/v1/shim/client/client.go b/runtime/v1/shim/client/client.go index 869f1f941..ac38c33fe 100644 --- a/runtime/v1/shim/client/client.go +++ b/runtime/v1/shim/client/client.go @@ -34,7 +34,6 @@ import ( "github.com/containerd/containerd/events" "github.com/containerd/containerd/log" - "github.com/containerd/containerd/pkg/dialer" v1 "github.com/containerd/containerd/runtime/v1" "github.com/containerd/containerd/runtime/v1/shim" shimapi "github.com/containerd/containerd/runtime/v1/shim/v1" @@ -298,12 +297,19 @@ func RemoveSocket(address string) error { return nil } +// AnonDialer returns a dialer for a socket +// +// NOTE: It is only used for testing. +func AnonDialer(address string, timeout time.Duration) (net.Conn, error) { + return anonDialer(address, timeout) +} + func connect(address string, d func(string, time.Duration) (net.Conn, error)) (net.Conn, error) { return d(address, 100*time.Second) } func anonDialer(address string, timeout time.Duration) (net.Conn, error) { - return dialer.Dialer(socket(address).path(), timeout) + return net.DialTimeout("unix", socket(address).path(), timeout) } // WithConnect connects to an existing shim diff --git a/runtime/v2/shim/util_unix.go b/runtime/v2/shim/util_unix.go index 128896dff..cc33d300d 100644 --- a/runtime/v2/shim/util_unix.go +++ b/runtime/v2/shim/util_unix.go @@ -32,7 +32,6 @@ import ( "github.com/containerd/containerd/defaults" "github.com/containerd/containerd/namespaces" - "github.com/containerd/containerd/pkg/dialer" "github.com/containerd/containerd/sys" "github.com/pkg/errors" ) @@ -78,9 +77,7 @@ func SocketAddress(ctx context.Context, socketPath, id string) (string, error) { // AnonDialer returns a dialer for a socket func AnonDialer(address string, timeout time.Duration) (net.Conn, error) { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - return dialer.ContextDialer(ctx, socket(address).path()) + return net.DialTimeout("unix", socket(address).path(), timeout) } // AnonReconnectDialer returns a dialer for an existing socket on reconnection