Merge pull request #99839 from saschagrunert/portforward-stream-cleanup
Cleanup portforward streams after their usage
This commit is contained in:
commit
4fae6ae5d2
@ -163,6 +163,10 @@ func (h *httpStreamHandler) removeStreamPair(requestID string) {
|
||||
h.streamPairsLock.Lock()
|
||||
defer h.streamPairsLock.Unlock()
|
||||
|
||||
if h.conn != nil {
|
||||
pair := h.streamPairs[requestID]
|
||||
h.conn.RemoveStreams(pair.dataStream, pair.errorStream)
|
||||
}
|
||||
delete(h.streamPairs, requestID)
|
||||
}
|
||||
|
||||
|
@ -92,11 +92,23 @@ func TestHTTPStreamReceived(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type fakeConn struct {
|
||||
removeStreamsCalled bool
|
||||
}
|
||||
|
||||
func (*fakeConn) CreateStream(headers http.Header) (httpstream.Stream, error) { return nil, nil }
|
||||
func (*fakeConn) Close() error { return nil }
|
||||
func (*fakeConn) CloseChan() <-chan bool { return nil }
|
||||
func (*fakeConn) SetIdleTimeout(timeout time.Duration) {}
|
||||
func (f *fakeConn) RemoveStreams(streams ...httpstream.Stream) { f.removeStreamsCalled = true }
|
||||
|
||||
func TestGetStreamPair(t *testing.T) {
|
||||
timeout := make(chan time.Time)
|
||||
|
||||
conn := &fakeConn{}
|
||||
h := &httpStreamHandler{
|
||||
streamPairs: make(map[string]*httpStreamPair),
|
||||
conn: conn,
|
||||
}
|
||||
|
||||
// test adding a new entry
|
||||
@ -158,6 +170,11 @@ func TestGetStreamPair(t *testing.T) {
|
||||
// make sure monitorStreamPair completed
|
||||
<-monitorDone
|
||||
|
||||
if !conn.removeStreamsCalled {
|
||||
t.Fatalf("connection remove stream not called")
|
||||
}
|
||||
conn.removeStreamsCalled = false
|
||||
|
||||
// make sure the pair was removed
|
||||
if h.hasStreamPair("1") {
|
||||
t.Fatal("expected removal of pair after both data and error streams received")
|
||||
@ -171,6 +188,7 @@ func TestGetStreamPair(t *testing.T) {
|
||||
if p == nil {
|
||||
t.Fatal("expected p not to be nil")
|
||||
}
|
||||
|
||||
monitorDone = make(chan struct{})
|
||||
go func() {
|
||||
h.monitorStreamPair(p, timeout)
|
||||
@ -183,6 +201,9 @@ func TestGetStreamPair(t *testing.T) {
|
||||
if h.hasStreamPair("2") {
|
||||
t.Fatal("expected stream pair to be removed")
|
||||
}
|
||||
if !conn.removeStreamsCalled {
|
||||
t.Fatalf("connection remove stream not called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestID(t *testing.T) {
|
||||
|
@ -78,6 +78,8 @@ type Connection interface {
|
||||
// SetIdleTimeout sets the amount of time the connection may remain idle before
|
||||
// it is automatically closed.
|
||||
SetIdleTimeout(timeout time.Duration)
|
||||
// RemoveStreams can be used to remove a set of streams from the Connection.
|
||||
RemoveStreams(streams ...Stream)
|
||||
}
|
||||
|
||||
// Stream represents a bidirectional communications channel that is part of an
|
||||
|
@ -31,7 +31,7 @@ import (
|
||||
// streams.
|
||||
type connection struct {
|
||||
conn *spdystream.Connection
|
||||
streams []httpstream.Stream
|
||||
streams map[uint32]httpstream.Stream
|
||||
streamLock sync.Mutex
|
||||
newStreamHandler httpstream.NewStreamHandler
|
||||
ping func() (time.Duration, error)
|
||||
@ -85,7 +85,12 @@ func NewServerConnectionWithPings(conn net.Conn, newStreamHandler httpstream.New
|
||||
// will be invoked when the server receives a newly created stream from the
|
||||
// client.
|
||||
func newConnection(conn *spdystream.Connection, newStreamHandler httpstream.NewStreamHandler, pingPeriod time.Duration, pingFn func() (time.Duration, error)) httpstream.Connection {
|
||||
c := &connection{conn: conn, newStreamHandler: newStreamHandler, ping: pingFn}
|
||||
c := &connection{
|
||||
conn: conn,
|
||||
newStreamHandler: newStreamHandler,
|
||||
ping: pingFn,
|
||||
streams: make(map[uint32]httpstream.Stream),
|
||||
}
|
||||
go conn.Serve(c.newSpdyStream)
|
||||
if pingPeriod > 0 && pingFn != nil {
|
||||
go c.sendPings(pingPeriod)
|
||||
@ -105,7 +110,7 @@ func (c *connection) Close() error {
|
||||
// calling Reset instead of Close ensures that all streams are fully torn down
|
||||
s.Reset()
|
||||
}
|
||||
c.streams = make([]httpstream.Stream, 0)
|
||||
c.streams = make(map[uint32]httpstream.Stream, 0)
|
||||
c.streamLock.Unlock()
|
||||
|
||||
// now that all streams are fully torn down, it's safe to call close on the underlying connection,
|
||||
@ -114,6 +119,15 @@ func (c *connection) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
// RemoveStreams can be used to removes a set of streams from the Connection.
|
||||
func (c *connection) RemoveStreams(streams ...httpstream.Stream) {
|
||||
c.streamLock.Lock()
|
||||
for _, stream := range streams {
|
||||
delete(c.streams, stream.Identifier())
|
||||
}
|
||||
c.streamLock.Unlock()
|
||||
}
|
||||
|
||||
// CreateStream creates a new stream with the specified headers and registers
|
||||
// it with the connection.
|
||||
func (c *connection) CreateStream(headers http.Header) (httpstream.Stream, error) {
|
||||
@ -133,7 +147,7 @@ func (c *connection) CreateStream(headers http.Header) (httpstream.Stream, error
|
||||
// it owns.
|
||||
func (c *connection) registerStream(s httpstream.Stream) {
|
||||
c.streamLock.Lock()
|
||||
c.streams = append(c.streams, s)
|
||||
c.streams[s.Identifier()] = s
|
||||
c.streamLock.Unlock()
|
||||
}
|
||||
|
||||
|
@ -290,3 +290,41 @@ func TestConnectionPings(t *testing.T) {
|
||||
t.Errorf("timed out waiting for server to exit")
|
||||
}
|
||||
}
|
||||
|
||||
type fakeStream struct{ id uint32 }
|
||||
|
||||
func (*fakeStream) Read(p []byte) (int, error) { return 0, nil }
|
||||
func (*fakeStream) Write(p []byte) (int, error) { return 0, nil }
|
||||
func (*fakeStream) Close() error { return nil }
|
||||
func (*fakeStream) Reset() error { return nil }
|
||||
func (*fakeStream) Headers() http.Header { return nil }
|
||||
func (f *fakeStream) Identifier() uint32 { return f.id }
|
||||
|
||||
func TestConnectionRemoveStreams(t *testing.T) {
|
||||
c := &connection{streams: make(map[uint32]httpstream.Stream)}
|
||||
stream0 := &fakeStream{id: 0}
|
||||
stream1 := &fakeStream{id: 1}
|
||||
stream2 := &fakeStream{id: 2}
|
||||
|
||||
c.registerStream(stream0)
|
||||
c.registerStream(stream1)
|
||||
|
||||
if len(c.streams) != 2 {
|
||||
t.Fatalf("should have two streams, has %d", len(c.streams))
|
||||
}
|
||||
|
||||
// not exists
|
||||
c.RemoveStreams(stream2)
|
||||
|
||||
if len(c.streams) != 2 {
|
||||
t.Fatalf("should have two streams, has %d", len(c.streams))
|
||||
}
|
||||
|
||||
// remove all existing
|
||||
c.RemoveStreams(stream0, stream1)
|
||||
|
||||
if len(c.streams) != 0 {
|
||||
t.Fatalf("should not have any streams, has %d", len(c.streams))
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -69,6 +69,9 @@ func (c *fakeConnection) CloseChan() <-chan bool {
|
||||
return c.closeChan
|
||||
}
|
||||
|
||||
func (c *fakeConnection) RemoveStreams(_ ...httpstream.Stream) {
|
||||
}
|
||||
|
||||
func (c *fakeConnection) SetIdleTimeout(timeout time.Duration) {
|
||||
// no-op
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user