Keep streams from being set up after closeAllStreamReaders is called
This commit is contained in:
		| @@ -187,6 +187,9 @@ type wsStreamCreator struct { | ||||
| 	// map of stream id to stream; multiple streams read/write the connection | ||||
| 	streams   map[byte]*stream | ||||
| 	streamsMu sync.Mutex | ||||
| 	// setStreamErr holds the error to return to anyone calling setStreams. | ||||
| 	// this is populated in closeAllStreamReaders | ||||
| 	setStreamErr error | ||||
| } | ||||
|  | ||||
| func newWSStreamCreator(conn *gwebsocket.Conn) *wsStreamCreator { | ||||
| @@ -202,10 +205,14 @@ func (c *wsStreamCreator) getStream(id byte) *stream { | ||||
| 	return c.streams[id] | ||||
| } | ||||
|  | ||||
| func (c *wsStreamCreator) setStream(id byte, s *stream) { | ||||
| func (c *wsStreamCreator) setStream(id byte, s *stream) error { | ||||
| 	c.streamsMu.Lock() | ||||
| 	defer c.streamsMu.Unlock() | ||||
| 	if c.setStreamErr != nil { | ||||
| 		return c.setStreamErr | ||||
| 	} | ||||
| 	c.streams[id] = s | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // CreateStream uses id from passed headers to create a stream over "c.conn" connection. | ||||
| @@ -228,7 +235,11 @@ func (c *wsStreamCreator) CreateStream(headers http.Header) (httpstream.Stream, | ||||
| 		connWriteLock: &c.connWriteLock, | ||||
| 		id:            id, | ||||
| 	} | ||||
| 	c.setStream(id, s) | ||||
| 	if err := c.setStream(id, s); err != nil { | ||||
| 		_ = s.writePipe.Close() | ||||
| 		_ = s.readPipe.Close() | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return s, nil | ||||
| } | ||||
|  | ||||
| @@ -312,7 +323,7 @@ func (c *wsStreamCreator) readDemuxLoop(bufferSize int, period time.Duration, de | ||||
| } | ||||
|  | ||||
| // closeAllStreamReaders closes readers in all streams. | ||||
| // This unblocks all stream.Read() calls. | ||||
| // This unblocks all stream.Read() calls, and keeps any future streams from being created. | ||||
| func (c *wsStreamCreator) closeAllStreamReaders(err error) { | ||||
| 	c.streamsMu.Lock() | ||||
| 	defer c.streamsMu.Unlock() | ||||
| @@ -320,6 +331,12 @@ func (c *wsStreamCreator) closeAllStreamReaders(err error) { | ||||
| 		// Closing writePipe unblocks all readPipe.Read() callers and prevents any future writes. | ||||
| 		_ = s.writePipe.CloseWithError(err) | ||||
| 	} | ||||
| 	// ensure callers to setStreams receive an error after this point | ||||
| 	if err != nil { | ||||
| 		c.setStreamErr = err | ||||
| 	} else { | ||||
| 		c.setStreamErr = fmt.Errorf("closed all streams") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type stream struct { | ||||
|   | ||||
| @@ -1116,6 +1116,14 @@ func TestWebSocketClient_HeartbeatSucceeds(t *testing.T) { | ||||
| 	wg.Wait() | ||||
| } | ||||
|  | ||||
| func TestLateStreamCreation(t *testing.T) { | ||||
| 	c := newWSStreamCreator(nil) | ||||
| 	c.closeAllStreamReaders(nil) | ||||
| 	if err := c.setStream(0, nil); err == nil { | ||||
| 		t.Fatal("expected error adding stream after closeAllStreamReaders") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestWebSocketClient_StreamsAndExpectedErrors(t *testing.T) { | ||||
| 	// Validate Stream functions. | ||||
| 	c := newWSStreamCreator(nil) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jordan Liggitt
					Jordan Liggitt