Keep streams from being set up after closeAllStreamReaders is called
This commit is contained in:
		| @@ -187,6 +187,9 @@ type wsStreamCreator struct { | |||||||
| 	// map of stream id to stream; multiple streams read/write the connection | 	// map of stream id to stream; multiple streams read/write the connection | ||||||
| 	streams   map[byte]*stream | 	streams   map[byte]*stream | ||||||
| 	streamsMu sync.Mutex | 	streamsMu sync.Mutex | ||||||
|  | 	// setStreamErr holds the error to return to anyone calling setStreams. | ||||||
|  | 	// this is populated in closeAllStreamReaders | ||||||
|  | 	setStreamErr error | ||||||
| } | } | ||||||
|  |  | ||||||
| func newWSStreamCreator(conn *gwebsocket.Conn) *wsStreamCreator { | func newWSStreamCreator(conn *gwebsocket.Conn) *wsStreamCreator { | ||||||
| @@ -202,10 +205,14 @@ func (c *wsStreamCreator) getStream(id byte) *stream { | |||||||
| 	return c.streams[id] | 	return c.streams[id] | ||||||
| } | } | ||||||
|  |  | ||||||
| func (c *wsStreamCreator) setStream(id byte, s *stream) { | func (c *wsStreamCreator) setStream(id byte, s *stream) error { | ||||||
| 	c.streamsMu.Lock() | 	c.streamsMu.Lock() | ||||||
| 	defer c.streamsMu.Unlock() | 	defer c.streamsMu.Unlock() | ||||||
|  | 	if c.setStreamErr != nil { | ||||||
|  | 		return c.setStreamErr | ||||||
|  | 	} | ||||||
| 	c.streams[id] = s | 	c.streams[id] = s | ||||||
|  | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // CreateStream uses id from passed headers to create a stream over "c.conn" connection. | // CreateStream uses id from passed headers to create a stream over "c.conn" connection. | ||||||
| @@ -228,7 +235,11 @@ func (c *wsStreamCreator) CreateStream(headers http.Header) (httpstream.Stream, | |||||||
| 		connWriteLock: &c.connWriteLock, | 		connWriteLock: &c.connWriteLock, | ||||||
| 		id:            id, | 		id:            id, | ||||||
| 	} | 	} | ||||||
| 	c.setStream(id, s) | 	if err := c.setStream(id, s); err != nil { | ||||||
|  | 		_ = s.writePipe.Close() | ||||||
|  | 		_ = s.readPipe.Close() | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
| 	return s, nil | 	return s, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -312,7 +323,7 @@ func (c *wsStreamCreator) readDemuxLoop(bufferSize int, period time.Duration, de | |||||||
| } | } | ||||||
|  |  | ||||||
| // closeAllStreamReaders closes readers in all streams. | // closeAllStreamReaders closes readers in all streams. | ||||||
| // This unblocks all stream.Read() calls. | // This unblocks all stream.Read() calls, and keeps any future streams from being created. | ||||||
| func (c *wsStreamCreator) closeAllStreamReaders(err error) { | func (c *wsStreamCreator) closeAllStreamReaders(err error) { | ||||||
| 	c.streamsMu.Lock() | 	c.streamsMu.Lock() | ||||||
| 	defer c.streamsMu.Unlock() | 	defer c.streamsMu.Unlock() | ||||||
| @@ -320,6 +331,12 @@ func (c *wsStreamCreator) closeAllStreamReaders(err error) { | |||||||
| 		// Closing writePipe unblocks all readPipe.Read() callers and prevents any future writes. | 		// Closing writePipe unblocks all readPipe.Read() callers and prevents any future writes. | ||||||
| 		_ = s.writePipe.CloseWithError(err) | 		_ = s.writePipe.CloseWithError(err) | ||||||
| 	} | 	} | ||||||
|  | 	// ensure callers to setStreams receive an error after this point | ||||||
|  | 	if err != nil { | ||||||
|  | 		c.setStreamErr = err | ||||||
|  | 	} else { | ||||||
|  | 		c.setStreamErr = fmt.Errorf("closed all streams") | ||||||
|  | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| type stream struct { | type stream struct { | ||||||
|   | |||||||
| @@ -1116,6 +1116,14 @@ func TestWebSocketClient_HeartbeatSucceeds(t *testing.T) { | |||||||
| 	wg.Wait() | 	wg.Wait() | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestLateStreamCreation(t *testing.T) { | ||||||
|  | 	c := newWSStreamCreator(nil) | ||||||
|  | 	c.closeAllStreamReaders(nil) | ||||||
|  | 	if err := c.setStream(0, nil); err == nil { | ||||||
|  | 		t.Fatal("expected error adding stream after closeAllStreamReaders") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestWebSocketClient_StreamsAndExpectedErrors(t *testing.T) { | func TestWebSocketClient_StreamsAndExpectedErrors(t *testing.T) { | ||||||
| 	// Validate Stream functions. | 	// Validate Stream functions. | ||||||
| 	c := newWSStreamCreator(nil) | 	c := newWSStreamCreator(nil) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Jordan Liggitt
					Jordan Liggitt