diff --git a/api/next.pb.txt b/api/next.pb.txt index 9443db911..9fab4146d 100644 --- a/api/next.pb.txt +++ b/api/next.pb.txt @@ -6041,7 +6041,6 @@ file { file { name: "github.com/containerd/containerd/api/types/transfer/streaming.proto" package: "containerd.v1.types" - dependency: "google/protobuf/timestamp.proto" message_type { name: "Progress" field { @@ -6073,6 +6072,26 @@ file { json_name: "total" } } + message_type { + name: "Data" + field { + name: "data" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_BYTES + json_name: "data" + } + } + message_type { + name: "WindowUpdate" + field { + name: "update" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_INT32 + json_name: "update" + } + } options { go_package: "github.com/containerd/containerd/api/types/transfer" } diff --git a/api/types/transfer/streaming.pb.go b/api/types/transfer/streaming.pb.go index 9e6baae1e..e8e0b5796 100644 --- a/api/types/transfer/streaming.pb.go +++ b/api/types/transfer/streaming.pb.go @@ -24,7 +24,6 @@ package transfer import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" - _ "google.golang.org/protobuf/types/known/timestamppb" reflect "reflect" sync "sync" ) @@ -107,6 +106,100 @@ func (x *Progress) GetTotal() int64 { return 0 } +type Data struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` +} + +func (x *Data) Reset() { + *x = Data{} + if protoimpl.UnsafeEnabled { + mi := &file_github_com_containerd_containerd_api_types_transfer_streaming_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Data) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Data) ProtoMessage() {} + +func (x *Data) ProtoReflect() protoreflect.Message { + mi := &file_github_com_containerd_containerd_api_types_transfer_streaming_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Data.ProtoReflect.Descriptor instead. +func (*Data) Descriptor() ([]byte, []int) { + return file_github_com_containerd_containerd_api_types_transfer_streaming_proto_rawDescGZIP(), []int{1} +} + +func (x *Data) GetData() []byte { + if x != nil { + return x.Data + } + return nil +} + +type WindowUpdate struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Update int32 `protobuf:"varint,1,opt,name=update,proto3" json:"update,omitempty"` +} + +func (x *WindowUpdate) Reset() { + *x = WindowUpdate{} + if protoimpl.UnsafeEnabled { + mi := &file_github_com_containerd_containerd_api_types_transfer_streaming_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *WindowUpdate) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WindowUpdate) ProtoMessage() {} + +func (x *WindowUpdate) ProtoReflect() protoreflect.Message { + mi := &file_github_com_containerd_containerd_api_types_transfer_streaming_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WindowUpdate.ProtoReflect.Descriptor instead. +func (*WindowUpdate) Descriptor() ([]byte, []int) { + return file_github_com_containerd_containerd_api_types_transfer_streaming_proto_rawDescGZIP(), []int{2} +} + +func (x *WindowUpdate) GetUpdate() int32 { + if x != nil { + return x.Update + } + return 0 +} + var File_github_com_containerd_containerd_api_types_transfer_streaming_proto protoreflect.FileDescriptor var file_github_com_containerd_containerd_api_types_transfer_streaming_proto_rawDesc = []byte{ @@ -115,20 +208,22 @@ var file_github_com_containerd_containerd_api_types_transfer_streaming_proto_raw 0x72, 0x64, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x66, 0x65, 0x72, 0x2f, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x69, 0x6e, 0x67, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x13, 0x63, 0x6f, 0x6e, 0x74, 0x61, 0x69, 0x6e, 0x65, 0x72, - 0x64, 0x2e, 0x76, 0x31, 0x2e, 0x74, 0x79, 0x70, 0x65, 0x73, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, - 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, - 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x6e, 0x0a, 0x08, 0x50, - 0x72, 0x6f, 0x67, 0x72, 0x65, 0x73, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x72, 0x65, 0x66, 0x65, 0x72, - 0x65, 0x6e, 0x63, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x72, 0x65, 0x66, 0x65, - 0x72, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x16, 0x0a, - 0x06, 0x6f, 0x66, 0x66, 0x73, 0x65, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x6f, - 0x66, 0x66, 0x73, 0x65, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x03, 0x52, 0x05, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x42, 0x35, 0x5a, 0x33, 0x67, - 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x6e, 0x74, 0x61, 0x69, - 0x6e, 0x65, 0x72, 0x64, 0x2f, 0x63, 0x6f, 0x6e, 0x74, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x64, 0x2f, - 0x61, 0x70, 0x69, 0x2f, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x66, - 0x65, 0x72, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x64, 0x2e, 0x76, 0x31, 0x2e, 0x74, 0x79, 0x70, 0x65, 0x73, 0x22, 0x6e, 0x0a, 0x08, 0x50, 0x72, + 0x6f, 0x67, 0x72, 0x65, 0x73, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x72, 0x65, 0x66, 0x65, 0x72, 0x65, + 0x6e, 0x63, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x72, 0x65, 0x66, 0x65, 0x72, + 0x65, 0x6e, 0x63, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x16, 0x0a, 0x06, + 0x6f, 0x66, 0x66, 0x73, 0x65, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x6f, 0x66, + 0x66, 0x73, 0x65, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x03, 0x52, 0x05, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x22, 0x1a, 0x0a, 0x04, 0x44, 0x61, + 0x74, 0x61, 0x12, 0x12, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, + 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x22, 0x26, 0x0a, 0x0c, 0x57, 0x69, 0x6e, 0x64, 0x6f, 0x77, + 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x75, 0x70, 0x64, 0x61, 0x74, 0x65, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x06, 0x75, 0x70, 0x64, 0x61, 0x74, 0x65, 0x42, 0x35, + 0x5a, 0x33, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x6e, + 0x74, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x64, 0x2f, 0x63, 0x6f, 0x6e, 0x74, 0x61, 0x69, 0x6e, 0x65, + 0x72, 0x64, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2f, 0x74, 0x72, 0x61, + 0x6e, 0x73, 0x66, 0x65, 0x72, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -143,9 +238,11 @@ func file_github_com_containerd_containerd_api_types_transfer_streaming_proto_ra return file_github_com_containerd_containerd_api_types_transfer_streaming_proto_rawDescData } -var file_github_com_containerd_containerd_api_types_transfer_streaming_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_github_com_containerd_containerd_api_types_transfer_streaming_proto_msgTypes = make([]protoimpl.MessageInfo, 3) var file_github_com_containerd_containerd_api_types_transfer_streaming_proto_goTypes = []interface{}{ - (*Progress)(nil), // 0: containerd.v1.types.Progress + (*Progress)(nil), // 0: containerd.v1.types.Progress + (*Data)(nil), // 1: containerd.v1.types.Data + (*WindowUpdate)(nil), // 2: containerd.v1.types.WindowUpdate } var file_github_com_containerd_containerd_api_types_transfer_streaming_proto_depIdxs = []int32{ 0, // [0:0] is the sub-list for method output_type @@ -173,6 +270,30 @@ func file_github_com_containerd_containerd_api_types_transfer_streaming_proto_in return nil } } + file_github_com_containerd_containerd_api_types_transfer_streaming_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Data); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_github_com_containerd_containerd_api_types_transfer_streaming_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*WindowUpdate); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } type x struct{} out := protoimpl.TypeBuilder{ @@ -180,7 +301,7 @@ func file_github_com_containerd_containerd_api_types_transfer_streaming_proto_in GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_github_com_containerd_containerd_api_types_transfer_streaming_proto_rawDesc, NumEnums: 0, - NumMessages: 1, + NumMessages: 3, NumExtensions: 0, NumServices: 0, }, diff --git a/api/types/transfer/streaming.proto b/api/types/transfer/streaming.proto index fbf51562e..9a0ff559d 100644 --- a/api/types/transfer/streaming.proto +++ b/api/types/transfer/streaming.proto @@ -18,7 +18,7 @@ syntax = "proto3"; package containerd.v1.types; -import "google/protobuf/timestamp.proto"; +//import "google/protobuf/timestamp.proto"; option go_package = "github.com/containerd/containerd/api/types/transfer"; @@ -30,6 +30,14 @@ message Progress { } +message Data { + bytes data = 1; +} + +message WindowUpdate { + int32 update = 1; +} + /* type StatusInfo struct { diff --git a/pkg/transfer/streaming/stream.go b/pkg/transfer/streaming/stream.go new file mode 100644 index 000000000..f3bd4bb8e --- /dev/null +++ b/pkg/transfer/streaming/stream.go @@ -0,0 +1,210 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package streaming + +import ( + "context" + "crypto/rand" + "encoding/base64" + "errors" + "fmt" + "io" + "sync" + "time" + + transferapi "github.com/containerd/containerd/api/types/transfer" + "github.com/containerd/containerd/log" + "github.com/containerd/containerd/pkg/streaming" + "github.com/containerd/typeurl" +) + +const maxRead = 32 * 1024 +const windowSize = 2 * maxRead + +var bufPool = &sync.Pool{ + New: func() interface{} { + buffer := make([]byte, maxRead) + return &buffer + }, +} + +func SendStream(ctx context.Context, r io.Reader, stream streaming.Stream) { + window := make(chan int32) + go func() { + defer close(window) + for { + select { + case <-ctx.Done(): + return + default: + } + + any, err := stream.Recv() + if err != nil { + if !errors.Is(err, io.EOF) && !errors.Is(err, context.Canceled) { + log.G(ctx).WithError(err).Error("send stream ended without EOF") + } + return + } + i, err := typeurl.UnmarshalAny(any) + if err != nil { + log.G(ctx).WithError(err).Error("failed to unmarshal stream object") + continue + } + switch v := i.(type) { + case *transferapi.WindowUpdate: + select { + case <-ctx.Done(): + return + case window <- v.Update: + } + default: + log.G(ctx).Errorf("unexpected stream object of type %T", i) + } + } + }() + go func() { + defer stream.Close() + + buf := bufPool.Get().(*[]byte) + defer bufPool.Put(buf) + + var remaining int32 + + for { + if remaining > 0 { + // Don't wait for window update since there are remaining + select { + case <-ctx.Done(): + // TODO: Send error message on stream before close to allow remote side to return error + return + case update := <-window: + remaining += update + default: + } + } else { + // Block until window updated + select { + case <-ctx.Done(): + // TODO: Send error message on stream before close to allow remote side to return error + return + case update := <-window: + remaining = update + } + } + var max int32 = maxRead + if max > remaining { + max = remaining + } + b := (*buf)[:max] + n, err := r.Read(b) + if err != nil { + if err != io.EOF { + log.G(ctx).WithError(err).Errorf("failed to read stream source") + // TODO: Send error message on stream before close to allow remote side to return error + } + return + } + remaining = remaining - int32(n) + + data := &transferapi.Data{ + Data: b[:n], + } + any, err := typeurl.MarshalAny(data) + if err != nil { + log.G(ctx).WithError(err).Errorf("failed to marshal data for send") + // TODO: Send error message on stream before close to allow remote side to return error + return + } + if err := stream.Send(any); err != nil { + log.G(ctx).WithError(err).Errorf("send failed") + return + } + } + }() +} + +func ReceiveStream(ctx context.Context, stream streaming.Stream) io.Reader { + r, w := io.Pipe() + go func() { + defer stream.Close() + var window int32 + for { + var werr error + if window < windowSize { + update := &transferapi.WindowUpdate{ + Update: windowSize, + } + any, err := typeurl.MarshalAny(update) + if err != nil { + w.CloseWithError(fmt.Errorf("failed to marshal window update: %w", err)) + return + } + // check window update error after recv, stream may be complete + if werr = stream.Send(any); werr == nil { + window += windowSize + } else if werr == io.EOF { + // TODO: Why does send return EOF here + werr = nil + } + } + any, err := stream.Recv() + if err != nil { + if err == io.EOF { + err = nil + } else { + err = fmt.Errorf("received failed: %w", err) + } + w.CloseWithError(err) + return + } else if werr != nil { + // Try receive before erroring out + w.CloseWithError(fmt.Errorf("failed to send window update: %w", werr)) + return + } + i, err := typeurl.UnmarshalAny(any) + if err != nil { + w.CloseWithError(fmt.Errorf("failed to unmarshal received object: %w", err)) + return + } + switch v := i.(type) { + case *transferapi.Data: + n, err := w.Write(v.Data) + if err != nil { + w.CloseWithError(fmt.Errorf("failed to unmarshal received object: %w", err)) + // Close will error out sender + return + } + window = window - int32(n) + // TODO: Handle error case + default: + log.G(ctx).Warnf("Ignoring unknown stream object of type %T", i) + continue + } + } + + }() + + return r +} + +func GenerateID(prefix string) string { + t := time.Now() + var b [3]byte + rand.Read(b[:]) + return fmt.Sprintf("%s-%d-%s", prefix, t.Nanosecond(), base64.URLEncoding.EncodeToString(b[:])) +} diff --git a/pkg/transfer/streaming/stream_test.go b/pkg/transfer/streaming/stream_test.go new file mode 100644 index 000000000..c42d7ade9 --- /dev/null +++ b/pkg/transfer/streaming/stream_test.go @@ -0,0 +1,178 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package streaming + +import ( + "bytes" + "context" + "io" + "testing" + + "github.com/containerd/containerd/pkg/streaming" + "github.com/containerd/typeurl" +) + +func FuzzSendAndReceive(f *testing.F) { + f.Add([]byte{}) + f.Add([]byte{0}) + f.Add(bytes.Repeat([]byte{0}, windowSize+1)) + f.Add([]byte("hello")) + f.Add(bytes.Repeat([]byte("hello"), windowSize+1)) + f.Fuzz(func(t *testing.T, expected []byte) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rs, ws := pipeStream() + r, w := io.Pipe() + SendStream(ctx, r, ws) + or := ReceiveStream(ctx, rs) + + go func() { + io.Copy(w, bytes.NewBuffer(expected)) + w.Close() + }() + + actual, err := io.ReadAll(or) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(expected, actual) { + t.Fatalf("received bytes are not equal\n\tactual: %v\n\texpected:%v", actual, expected) + } + }) +} +func FuzzSendAndReceiveChain(f *testing.F) { + f.Add([]byte{}) + f.Add([]byte{0}) + f.Add(bytes.Repeat([]byte{0}, windowSize+1)) + f.Add([]byte("hello")) + f.Add(bytes.Repeat([]byte("hello"), windowSize+1)) + f.Fuzz(func(t *testing.T, expected []byte) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + r, w := io.Pipe() + + or := chainStreams(ctx, chainStreams(ctx, chainStreams(ctx, r))) + + go func() { + io.Copy(w, bytes.NewBuffer(expected)) + w.Close() + }() + + actual, err := io.ReadAll(or) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(expected, actual) { + t.Fatalf("received bytes are not equal\n\tactual: %v\n\texpected:%v", actual, expected) + } + }) +} + +func FuzzWriter(f *testing.F) { + f.Add([]byte{}) + f.Add([]byte{0}) + f.Add(bytes.Repeat([]byte{0}, windowSize+1)) + f.Add([]byte("hello")) + f.Add(bytes.Repeat([]byte("hello"), windowSize+1)) + f.Fuzz(func(t *testing.T, expected []byte) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rs, ws := pipeStream() + wc := WriteByteStream(ctx, ws) + or := ReceiveStream(ctx, rs) + + go func() { + io.Copy(wc, bytes.NewBuffer(expected)) + wc.Close() + }() + + actual, err := io.ReadAll(or) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(expected, actual) { + t.Fatalf("received bytes are not equal\n\tactual: %v\n\texpected:%v", actual, expected) + } + }) +} + +func chainStreams(ctx context.Context, r io.Reader) io.Reader { + rs, ws := pipeStream() + SendStream(ctx, r, ws) + return ReceiveStream(ctx, rs) +} + +func pipeStream() (streaming.Stream, streaming.Stream) { + r := make(chan typeurl.Any) + rc := make(chan struct{}) + w := make(chan typeurl.Any) + wc := make(chan struct{}) + + rs := &testStream{ + send: w, + recv: r, + closer: wc, + remote: rc, + } + ws := &testStream{ + send: r, + recv: w, + closer: rc, + remote: wc, + } + return rs, ws +} + +type testStream struct { + send chan<- typeurl.Any + recv <-chan typeurl.Any + closer chan struct{} + remote <-chan struct{} +} + +func (ts *testStream) Send(a typeurl.Any) error { + select { + case <-ts.remote: + return io.ErrClosedPipe + case ts.send <- a: + } + return nil +} +func (ts *testStream) Recv() (typeurl.Any, error) { + select { + case <-ts.remote: + return nil, io.EOF + case a := <-ts.recv: + return a, nil + } +} + +func (ts *testStream) Close() error { + select { + case <-ts.closer: + return nil + default: + } + close(ts.closer) + return nil +} diff --git a/pkg/transfer/streaming/writer.go b/pkg/transfer/streaming/writer.go new file mode 100644 index 000000000..bb06572d7 --- /dev/null +++ b/pkg/transfer/streaming/writer.go @@ -0,0 +1,130 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package streaming + +import ( + "context" + "errors" + "io" + "sync/atomic" + + transferapi "github.com/containerd/containerd/api/types/transfer" + "github.com/containerd/containerd/log" + "github.com/containerd/containerd/pkg/streaming" + "github.com/containerd/typeurl" +) + +func WriteByteStream(ctx context.Context, stream streaming.Stream) io.WriteCloser { + wbs := &writeByteStream{ + ctx: ctx, + stream: stream, + updated: make(chan struct{}, 1), + } + go func() { + for { + select { + case <-ctx.Done(): + return + default: + } + + any, err := stream.Recv() + if err != nil { + if !errors.Is(err, io.EOF) && !errors.Is(err, context.Canceled) { + log.G(ctx).WithError(err).Error("send byte stream ended without EOF") + } + return + } + i, err := typeurl.UnmarshalAny(any) + if err != nil { + log.G(ctx).WithError(err).Error("failed to unmarshal stream object") + continue + } + switch v := i.(type) { + case *transferapi.WindowUpdate: + atomic.AddInt32(&wbs.remaining, v.Update) + select { + case <-ctx.Done(): + return + case wbs.updated <- struct{}{}: + default: + // Don't block if no writes are waiting + } + default: + log.G(ctx).Errorf("unexpected stream object of type %T", i) + } + } + }() + + return wbs +} + +type writeByteStream struct { + ctx context.Context + stream streaming.Stream + remaining int32 + updated chan struct{} +} + +func (wbs *writeByteStream) Write(p []byte) (n int, err error) { + for len(p) > 0 { + remaining := atomic.LoadInt32(&wbs.remaining) + if remaining == 0 { + // Don't wait for window update since there are remaining + select { + case <-wbs.ctx.Done(): + // TODO: Send error message on stream before close to allow remote side to return error + err = io.ErrShortWrite + return + case <-wbs.updated: + continue + } + } + var max int32 = maxRead + if max > int32(len(p)) { + max = int32(len(p)) + } + if max > remaining { + max = remaining + } + // TODO: continue + //remaining = remaining - int32(n) + + data := &transferapi.Data{ + Data: p[:max], + } + var any typeurl.Any + any, err = typeurl.MarshalAny(data) + if err != nil { + log.G(wbs.ctx).WithError(err).Errorf("failed to marshal data for send") + // TODO: Send error message on stream before close to allow remote side to return error + return + } + if err = wbs.stream.Send(any); err != nil { + log.G(wbs.ctx).WithError(err).Errorf("send failed") + return + } + n += int(max) + p = p[max:] + atomic.AddInt32(&wbs.remaining, -1*max) + } + return +} + +func (wbs *writeByteStream) Close() error { + return wbs.stream.Close() +}