diff --git a/container_test.go b/container_test.go index 00b9c43b0..44ac5cf5b 100644 --- a/container_test.go +++ b/container_test.go @@ -3,6 +3,9 @@ package containerd import ( "bytes" "context" + "fmt" + "io/ioutil" + "os" "syscall" "testing" ) @@ -360,3 +363,85 @@ func TestContainerProcesses(t *testing.T) { } <-statusC } + +func TestContainerCloseStdin(t *testing.T) { + if testing.Short() { + t.Skip() + } + client, err := New(address) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + var ( + ctx = context.Background() + id = "ContainerCloseStdin" + ) + image, err := client.GetImage(ctx, testImage) + if err != nil { + t.Error(err) + return + } + spec, err := GenerateSpec(WithImageConfig(ctx, image), WithProcessArgs("cat")) + if err != nil { + t.Error(err) + return + } + container, err := client.NewContainer(ctx, id, spec, WithImage(image), WithNewRootFS(id, image)) + if err != nil { + t.Error(err) + return + } + defer container.Delete(ctx) + + const expected = "hello\n" + stdout := bytes.NewBuffer(nil) + + r, w, err := os.Pipe() + if err != nil { + t.Error(err) + return + } + + task, err := container.NewTask(ctx, NewIO(r, stdout, ioutil.Discard)) + if err != nil { + t.Error(err) + return + } + defer task.Delete(ctx) + + statusC := make(chan uint32, 1) + go func() { + status, err := task.Wait(ctx) + if err != nil { + t.Error(err) + } + statusC <- status + }() + + if err := task.Start(ctx); err != nil { + t.Error(err) + return + } + + if _, err := fmt.Fprint(w, expected); err != nil { + t.Error(err) + } + w.Close() + if err := task.CloseStdin(ctx); err != nil { + t.Error(err) + } + + <-statusC + + if _, err := task.Delete(ctx); err != nil { + t.Error(err) + } + + output := stdout.String() + + if output != expected { + t.Errorf("expected output %q but received %q", expected, output) + } +} diff --git a/io.go b/io.go index ea3325967..016053142 100644 --- a/io.go +++ b/io.go @@ -54,6 +54,32 @@ func BufferedIO(stdin, stdout, stderr *bytes.Buffer) IOCreation { } } +func NewIO(stdin io.Reader, stdout, stderr io.Writer) IOCreation { + return func() (*IO, error) { + paths, err := fifoPaths() + if err != nil { + return nil, err + } + i := &IO{ + Terminal: false, + Stdout: paths.out, + Stderr: paths.err, + Stdin: paths.in, + } + set := &ioSet{ + in: stdin, + out: stdout, + err: stderr, + } + closer, err := copyIO(paths, set, false) + if err != nil { + return nil, err + } + i.closer = closer + return i, nil + } +} + // Stdio returns an IO implementation to be used for a task // that outputs the container's IO as the current processes Stdio func Stdio() (*IO, error) { diff --git a/task.go b/task.go index ba5489a64..b56f0db99 100644 --- a/task.go +++ b/task.go @@ -32,6 +32,7 @@ type Task interface { Wait(context.Context) (uint32, error) Exec(context.Context, *specs.Process, IOCreation) (Process, error) Processes(context.Context) ([]uint32, error) + CloseStdin(context.Context) error } type Process interface { @@ -158,3 +159,11 @@ func (t *task) Processes(ctx context.Context) ([]uint32, error) { } return out, nil } + +func (t *task) CloseStdin(ctx context.Context) error { + _, err := t.client.TaskService().CloseStdin(ctx, &execution.CloseStdinRequest{ + ContainerID: t.containerID, + Pid: t.pid, + }) + return err +}