Keep streams from being set up after closeAllStreamReaders is called
This commit is contained in:
		@@ -187,6 +187,9 @@ type wsStreamCreator struct {
 | 
				
			|||||||
	// map of stream id to stream; multiple streams read/write the connection
 | 
						// map of stream id to stream; multiple streams read/write the connection
 | 
				
			||||||
	streams   map[byte]*stream
 | 
						streams   map[byte]*stream
 | 
				
			||||||
	streamsMu sync.Mutex
 | 
						streamsMu sync.Mutex
 | 
				
			||||||
 | 
						// setStreamErr holds the error to return to anyone calling setStreams.
 | 
				
			||||||
 | 
						// this is populated in closeAllStreamReaders
 | 
				
			||||||
 | 
						setStreamErr error
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func newWSStreamCreator(conn *gwebsocket.Conn) *wsStreamCreator {
 | 
					func newWSStreamCreator(conn *gwebsocket.Conn) *wsStreamCreator {
 | 
				
			||||||
@@ -202,10 +205,14 @@ func (c *wsStreamCreator) getStream(id byte) *stream {
 | 
				
			|||||||
	return c.streams[id]
 | 
						return c.streams[id]
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (c *wsStreamCreator) setStream(id byte, s *stream) {
 | 
					func (c *wsStreamCreator) setStream(id byte, s *stream) error {
 | 
				
			||||||
	c.streamsMu.Lock()
 | 
						c.streamsMu.Lock()
 | 
				
			||||||
	defer c.streamsMu.Unlock()
 | 
						defer c.streamsMu.Unlock()
 | 
				
			||||||
 | 
						if c.setStreamErr != nil {
 | 
				
			||||||
 | 
							return c.setStreamErr
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	c.streams[id] = s
 | 
						c.streams[id] = s
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// CreateStream uses id from passed headers to create a stream over "c.conn" connection.
 | 
					// CreateStream uses id from passed headers to create a stream over "c.conn" connection.
 | 
				
			||||||
@@ -228,7 +235,11 @@ func (c *wsStreamCreator) CreateStream(headers http.Header) (httpstream.Stream,
 | 
				
			|||||||
		connWriteLock: &c.connWriteLock,
 | 
							connWriteLock: &c.connWriteLock,
 | 
				
			||||||
		id:            id,
 | 
							id:            id,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	c.setStream(id, s)
 | 
						if err := c.setStream(id, s); err != nil {
 | 
				
			||||||
 | 
							_ = s.writePipe.Close()
 | 
				
			||||||
 | 
							_ = s.readPipe.Close()
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	return s, nil
 | 
						return s, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -312,7 +323,7 @@ func (c *wsStreamCreator) readDemuxLoop(bufferSize int, period time.Duration, de
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// closeAllStreamReaders closes readers in all streams.
 | 
					// closeAllStreamReaders closes readers in all streams.
 | 
				
			||||||
// This unblocks all stream.Read() calls.
 | 
					// This unblocks all stream.Read() calls, and keeps any future streams from being created.
 | 
				
			||||||
func (c *wsStreamCreator) closeAllStreamReaders(err error) {
 | 
					func (c *wsStreamCreator) closeAllStreamReaders(err error) {
 | 
				
			||||||
	c.streamsMu.Lock()
 | 
						c.streamsMu.Lock()
 | 
				
			||||||
	defer c.streamsMu.Unlock()
 | 
						defer c.streamsMu.Unlock()
 | 
				
			||||||
@@ -320,6 +331,12 @@ func (c *wsStreamCreator) closeAllStreamReaders(err error) {
 | 
				
			|||||||
		// Closing writePipe unblocks all readPipe.Read() callers and prevents any future writes.
 | 
							// Closing writePipe unblocks all readPipe.Read() callers and prevents any future writes.
 | 
				
			||||||
		_ = s.writePipe.CloseWithError(err)
 | 
							_ = s.writePipe.CloseWithError(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						// ensure callers to setStreams receive an error after this point
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							c.setStreamErr = err
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							c.setStreamErr = fmt.Errorf("closed all streams")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type stream struct {
 | 
					type stream struct {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1116,6 +1116,14 @@ func TestWebSocketClient_HeartbeatSucceeds(t *testing.T) {
 | 
				
			|||||||
	wg.Wait()
 | 
						wg.Wait()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestLateStreamCreation(t *testing.T) {
 | 
				
			||||||
 | 
						c := newWSStreamCreator(nil)
 | 
				
			||||||
 | 
						c.closeAllStreamReaders(nil)
 | 
				
			||||||
 | 
						if err := c.setStream(0, nil); err == nil {
 | 
				
			||||||
 | 
							t.Fatal("expected error adding stream after closeAllStreamReaders")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestWebSocketClient_StreamsAndExpectedErrors(t *testing.T) {
 | 
					func TestWebSocketClient_StreamsAndExpectedErrors(t *testing.T) {
 | 
				
			||||||
	// Validate Stream functions.
 | 
						// Validate Stream functions.
 | 
				
			||||||
	c := newWSStreamCreator(nil)
 | 
						c := newWSStreamCreator(nil)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user