diff --git a/cli/cli.go b/cli/cli.go index 103b1c229..e43180cad 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -49,16 +49,16 @@ var loadCommand = cli.Command{ timeout = context.GlobalDuration("timeout") cancel gocontext.CancelFunc ) - cl, err := client.NewCRIPluginClient(address, timeout) - if err != nil { - return errors.Wrap(err, "failed to create grpc client") - } if timeout > 0 { ctx, cancel = gocontext.WithTimeout(gocontext.Background(), timeout) } else { ctx, cancel = gocontext.WithCancel(ctx) } defer cancel() + cl, err := client.NewCRIPluginClient(ctx, address) + if err != nil { + return errors.Wrap(err, "failed to create grpc client") + } for _, path := range context.Args() { absPath, err := filepath.Abs(path) if err != nil { diff --git a/integration/test_utils.go b/integration/test_utils.go index 8ba08f6c7..ebcca5c9f 100644 --- a/integration/test_utils.go +++ b/integration/test_utils.go @@ -17,6 +17,7 @@ limitations under the License. package integration import ( + "context" "flag" "fmt" "os/exec" @@ -87,7 +88,9 @@ func ConnectDaemons() error { if err != nil { return errors.Wrap(err, "failed to connect containerd") } - criPluginClient, err = client.NewCRIPluginClient(*criEndpoint, timeout) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + criPluginClient, err = client.NewCRIPluginClient(ctx, *criEndpoint) if err != nil { return errors.Wrap(err, "failed to connect cri plugin") } diff --git a/pkg/client/client.go b/pkg/client/client.go index c856f5cb3..83d605ff9 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -18,7 +18,6 @@ package client import ( "context" - "time" "github.com/pkg/errors" "google.golang.org/grpc" @@ -29,13 +28,11 @@ import ( // NewCRIPluginClient creates grpc client of cri plugin // TODO(random-liu): Wrap grpc functions. -func NewCRIPluginClient(endpoint string, timeout time.Duration) (api.CRIPluginServiceClient, error) { +func NewCRIPluginClient(ctx context.Context, endpoint string) (api.CRIPluginServiceClient, error) { addr, dialer, err := util.GetAddressAndDialer(endpoint) if err != nil { return nil, errors.Wrap(err, "failed to get dialer") } - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() conn, err := grpc.DialContext(ctx, addr, grpc.WithBlock(), grpc.WithInsecure(),