| @@ -99,6 +99,10 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int | ||||
| 		cresp = &Response{} | ||||
| 	) | ||||
|  | ||||
| 	if headers, ok := GetHeaders(ctx); ok { | ||||
| 		creq.Headers = headers | ||||
| 	} | ||||
|  | ||||
| 	if dl, ok := ctx.Deadline(); ok { | ||||
| 		creq.TimeoutNano = dl.Sub(time.Now()).Nanoseconds() | ||||
| 	} | ||||
|   | ||||
							
								
								
									
										86
									
								
								header.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								header.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,86 @@ | ||||
| /* | ||||
|    Copyright The containerd Authors. | ||||
|  | ||||
|    Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|    you may not use this file except in compliance with the License. | ||||
|    You may obtain a copy of the License at | ||||
|  | ||||
|        http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
|    Unless required by applicable law or agreed to in writing, software | ||||
|    distributed under the License is distributed on an "AS IS" BASIS, | ||||
|    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|    See the License for the specific language governing permissions and | ||||
|    limitations under the License. | ||||
| */ | ||||
|  | ||||
| package ttrpc | ||||
|  | ||||
| import "context" | ||||
|  | ||||
| // Headers represents the key-value pairs (similar to http.Header) to be passed to ttrpc server from a client. | ||||
| type Headers map[string]StringList | ||||
|  | ||||
| // Get returns the headers for a given key when they exist. | ||||
| // If there are no headers, a nil slice and false are returned. | ||||
| func (h Headers) Get(key string) ([]string, bool) { | ||||
| 	list, ok := h[key] | ||||
| 	if !ok || len(list.List) == 0 { | ||||
| 		return nil, false | ||||
| 	} | ||||
|  | ||||
| 	return list.List, true | ||||
| } | ||||
|  | ||||
| // Set sets the provided values for a given key. | ||||
| // The values will overwrite any existing values. | ||||
| // If no values provided, a key will be deleted. | ||||
| func (h Headers) Set(key string, values ...string) { | ||||
| 	if len(values) == 0 { | ||||
| 		delete(h, key) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	h[key] = StringList{List: values} | ||||
| } | ||||
|  | ||||
| // Append appends additional values to the given key. | ||||
| func (h Headers) Append(key string, values ...string) { | ||||
| 	if len(values) == 0 { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	list, ok := h[key] | ||||
| 	if ok { | ||||
| 		h.Set(key, append(list.List, values...)...) | ||||
| 	} else { | ||||
| 		h.Set(key, values...) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type headerKey struct{} | ||||
|  | ||||
| // GetHeaders retrieves headers from context.Context (previously attached with WithHeaders) | ||||
| func GetHeaders(ctx context.Context) (Headers, bool) { | ||||
| 	headers, ok := ctx.Value(headerKey{}).(Headers) | ||||
| 	return headers, ok | ||||
| } | ||||
|  | ||||
| // GetHeader gets a specific header value by name from context.Context | ||||
| func GetHeader(ctx context.Context, name string) (string, bool) { | ||||
| 	headers, ok := GetHeaders(ctx) | ||||
| 	if !ok { | ||||
| 		return "", false | ||||
| 	} | ||||
|  | ||||
| 	if list, ok := headers.Get(name); ok { | ||||
| 		return list[0], true | ||||
| 	} | ||||
|  | ||||
| 	return "", false | ||||
| } | ||||
|  | ||||
| // WithHeaders attaches headers map to a context.Context | ||||
| func WithHeaders(ctx context.Context, headers Headers) context.Context { | ||||
| 	return context.WithValue(ctx, headerKey{}, headers) | ||||
| } | ||||
							
								
								
									
										108
									
								
								header_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										108
									
								
								header_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,108 @@ | ||||
| /* | ||||
|    Copyright The containerd Authors. | ||||
|  | ||||
|    Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|    you may not use this file except in compliance with the License. | ||||
|    You may obtain a copy of the License at | ||||
|  | ||||
|        http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
|    Unless required by applicable law or agreed to in writing, software | ||||
|    distributed under the License is distributed on an "AS IS" BASIS, | ||||
|    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|    See the License for the specific language governing permissions and | ||||
|    limitations under the License. | ||||
| */ | ||||
|  | ||||
| package ttrpc | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestHeaders_Get(t *testing.T) { | ||||
| 	hdr := make(Headers) | ||||
| 	hdr.Set("foo", "1", "2") | ||||
|  | ||||
| 	if list, ok := hdr.Get("foo"); !ok { | ||||
| 		t.Error("key not found") | ||||
| 	} else if len(list) != 2 { | ||||
| 		t.Errorf("unexpected number of values: %d", len(list)) | ||||
| 	} else if list[0] != "1" { | ||||
| 		t.Errorf("invalid header value at 0: %s", list[0]) | ||||
| 	} else if list[1] != "2" { | ||||
| 		t.Errorf("invalid header value at 1: %s", list[1]) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestHeaders_GetInvalidKey(t *testing.T) { | ||||
| 	hdr := make(Headers) | ||||
| 	hdr.Set("foo", "1", "2") | ||||
|  | ||||
| 	if _, ok := hdr.Get("invalid"); ok { | ||||
| 		t.Error("found invalid key") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestHeaders_Unset(t *testing.T) { | ||||
| 	hdr := make(Headers) | ||||
| 	hdr.Set("foo", "1", "2") | ||||
| 	hdr.Set("foo") | ||||
|  | ||||
| 	if _, ok := hdr.Get("foo"); ok { | ||||
| 		t.Error("key not deleted") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestHeader_Replace(t *testing.T) { | ||||
| 	hdr := make(Headers) | ||||
| 	hdr.Set("foo", "1", "2") | ||||
| 	hdr.Set("foo", "3", "4") | ||||
|  | ||||
| 	if list, ok := hdr.Get("foo"); !ok { | ||||
| 		t.Error("key not found") | ||||
| 	} else if len(list) != 2 { | ||||
| 		t.Errorf("unexpected number of values: %d", len(list)) | ||||
| 	} else if list[0] != "3" { | ||||
| 		t.Errorf("invalid header value at 0: %s", list[0]) | ||||
| 	} else if list[1] != "4" { | ||||
| 		t.Errorf("invalid header value at 1: %s", list[1]) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestHeaders_Append(t *testing.T) { | ||||
| 	hdr := make(Headers) | ||||
| 	hdr.Set("foo", "1") | ||||
| 	hdr.Append("foo", "2") | ||||
| 	hdr.Append("bar", "3") | ||||
|  | ||||
| 	if list, ok := hdr.Get("foo"); !ok { | ||||
| 		t.Error("key not found") | ||||
| 	} else if len(list) != 2 { | ||||
| 		t.Errorf("unexpected number of values: %d", len(list)) | ||||
| 	} else if list[0] != "1" { | ||||
| 		t.Errorf("invalid header value at 0: %s", list[0]) | ||||
| 	} else if list[1] != "2" { | ||||
| 		t.Errorf("invalid header value at 1: %s", list[1]) | ||||
| 	} | ||||
|  | ||||
| 	if list, ok := hdr.Get("bar"); !ok { | ||||
| 		t.Error("key not found") | ||||
| 	} else if list[0] != "3" { | ||||
| 		t.Errorf("invalid value: %s", list[0]) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestHeaders_Context(t *testing.T) { | ||||
| 	hdr := make(Headers) | ||||
| 	hdr.Set("foo", "bar") | ||||
|  | ||||
| 	ctx := WithHeaders(context.Background(), hdr) | ||||
|  | ||||
| 	if bar, ok := GetHeader(ctx, "foo"); !ok { | ||||
| 		t.Error("header not found") | ||||
| 	} else if bar != "bar" { | ||||
| 		t.Errorf("invalid header value: %q", bar) | ||||
| 	} | ||||
| } | ||||
| @@ -466,6 +466,10 @@ func (c *serverConn) run(sctx context.Context) { | ||||
| var noopFunc = func() {} | ||||
|  | ||||
| func getRequestContext(ctx context.Context, req *Request) (retCtx context.Context, cancel func()) { | ||||
| 	if req.Headers != nil { | ||||
| 		ctx = WithHeaders(ctx, req.Headers) | ||||
| 	} | ||||
|  | ||||
| 	cancel = noopFunc | ||||
| 	if req.TimeoutNano == 0 { | ||||
| 		return ctx, cancel | ||||
|   | ||||
| @@ -61,6 +61,7 @@ func (tc *testingClient) Test(ctx context.Context, req *testPayload) (*testPaylo | ||||
| type testPayload struct { | ||||
| 	Foo      string `protobuf:"bytes,1,opt,name=foo,proto3"` | ||||
| 	Deadline int64  `protobuf:"varint,2,opt,name=deadline,proto3"` | ||||
| 	Hdr      string `protobuf:"bytes,3,opt,name=hdr,proto3"` | ||||
| } | ||||
|  | ||||
| func (r *testPayload) Reset()         { *r = testPayload{} } | ||||
| @@ -75,6 +76,11 @@ func (s *testingServer) Test(ctx context.Context, req *testPayload) (*testPayloa | ||||
| 	if dl, ok := ctx.Deadline(); ok { | ||||
| 		tp.Deadline = dl.UnixNano() | ||||
| 	} | ||||
|  | ||||
| 	if v, ok := GetHeader(ctx, "foo"); ok { | ||||
| 		tp.Hdr = v | ||||
| 	} | ||||
|  | ||||
| 	return tp, nil | ||||
| } | ||||
|  | ||||
| @@ -540,6 +546,8 @@ func roundTrip(ctx context.Context, t *testing.T, client *testingClient, value s | ||||
| 		} | ||||
| 	) | ||||
|  | ||||
| 	ctx = WithHeaders(ctx, Headers{"foo": makeStringList("bar")}) | ||||
|  | ||||
| 	resp, err := client.Test(ctx, tp) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| @@ -547,7 +555,7 @@ func roundTrip(ctx context.Context, t *testing.T, client *testingClient, value s | ||||
|  | ||||
| 	results <- callResult{ | ||||
| 		input:    tp, | ||||
| 		expected: &testPayload{Foo: strings.Repeat(tp.Foo, 2)}, | ||||
| 		expected: &testPayload{Foo: strings.Repeat(tp.Foo, 2), Hdr: "bar"}, | ||||
| 		received: resp, | ||||
| 	} | ||||
| } | ||||
|   | ||||
							
								
								
									
										11
									
								
								types.go
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								types.go
									
									
									
									
									
								
							| @@ -27,6 +27,7 @@ type Request struct { | ||||
| 	Method      string  `protobuf:"bytes,2,opt,name=method,proto3"` | ||||
| 	Payload     []byte  `protobuf:"bytes,3,opt,name=payload,proto3"` | ||||
| 	TimeoutNano int64   `protobuf:"varint,4,opt,name=timeout_nano,proto3"` | ||||
| 	Headers     Headers `protobuf:"bytes,5,opt,name=headers,proto3" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` | ||||
| } | ||||
|  | ||||
| func (r *Request) Reset()         { *r = Request{} } | ||||
| @@ -41,3 +42,13 @@ type Response struct { | ||||
| func (r *Response) Reset()         { *r = Response{} } | ||||
| func (r *Response) String() string { return fmt.Sprintf("%+#v", r) } | ||||
| func (r *Response) ProtoMessage()  {} | ||||
|  | ||||
| type StringList struct { | ||||
| 	List []string `protobuf:"bytes,1,rep,name=list,proto3"` | ||||
| } | ||||
|  | ||||
| func (r *StringList) Reset()         { *r = StringList{} } | ||||
| func (r *StringList) String() string { return fmt.Sprintf("%+#v", r) } | ||||
| func (r *StringList) ProtoMessage()  {} | ||||
|  | ||||
| func makeStringList(item ...string) StringList { return StringList{List: item} } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Maksym Pavlenko
					Maksym Pavlenko