Use Dial with context

This commit is contained in:
Mikhail Mazurskiy
2018-05-19 08:14:37 +10:00
parent 77a08ee2d7
commit 5e8e570dbd
25 changed files with 111 additions and 110 deletions

View File

@@ -18,6 +18,7 @@ package ssh
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
@@ -121,10 +122,11 @@ func (s *SSHTunnel) Open() error {
return err
}
func (s *SSHTunnel) Dial(network, address string) (net.Conn, error) {
func (s *SSHTunnel) Dial(ctx context.Context, network, address string) (net.Conn, error) {
if s.client == nil {
return nil, errors.New("tunnel is not opened.")
}
// This Dial method does not allow to pass a context unfortunately
return s.client.Dial(network, address)
}
@@ -294,7 +296,7 @@ func ParsePublicKeyFromFile(keyFile string) (*rsa.PublicKey, error) {
type tunnel interface {
Open() error
Close() error
Dial(network, address string) (net.Conn, error)
Dial(ctx context.Context, network, address string) (net.Conn, error)
}
type sshTunnelEntry struct {
@@ -361,7 +363,7 @@ func (l *SSHTunnelList) delayedHealthCheck(e sshTunnelEntry, delay time.Duration
func (l *SSHTunnelList) healthCheck(e sshTunnelEntry) error {
// GET the healthcheck path using the provided tunnel's dial function.
transport := utilnet.SetTransportDefaults(&http.Transport{
Dial: e.Tunnel.Dial,
DialContext: e.Tunnel.Dial,
// TODO(cjcullen): Plumb real TLS options through.
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
// We don't reuse the clients, so disable the keep-alive to properly
@@ -394,7 +396,7 @@ func (l *SSHTunnelList) removeAndReAdd(e sshTunnelEntry) {
go l.createAndAddTunnel(e.Address)
}
func (l *SSHTunnelList) Dial(net, addr string) (net.Conn, error) {
func (l *SSHTunnelList) Dial(ctx context.Context, net, addr string) (net.Conn, error) {
start := time.Now()
id := mathrand.Int63() // So you can match begins/ends in the log.
glog.Infof("[%x: %v] Dialing...", id, addr)
@@ -405,7 +407,7 @@ func (l *SSHTunnelList) Dial(net, addr string) (net.Conn, error) {
if err != nil {
return nil, err
}
return tunnel.Dial(net, addr)
return tunnel.Dial(ctx, net, addr)
}
func (l *SSHTunnelList) pickTunnel(addr string) (tunnel, error) {

View File

@@ -17,6 +17,7 @@ limitations under the License.
package ssh
import (
"context"
"fmt"
"io"
"net"
@@ -145,7 +146,7 @@ func TestSSHTunnel(t *testing.T) {
t.FailNow()
}
_, err = tunnel.Dial("tcp", "127.0.0.1:8080")
_, err = tunnel.Dial(context.Background(), "tcp", "127.0.0.1:8080")
if err != nil {
t.Errorf("unexpected error: %v", err)
}
@@ -176,7 +177,7 @@ func (*fakeTunnel) Close() error {
return nil
}
func (*fakeTunnel) Dial(network, address string) (net.Conn, error) {
func (*fakeTunnel) Dial(ctx context.Context, network, address string) (net.Conn, error) {
return nil, nil
}