Add staticcheck linter

Fix issues with sync.Pool being passed an array and not a pointer.
See https://github.com/dominikh/go-tools/blob/master/cmd/staticcheck/docs/checks/SA6002

Add missing tests for content.Copy

Fix T.Fatal being called in a goroutine

Signed-off-by: Daniel Nephin <dnephin@gmail.com>
This commit is contained in:
Daniel Nephin 2017-11-27 12:22:22 -05:00
parent 2556c594ec
commit ee04cfa3f9
12 changed files with 165 additions and 63 deletions

View File

@ -13,6 +13,7 @@
"structcheck", "structcheck",
"unused", "unused",
"varcheck", "varcheck",
"staticcheck",
"gofmt", "gofmt",
"goimports", "goimports",

View File

@ -19,13 +19,12 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
var ( var bufferPool = &sync.Pool{
bufferPool = &sync.Pool{
New: func() interface{} { New: func() interface{} {
return make([]byte, 32*1024) buffer := make([]byte, 32*1024)
return &buffer
}, },
} }
)
// Diff returns a tar stream of the computed filesystem // Diff returns a tar stream of the computed filesystem
// difference between the provided directories. // difference between the provided directories.
@ -404,8 +403,8 @@ func (cw *changeWriter) HandleChange(k fs.ChangeKind, p string, f os.FileInfo, e
} }
defer file.Close() defer file.Close()
buf := bufferPool.Get().([]byte) buf := bufferPool.Get().(*[]byte)
n, err := io.CopyBuffer(cw.tw, file, buf) n, err := io.CopyBuffer(cw.tw, file, *buf)
bufferPool.Put(buf) bufferPool.Put(buf)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to copy") return errors.Wrap(err, "failed to copy")
@ -529,7 +528,7 @@ func createTarFile(ctx context.Context, path, extractDir string, hdr *tar.Header
} }
func copyBuffered(ctx context.Context, dst io.Writer, src io.Reader) (written int64, err error) { func copyBuffered(ctx context.Context, dst io.Writer, src io.Reader) (written int64, err error) {
buf := bufferPool.Get().([]byte) buf := bufferPool.Get().(*[]byte)
defer bufferPool.Put(buf) defer bufferPool.Put(buf)
for { for {
@ -540,9 +539,9 @@ func copyBuffered(ctx context.Context, dst io.Writer, src io.Reader) (written in
default: default:
} }
nr, er := src.Read(buf) nr, er := src.Read(*buf)
if nr > 0 { if nr > 0 {
nw, ew := dst.Write(buf[0:nr]) nw, ew := dst.Write((*buf)[0:nr])
if nw > 0 { if nw > 0 {
written += int64(nw) written += int64(nw)
} }

View File

@ -10,13 +10,12 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
var ( var bufPool = sync.Pool{
bufPool = sync.Pool{
New: func() interface{} { New: func() interface{} {
return make([]byte, 1<<20) buffer := make([]byte, 1<<20)
return &buffer
}, },
} }
)
// NewReader returns a io.Reader from a ReaderAt // NewReader returns a io.Reader from a ReaderAt
func NewReader(ra ReaderAt) io.Reader { func NewReader(ra ReaderAt) io.Reader {
@ -88,10 +87,10 @@ func Copy(ctx context.Context, cw Writer, r io.Reader, size int64, expected dige
} }
} }
buf := bufPool.Get().([]byte) buf := bufPool.Get().(*[]byte)
defer bufPool.Put(buf) defer bufPool.Put(buf)
if _, err := io.CopyBuffer(cw, r, buf); err != nil { if _, err := io.CopyBuffer(cw, r, *buf); err != nil {
return err return err
} }

112
content/helpers_test.go Normal file
View File

@ -0,0 +1,112 @@
package content
import (
"bytes"
"context"
"io"
"strings"
"testing"
"github.com/containerd/containerd/errdefs"
"github.com/opencontainers/go-digest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type copySource struct {
reader io.Reader
size int64
digest digest.Digest
}
func TestCopy(t *testing.T) {
defaultSource := newCopySource("this is the source to copy")
var testcases = []struct {
name string
source copySource
writer fakeWriter
expected string
}{
{
name: "copy no offset",
source: defaultSource,
writer: fakeWriter{},
expected: "this is the source to copy",
},
{
name: "copy with offset from seeker",
source: defaultSource,
writer: fakeWriter{status: Status{Offset: 8}},
expected: "the source to copy",
},
{
name: "copy with offset from unseekable source",
source: copySource{reader: bytes.NewBufferString("foo"), size: 3},
writer: fakeWriter{status: Status{Offset: 8}},
expected: "foo",
},
{
name: "commit already exists",
source: defaultSource,
writer: fakeWriter{commitFunc: func() error {
return errdefs.ErrAlreadyExists
}},
},
}
for _, testcase := range testcases {
t.Run(testcase.name, func(t *testing.T) {
err := Copy(context.Background(),
&testcase.writer,
testcase.source.reader,
testcase.source.size,
testcase.source.digest)
require.NoError(t, err)
assert.Equal(t, testcase.source.digest, testcase.writer.commitedDigest)
assert.Equal(t, testcase.expected, testcase.writer.String())
})
}
}
func newCopySource(raw string) copySource {
return copySource{
reader: strings.NewReader(raw),
size: int64(len(raw)),
digest: digest.FromBytes([]byte(raw)),
}
}
type fakeWriter struct {
bytes.Buffer
commitedDigest digest.Digest
status Status
commitFunc func() error
}
func (f *fakeWriter) Close() error {
f.Buffer.Reset()
return nil
}
func (f *fakeWriter) Commit(ctx context.Context, size int64, expected digest.Digest, opts ...Opt) error {
f.commitedDigest = expected
if f.commitFunc == nil {
return nil
}
return f.commitFunc()
}
func (f *fakeWriter) Digest() digest.Digest {
return f.commitedDigest
}
func (f *fakeWriter) Status() (Status, error) {
return f.status, nil
}
func (f *fakeWriter) Truncate(size int64) error {
f.Buffer.Truncate(int(size))
return nil
}

View File

@ -21,13 +21,12 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
var ( var bufPool = sync.Pool{
bufPool = sync.Pool{
New: func() interface{} { New: func() interface{} {
return make([]byte, 1<<20) buffer := make([]byte, 1<<20)
return &buffer
}, },
} }
)
// LabelStore is used to store mutable labels for digests // LabelStore is used to store mutable labels for digests
type LabelStore interface { type LabelStore interface {
@ -463,10 +462,10 @@ func (s *store) writer(ctx context.Context, ref string, total int64, expected di
} }
defer fp.Close() defer fp.Close()
p := bufPool.Get().([]byte) p := bufPool.Get().(*[]byte)
defer bufPool.Put(p) defer bufPool.Put(p)
offset, err = io.CopyBuffer(digester.Hash(), fp, p) offset, err = io.CopyBuffer(digester.Hash(), fp, *p)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -2,7 +2,6 @@ package exchange
import ( import (
"context" "context"
"fmt"
"reflect" "reflect"
"sync" "sync"
"testing" "testing"
@ -39,13 +38,14 @@ func TestExchangeBasic(t *testing.T) {
t.Log("publish") t.Log("publish")
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
errChan := make(chan error)
go func() { go func() {
defer wg.Done() defer wg.Done()
defer close(errChan)
for _, event := range testevents { for _, event := range testevents {
fmt.Println("publish", event)
if err := exchange.Publish(ctx, "/test", event); err != nil { if err := exchange.Publish(ctx, "/test", event); err != nil {
fmt.Println("publish error", err) errChan <- err
t.Fatal(err) return
} }
} }
@ -54,6 +54,9 @@ func TestExchangeBasic(t *testing.T) {
t.Log("waiting") t.Log("waiting")
wg.Wait() wg.Wait()
if err := <-errChan; err != nil {
t.Fatal(err)
}
for _, subscriber := range []struct { for _, subscriber := range []struct {
eventq <-chan *events.Envelope eventq <-chan *events.Envelope

View File

@ -9,13 +9,12 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
var ( var bufferPool = &sync.Pool{
bufferPool = &sync.Pool{
New: func() interface{} { New: func() interface{} {
return make([]byte, 32*1024) buffer := make([]byte, 32*1024)
return &buffer
}, },
} }
)
// CopyDir copies the directory from src to dst. // CopyDir copies the directory from src to dst.
// Most efficient copy of files is attempted. // Most efficient copy of files is attempted.

View File

@ -43,8 +43,8 @@ func copyFileContent(dst, src *os.File) error {
return errors.Wrap(err, "copy file range failed") return errors.Wrap(err, "copy file range failed")
} }
buf := bufferPool.Get().([]byte) buf := bufferPool.Get().(*[]byte)
_, err = io.CopyBuffer(dst, src, buf) _, err = io.CopyBuffer(dst, src, *buf)
bufferPool.Put(buf) bufferPool.Put(buf)
return err return err
} }

View File

@ -34,8 +34,8 @@ func copyFileInfo(fi os.FileInfo, name string) error {
} }
func copyFileContent(dst, src *os.File) error { func copyFileContent(dst, src *os.File) error {
buf := bufferPool.Get().([]byte) buf := bufferPool.Get().(*[]byte)
_, err := io.CopyBuffer(dst, src, buf) _, err := io.CopyBuffer(dst, src, *buf)
bufferPool.Put(buf) bufferPool.Put(buf)
return err return err

View File

@ -18,8 +18,8 @@ func copyFileInfo(fi os.FileInfo, name string) error {
} }
func copyFileContent(dst, src *os.File) error { func copyFileContent(dst, src *os.File) error {
buf := bufferPool.Get().([]byte) buf := bufferPool.Get().(*[]byte)
_, err := io.CopyBuffer(dst, src, buf) _, err := io.CopyBuffer(dst, src, *buf)
bufferPool.Put(buf) bufferPool.Put(buf)
return err return err
} }

View File

@ -69,14 +69,6 @@ func createInitLayer(ctx context.Context, parent, initName string, initFn func(s
if err != nil { if err != nil {
return "", err return "", err
} }
defer func() {
if err != nil {
// TODO: once implemented uncomment
//if rerr := snapshotter.Remove(ctx, td); rerr != nil {
// log.G(ctx).Errorf("Failed to remove snapshot %s: %v", td, merr)
//}
}
}()
if err = mounter.Mount(td, mounts...); err != nil { if err = mounter.Mount(td, mounts...); err != nil {
return "", err return "", err

View File

@ -28,7 +28,8 @@ type service struct {
var bufPool = sync.Pool{ var bufPool = sync.Pool{
New: func() interface{} { New: func() interface{} {
return make([]byte, 1<<20) buffer := make([]byte, 1<<20)
return &buffer
}, },
} }
@ -178,7 +179,7 @@ func (s *service) Read(req *api.ReadContentRequest, session api.Content_ReadServ
// TODO(stevvooe): Using the global buffer pool. At 32KB, it is probably // TODO(stevvooe): Using the global buffer pool. At 32KB, it is probably
// little inefficient for work over a fast network. We can tune this later. // little inefficient for work over a fast network. We can tune this later.
p = bufPool.Get().([]byte) p = bufPool.Get().(*[]byte)
) )
defer bufPool.Put(p) defer bufPool.Put(p)
@ -194,13 +195,10 @@ func (s *service) Read(req *api.ReadContentRequest, session api.Content_ReadServ
return grpc.Errorf(codes.OutOfRange, "read past object length %v bytes", oi.Size) return grpc.Errorf(codes.OutOfRange, "read past object length %v bytes", oi.Size)
} }
if _, err := io.CopyBuffer( _, err = io.CopyBuffer(
&readResponseWriter{session: session}, &readResponseWriter{session: session},
io.NewSectionReader(ra, offset, size), p); err != nil { io.NewSectionReader(ra, offset, size), *p)
return err return err
}
return nil
} }
// readResponseWriter is a writer that places the output into ReadContentRequest messages. // readResponseWriter is a writer that places the output into ReadContentRequest messages.