Merge pull request #9159 from csrwng/remove_cors_headers

Remove CORS headers from pod proxy responses
This commit is contained in:
Quinton Hoole
2015-06-05 11:40:09 -07:00
2 changed files with 82 additions and 14 deletions

View File

@@ -214,9 +214,37 @@ func (h *UpgradeAwareProxyHandler) defaultProxyTransport(url *url.URL) http.Roun
suffix += "/" suffix += "/"
} }
pathPrepend := strings.TrimSuffix(url.Path, suffix) pathPrepend := strings.TrimSuffix(url.Path, suffix)
return &proxy.Transport{ internalTransport := &proxy.Transport{
Scheme: scheme, Scheme: scheme,
Host: host, Host: host,
PathPrepend: pathPrepend, PathPrepend: pathPrepend,
} }
return &corsRemovingTransport{
RoundTripper: internalTransport,
}
}
// corsRemovingTransport is a wrapper for an internal transport. It removes CORS headers
// from the internal response.
type corsRemovingTransport struct {
http.RoundTripper
}
func (p *corsRemovingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := p.RoundTripper.RoundTrip(req)
if err != nil {
return nil, err
}
removeCORSHeaders(resp)
return resp, nil
}
// removeCORSHeaders strip CORS headers sent from the backend
// This should be called on all responses before returning
func removeCORSHeaders(resp *http.Response) {
resp.Header.Del("Access-Control-Allow-Credentials")
resp.Header.Del("Access-Control-Allow-Headers")
resp.Header.Del("Access-Control-Allow-Methods")
resp.Header.Del("Access-Control-Allow-Origin")
} }

View File

@@ -51,8 +51,10 @@ func (s *SimpleBackendHandler) ServeHTTP(w http.ResponseWriter, req *http.Reques
return return
} }
for k, v := range s.responseHeader { if s.responseHeader != nil {
w.Header().Add(k, v) for k, v := range s.responseHeader {
w.Header().Add(k, v)
}
} }
w.Write([]byte(s.responseBody)) w.Write([]byte(s.responseBody))
} }
@@ -71,7 +73,7 @@ func validateParameters(t *testing.T, name string, actual url.Values, expected m
} }
} }
func validateHeaders(t *testing.T, name string, actual http.Header, expected map[string]string) { func validateHeaders(t *testing.T, name string, actual http.Header, expected map[string]string, notExpected []string) {
for k, v := range expected { for k, v := range expected {
actualValue, ok := actual[k] actualValue, ok := actual[k]
if !ok { if !ok {
@@ -83,17 +85,28 @@ func validateHeaders(t *testing.T, name string, actual http.Header, expected map
name, k, actualValue, v) name, k, actualValue, v)
} }
} }
if notExpected == nil {
return
}
for _, h := range notExpected {
if _, present := actual[h]; present {
t.Errorf("%s: unexpected header: %s", name, h)
}
}
} }
func TestServeHTTP(t *testing.T) { func TestServeHTTP(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
method string method string
requestPath string requestPath string
expectedPath string expectedPath string
requestBody string requestBody string
requestParams map[string]string requestParams map[string]string
requestHeader map[string]string requestHeader map[string]string
responseHeader map[string]string
expectedRespHeader map[string]string
notExpectedRespHeader []string
}{ }{
{ {
name: "root path, simple get", name: "root path, simple get",
@@ -128,14 +141,37 @@ func TestServeHTTP(t *testing.T) {
requestPath: "", requestPath: "",
expectedPath: "/", expectedPath: "/",
}, },
{
name: "remove CORS headers",
method: "GET",
requestPath: "/some/path",
expectedPath: "/some/path",
responseHeader: map[string]string{
"Header1": "value1",
"Access-Control-Allow-Origin": "some.server",
"Access-Control-Allow-Methods": "GET"},
expectedRespHeader: map[string]string{
"Header1": "value1",
},
notExpectedRespHeader: []string{
"Access-Control-Allow-Origin",
"Access-Control-Allow-Methods",
},
},
} }
for _, test := range tests { for _, test := range tests {
func() { func() {
backendResponse := "<html><head></head><body><a href=\"/test/path\">Hello</a></body></html>" backendResponse := "<html><head></head><body><a href=\"/test/path\">Hello</a></body></html>"
backendResponseHeader := test.responseHeader
// Test a simple header if not specified in the test
if backendResponseHeader == nil && test.expectedRespHeader == nil {
backendResponseHeader = map[string]string{"Content-Type": "text/html"}
test.expectedRespHeader = map[string]string{"Content-Type": "text/html"}
}
backendHandler := &SimpleBackendHandler{ backendHandler := &SimpleBackendHandler{
responseBody: backendResponse, responseBody: backendResponse,
responseHeader: map[string]string{"Content-Type": "text/html"}, responseHeader: backendResponseHeader,
} }
backendServer := httptest.NewServer(backendHandler) backendServer := httptest.NewServer(backendHandler)
defer backendServer.Close() defer backendServer.Close()
@@ -197,9 +233,13 @@ func TestServeHTTP(t *testing.T) {
// Headers // Headers
validateHeaders(t, test.name+" backend request", backendHandler.requestHeader, validateHeaders(t, test.name+" backend request", backendHandler.requestHeader,
test.requestHeader) test.requestHeader, nil)
// Validate proxy response // Validate proxy response
// Response Headers
validateHeaders(t, test.name+" backend headers", res.Header, test.expectedRespHeader, test.notExpectedRespHeader)
// Validate Body // Validate Body
responseBody, err := ioutil.ReadAll(res.Body) responseBody, err := ioutil.ReadAll(res.Body)
if err != nil { if err != nil {
@@ -297,7 +337,7 @@ func TestDefaultProxyTransport(t *testing.T) {
Location: locURL, Location: locURL,
} }
result := h.defaultProxyTransport(URL) result := h.defaultProxyTransport(URL)
transport := result.(*proxy.Transport) transport := result.(*corsRemovingTransport).RoundTripper.(*proxy.Transport)
if transport.Scheme != test.expectedScheme { if transport.Scheme != test.expectedScheme {
t.Errorf("%s: unexpected scheme. Actual: %s, Expected: %s", test.name, transport.Scheme, test.expectedScheme) t.Errorf("%s: unexpected scheme. Actual: %s, Expected: %s", test.name, transport.Scheme, test.expectedScheme)
} }