Explicitly flush headers when proxying
This commit is contained in:
		@@ -230,7 +230,34 @@ func (h *UpgradeAwareHandler) ServeHTTP(w http.ResponseWriter, req *http.Request
 | 
				
			|||||||
	proxy := httputil.NewSingleHostReverseProxy(&url.URL{Scheme: h.Location.Scheme, Host: h.Location.Host})
 | 
						proxy := httputil.NewSingleHostReverseProxy(&url.URL{Scheme: h.Location.Scheme, Host: h.Location.Host})
 | 
				
			||||||
	proxy.Transport = h.Transport
 | 
						proxy.Transport = h.Transport
 | 
				
			||||||
	proxy.FlushInterval = h.FlushInterval
 | 
						proxy.FlushInterval = h.FlushInterval
 | 
				
			||||||
	proxy.ServeHTTP(w, newReq)
 | 
						proxy.ServeHTTP(maybeWrapFlushHeadersWriter(w), newReq)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// maybeWrapFlushHeadersWriter wraps the given writer to force flushing headers prior to writing the response body.
 | 
				
			||||||
 | 
					// if the given writer does not support http.Flusher, http.Hijacker, and http.CloseNotifier, the original writer is returned.
 | 
				
			||||||
 | 
					// TODO(liggitt): drop this once https://github.com/golang/go/issues/31125 is fixed
 | 
				
			||||||
 | 
					func maybeWrapFlushHeadersWriter(w http.ResponseWriter) http.ResponseWriter {
 | 
				
			||||||
 | 
						flusher, isFlusher := w.(http.Flusher)
 | 
				
			||||||
 | 
						hijacker, isHijacker := w.(http.Hijacker)
 | 
				
			||||||
 | 
						closeNotifier, isCloseNotifier := w.(http.CloseNotifier)
 | 
				
			||||||
 | 
						// flusher, hijacker, and closeNotifier are all used by the ReverseProxy implementation.
 | 
				
			||||||
 | 
						// if the given writer can't support all three, return the original writer.
 | 
				
			||||||
 | 
						if !isFlusher || !isHijacker || !isCloseNotifier {
 | 
				
			||||||
 | 
							return w
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &flushHeadersWriter{w, flusher, hijacker, closeNotifier}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type flushHeadersWriter struct {
 | 
				
			||||||
 | 
						http.ResponseWriter
 | 
				
			||||||
 | 
						http.Flusher
 | 
				
			||||||
 | 
						http.Hijacker
 | 
				
			||||||
 | 
						http.CloseNotifier
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (w *flushHeadersWriter) WriteHeader(code int) {
 | 
				
			||||||
 | 
						w.ResponseWriter.WriteHeader(code)
 | 
				
			||||||
 | 
						w.Flusher.Flush()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// tryUpgrade returns true if the request was handled.
 | 
					// tryUpgrade returns true if the request was handled.
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -865,6 +865,46 @@ func TestProxyRequestContentLengthAndTransferEncoding(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestFlushIntervalHeaders(t *testing.T) {
 | 
				
			||||||
 | 
						const expected = "hi"
 | 
				
			||||||
 | 
						stopCh := make(chan struct{})
 | 
				
			||||||
 | 
						backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
							w.Header().Add("MyHeader", expected)
 | 
				
			||||||
 | 
							w.WriteHeader(200)
 | 
				
			||||||
 | 
							w.(http.Flusher).Flush()
 | 
				
			||||||
 | 
							<-stopCh
 | 
				
			||||||
 | 
						}))
 | 
				
			||||||
 | 
						defer backend.Close()
 | 
				
			||||||
 | 
						defer close(stopCh)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						backendURL, err := url.Parse(backend.URL)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						proxyHandler := NewUpgradeAwareHandler(backendURL, nil, false, false, nil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						frontend := httptest.NewServer(proxyHandler)
 | 
				
			||||||
 | 
						defer frontend.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						req, _ := http.NewRequest("GET", frontend.URL, nil)
 | 
				
			||||||
 | 
						req.Close = true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second)
 | 
				
			||||||
 | 
						defer cancel()
 | 
				
			||||||
 | 
						req = req.WithContext(ctx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						res, err := frontend.Client().Do(req)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("Get: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer res.Body.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if res.Header.Get("MyHeader") != expected {
 | 
				
			||||||
 | 
							t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// exampleCert was generated from crypto/tls/generate_cert.go with the following command:
 | 
					// exampleCert was generated from crypto/tls/generate_cert.go with the following command:
 | 
				
			||||||
//    go run generate_cert.go  --rsa-bits 512 --host example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
 | 
					//    go run generate_cert.go  --rsa-bits 512 --host example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
 | 
				
			||||||
var exampleCert = []byte(`-----BEGIN CERTIFICATE-----
 | 
					var exampleCert = []byte(`-----BEGIN CERTIFICATE-----
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user