Merge pull request #9159 from csrwng/remove_cors_headers
Remove CORS headers from pod proxy responses
This commit is contained in:
@@ -214,9 +214,37 @@ func (h *UpgradeAwareProxyHandler) defaultProxyTransport(url *url.URL) http.Roun
|
|||||||
suffix += "/"
|
suffix += "/"
|
||||||
}
|
}
|
||||||
pathPrepend := strings.TrimSuffix(url.Path, suffix)
|
pathPrepend := strings.TrimSuffix(url.Path, suffix)
|
||||||
return &proxy.Transport{
|
internalTransport := &proxy.Transport{
|
||||||
Scheme: scheme,
|
Scheme: scheme,
|
||||||
Host: host,
|
Host: host,
|
||||||
PathPrepend: pathPrepend,
|
PathPrepend: pathPrepend,
|
||||||
}
|
}
|
||||||
|
return &corsRemovingTransport{
|
||||||
|
RoundTripper: internalTransport,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// corsRemovingTransport is a wrapper for an internal transport. It removes CORS headers
|
||||||
|
// from the internal response.
|
||||||
|
type corsRemovingTransport struct {
|
||||||
|
http.RoundTripper
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *corsRemovingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
resp, err := p.RoundTripper.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
removeCORSHeaders(resp)
|
||||||
|
return resp, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeCORSHeaders strip CORS headers sent from the backend
|
||||||
|
// This should be called on all responses before returning
|
||||||
|
func removeCORSHeaders(resp *http.Response) {
|
||||||
|
resp.Header.Del("Access-Control-Allow-Credentials")
|
||||||
|
resp.Header.Del("Access-Control-Allow-Headers")
|
||||||
|
resp.Header.Del("Access-Control-Allow-Methods")
|
||||||
|
resp.Header.Del("Access-Control-Allow-Origin")
|
||||||
}
|
}
|
||||||
|
@@ -51,8 +51,10 @@ func (s *SimpleBackendHandler) ServeHTTP(w http.ResponseWriter, req *http.Reques
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range s.responseHeader {
|
if s.responseHeader != nil {
|
||||||
w.Header().Add(k, v)
|
for k, v := range s.responseHeader {
|
||||||
|
w.Header().Add(k, v)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
w.Write([]byte(s.responseBody))
|
w.Write([]byte(s.responseBody))
|
||||||
}
|
}
|
||||||
@@ -71,7 +73,7 @@ func validateParameters(t *testing.T, name string, actual url.Values, expected m
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateHeaders(t *testing.T, name string, actual http.Header, expected map[string]string) {
|
func validateHeaders(t *testing.T, name string, actual http.Header, expected map[string]string, notExpected []string) {
|
||||||
for k, v := range expected {
|
for k, v := range expected {
|
||||||
actualValue, ok := actual[k]
|
actualValue, ok := actual[k]
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -83,17 +85,28 @@ func validateHeaders(t *testing.T, name string, actual http.Header, expected map
|
|||||||
name, k, actualValue, v)
|
name, k, actualValue, v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if notExpected == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, h := range notExpected {
|
||||||
|
if _, present := actual[h]; present {
|
||||||
|
t.Errorf("%s: unexpected header: %s", name, h)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServeHTTP(t *testing.T) {
|
func TestServeHTTP(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
method string
|
method string
|
||||||
requestPath string
|
requestPath string
|
||||||
expectedPath string
|
expectedPath string
|
||||||
requestBody string
|
requestBody string
|
||||||
requestParams map[string]string
|
requestParams map[string]string
|
||||||
requestHeader map[string]string
|
requestHeader map[string]string
|
||||||
|
responseHeader map[string]string
|
||||||
|
expectedRespHeader map[string]string
|
||||||
|
notExpectedRespHeader []string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "root path, simple get",
|
name: "root path, simple get",
|
||||||
@@ -128,14 +141,37 @@ func TestServeHTTP(t *testing.T) {
|
|||||||
requestPath: "",
|
requestPath: "",
|
||||||
expectedPath: "/",
|
expectedPath: "/",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "remove CORS headers",
|
||||||
|
method: "GET",
|
||||||
|
requestPath: "/some/path",
|
||||||
|
expectedPath: "/some/path",
|
||||||
|
responseHeader: map[string]string{
|
||||||
|
"Header1": "value1",
|
||||||
|
"Access-Control-Allow-Origin": "some.server",
|
||||||
|
"Access-Control-Allow-Methods": "GET"},
|
||||||
|
expectedRespHeader: map[string]string{
|
||||||
|
"Header1": "value1",
|
||||||
|
},
|
||||||
|
notExpectedRespHeader: []string{
|
||||||
|
"Access-Control-Allow-Origin",
|
||||||
|
"Access-Control-Allow-Methods",
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
func() {
|
func() {
|
||||||
backendResponse := "<html><head></head><body><a href=\"/test/path\">Hello</a></body></html>"
|
backendResponse := "<html><head></head><body><a href=\"/test/path\">Hello</a></body></html>"
|
||||||
|
backendResponseHeader := test.responseHeader
|
||||||
|
// Test a simple header if not specified in the test
|
||||||
|
if backendResponseHeader == nil && test.expectedRespHeader == nil {
|
||||||
|
backendResponseHeader = map[string]string{"Content-Type": "text/html"}
|
||||||
|
test.expectedRespHeader = map[string]string{"Content-Type": "text/html"}
|
||||||
|
}
|
||||||
backendHandler := &SimpleBackendHandler{
|
backendHandler := &SimpleBackendHandler{
|
||||||
responseBody: backendResponse,
|
responseBody: backendResponse,
|
||||||
responseHeader: map[string]string{"Content-Type": "text/html"},
|
responseHeader: backendResponseHeader,
|
||||||
}
|
}
|
||||||
backendServer := httptest.NewServer(backendHandler)
|
backendServer := httptest.NewServer(backendHandler)
|
||||||
defer backendServer.Close()
|
defer backendServer.Close()
|
||||||
@@ -197,9 +233,13 @@ func TestServeHTTP(t *testing.T) {
|
|||||||
|
|
||||||
// Headers
|
// Headers
|
||||||
validateHeaders(t, test.name+" backend request", backendHandler.requestHeader,
|
validateHeaders(t, test.name+" backend request", backendHandler.requestHeader,
|
||||||
test.requestHeader)
|
test.requestHeader, nil)
|
||||||
|
|
||||||
// Validate proxy response
|
// Validate proxy response
|
||||||
|
|
||||||
|
// Response Headers
|
||||||
|
validateHeaders(t, test.name+" backend headers", res.Header, test.expectedRespHeader, test.notExpectedRespHeader)
|
||||||
|
|
||||||
// Validate Body
|
// Validate Body
|
||||||
responseBody, err := ioutil.ReadAll(res.Body)
|
responseBody, err := ioutil.ReadAll(res.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -297,7 +337,7 @@ func TestDefaultProxyTransport(t *testing.T) {
|
|||||||
Location: locURL,
|
Location: locURL,
|
||||||
}
|
}
|
||||||
result := h.defaultProxyTransport(URL)
|
result := h.defaultProxyTransport(URL)
|
||||||
transport := result.(*proxy.Transport)
|
transport := result.(*corsRemovingTransport).RoundTripper.(*proxy.Transport)
|
||||||
if transport.Scheme != test.expectedScheme {
|
if transport.Scheme != test.expectedScheme {
|
||||||
t.Errorf("%s: unexpected scheme. Actual: %s, Expected: %s", test.name, transport.Scheme, test.expectedScheme)
|
t.Errorf("%s: unexpected scheme. Actual: %s, Expected: %s", test.name, transport.Scheme, test.expectedScheme)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user