
Following the convention of http2, we now use odd stream ids for client initiated streams. This makes it easier to tell who initiates the stream. We enforce the convention on the server-side. This allows us to upgrade the protocol in the future to have server initiated streams. Signed-off-by: Stephen J Day <stephen.day@docker.com>
156 lines
2.9 KiB
Go
156 lines
2.9 KiB
Go
package ttrpc
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
|
|
"github.com/containerd/containerd/log"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
type Server struct {
|
|
services *serviceSet
|
|
codec codec
|
|
}
|
|
|
|
func NewServer() *Server {
|
|
return &Server{
|
|
services: newServiceSet(),
|
|
}
|
|
}
|
|
|
|
func (s *Server) Register(name string, methods map[string]Method) {
|
|
s.services.register(name, methods)
|
|
}
|
|
|
|
func (s *Server) Shutdown(ctx context.Context) error {
|
|
// TODO(stevvooe): Wait on connection shutdown.
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) Serve(l net.Listener) error {
|
|
for {
|
|
conn, err := l.Accept()
|
|
if err != nil {
|
|
log.L.WithError(err).Error("failed accept")
|
|
continue
|
|
}
|
|
|
|
go s.handleConn(conn)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) handleConn(conn net.Conn) {
|
|
defer conn.Close()
|
|
|
|
type (
|
|
request struct {
|
|
id uint32
|
|
req *Request
|
|
}
|
|
|
|
response struct {
|
|
id uint32
|
|
resp *Response
|
|
}
|
|
)
|
|
|
|
var (
|
|
ch = newChannel(conn, conn)
|
|
ctx, cancel = context.WithCancel(context.Background())
|
|
responses = make(chan response)
|
|
requests = make(chan request)
|
|
recvErr = make(chan error, 1)
|
|
done = make(chan struct{})
|
|
)
|
|
|
|
defer cancel()
|
|
defer close(done)
|
|
|
|
go func() {
|
|
defer close(recvErr)
|
|
var p [messageLengthMax]byte
|
|
for {
|
|
mh, err := ch.recv(ctx, p[:])
|
|
if err != nil {
|
|
recvErr <- err
|
|
return
|
|
}
|
|
|
|
if mh.Type != messageTypeRequest {
|
|
// we must ignore this for future compat.
|
|
continue
|
|
}
|
|
|
|
var req Request
|
|
if err := s.codec.Unmarshal(p[:mh.Length], &req); err != nil {
|
|
recvErr <- err
|
|
return
|
|
}
|
|
|
|
if mh.StreamID%2 != 1 {
|
|
// enforce odd client initiated identifiers.
|
|
select {
|
|
case responses <- response{
|
|
// even though we've had an invalid stream id, we send it
|
|
// back on the same stream id so the client knows which
|
|
// stream id was bad.
|
|
id: mh.StreamID,
|
|
resp: &Response{
|
|
Status: status.New(codes.InvalidArgument, "StreamID must be odd for client initiated streams").Proto(),
|
|
},
|
|
}:
|
|
case <-done:
|
|
}
|
|
|
|
continue
|
|
}
|
|
|
|
select {
|
|
case requests <- request{
|
|
id: mh.StreamID,
|
|
req: &req,
|
|
}:
|
|
case <-done:
|
|
}
|
|
}
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case request := <-requests:
|
|
go func(id uint32) {
|
|
p, status := s.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload)
|
|
resp := &Response{
|
|
Status: status.Proto(),
|
|
Payload: p,
|
|
}
|
|
|
|
select {
|
|
case responses <- response{
|
|
id: id,
|
|
resp: resp,
|
|
}:
|
|
case <-done:
|
|
}
|
|
}(request.id)
|
|
case response := <-responses:
|
|
p, err := s.codec.Marshal(response.resp)
|
|
if err != nil {
|
|
log.L.WithError(err).Error("failed marshaling response")
|
|
return
|
|
}
|
|
if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil {
|
|
log.L.WithError(err).Error("failed sending message on channel")
|
|
return
|
|
}
|
|
case err := <-recvErr:
|
|
log.L.WithError(err).Error("error receiving message")
|
|
return
|
|
}
|
|
}
|
|
}
|