Explicitly flush headers when proxying
This commit is contained in:
		| @@ -230,7 +230,34 @@ func (h *UpgradeAwareHandler) ServeHTTP(w http.ResponseWriter, req *http.Request | ||||
| 	proxy := httputil.NewSingleHostReverseProxy(&url.URL{Scheme: h.Location.Scheme, Host: h.Location.Host}) | ||||
| 	proxy.Transport = h.Transport | ||||
| 	proxy.FlushInterval = h.FlushInterval | ||||
| 	proxy.ServeHTTP(w, newReq) | ||||
| 	proxy.ServeHTTP(maybeWrapFlushHeadersWriter(w), newReq) | ||||
| } | ||||
|  | ||||
| // maybeWrapFlushHeadersWriter wraps the given writer to force flushing headers prior to writing the response body. | ||||
| // if the given writer does not support http.Flusher, http.Hijacker, and http.CloseNotifier, the original writer is returned. | ||||
| // TODO(liggitt): drop this once https://github.com/golang/go/issues/31125 is fixed | ||||
| func maybeWrapFlushHeadersWriter(w http.ResponseWriter) http.ResponseWriter { | ||||
| 	flusher, isFlusher := w.(http.Flusher) | ||||
| 	hijacker, isHijacker := w.(http.Hijacker) | ||||
| 	closeNotifier, isCloseNotifier := w.(http.CloseNotifier) | ||||
| 	// flusher, hijacker, and closeNotifier are all used by the ReverseProxy implementation. | ||||
| 	// if the given writer can't support all three, return the original writer. | ||||
| 	if !isFlusher || !isHijacker || !isCloseNotifier { | ||||
| 		return w | ||||
| 	} | ||||
| 	return &flushHeadersWriter{w, flusher, hijacker, closeNotifier} | ||||
| } | ||||
|  | ||||
| type flushHeadersWriter struct { | ||||
| 	http.ResponseWriter | ||||
| 	http.Flusher | ||||
| 	http.Hijacker | ||||
| 	http.CloseNotifier | ||||
| } | ||||
|  | ||||
| func (w *flushHeadersWriter) WriteHeader(code int) { | ||||
| 	w.ResponseWriter.WriteHeader(code) | ||||
| 	w.Flusher.Flush() | ||||
| } | ||||
|  | ||||
| // tryUpgrade returns true if the request was handled. | ||||
|   | ||||
| @@ -865,6 +865,46 @@ func TestProxyRequestContentLengthAndTransferEncoding(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestFlushIntervalHeaders(t *testing.T) { | ||||
| 	const expected = "hi" | ||||
| 	stopCh := make(chan struct{}) | ||||
| 	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		w.Header().Add("MyHeader", expected) | ||||
| 		w.WriteHeader(200) | ||||
| 		w.(http.Flusher).Flush() | ||||
| 		<-stopCh | ||||
| 	})) | ||||
| 	defer backend.Close() | ||||
| 	defer close(stopCh) | ||||
|  | ||||
| 	backendURL, err := url.Parse(backend.URL) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	proxyHandler := NewUpgradeAwareHandler(backendURL, nil, false, false, nil) | ||||
|  | ||||
| 	frontend := httptest.NewServer(proxyHandler) | ||||
| 	defer frontend.Close() | ||||
|  | ||||
| 	req, _ := http.NewRequest("GET", frontend.URL, nil) | ||||
| 	req.Close = true | ||||
|  | ||||
| 	ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second) | ||||
| 	defer cancel() | ||||
| 	req = req.WithContext(ctx) | ||||
|  | ||||
| 	res, err := frontend.Client().Do(req) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Get: %v", err) | ||||
| 	} | ||||
| 	defer res.Body.Close() | ||||
|  | ||||
| 	if res.Header.Get("MyHeader") != expected { | ||||
| 		t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // exampleCert was generated from crypto/tls/generate_cert.go with the following command: | ||||
| //    go run generate_cert.go  --rsa-bits 512 --host example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h | ||||
| var exampleCert = []byte(`-----BEGIN CERTIFICATE----- | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jordan Liggitt
					Jordan Liggitt