diff --git a/Godeps/Godeps.json b/Godeps/Godeps.json index 5f30ed59f74..2030531dc3a 100644 --- a/Godeps/Godeps.json +++ b/Godeps/Godeps.json @@ -358,6 +358,10 @@ "Comment": "v0.8.8", "Rev": "afde71eb1740fd763ab9450e1f700ba0e53c36d0" }, + { + "ImportPath": "github.com/kardianos/osext", + "Rev": "8fef92e41e22a70e700a96b29f066cda30ea24ef" + }, { "ImportPath": "github.com/kr/pty", "Comment": "release.r56-25-g05017fc", @@ -367,10 +371,18 @@ "ImportPath": "github.com/matttproud/golang_protobuf_extensions/pbutil", "Rev": "fc2b8d3a73c4867e51861bbdd5ae3c1f0869dd6a" }, + { + "ImportPath": "github.com/mesos/mesos-go/auth", + "Rev": "4b1767c0dfc51020e01f35da5b38472f40ce572a" + }, { "ImportPath": "github.com/mesos/mesos-go/detector", "Rev": "4b1767c0dfc51020e01f35da5b38472f40ce572a" }, + { + "ImportPath": "github.com/mesos/mesos-go/executor", + "Rev": "4b1767c0dfc51020e01f35da5b38472f40ce572a" + }, { "ImportPath": "github.com/mesos/mesos-go/mesosproto", "Rev": "4b1767c0dfc51020e01f35da5b38472f40ce572a" @@ -379,6 +391,14 @@ "ImportPath": "github.com/mesos/mesos-go/mesosutil", "Rev": "4b1767c0dfc51020e01f35da5b38472f40ce572a" }, + { + "ImportPath": "github.com/mesos/mesos-go/messenger", + "Rev": "4b1767c0dfc51020e01f35da5b38472f40ce572a" + }, + { + "ImportPath": "github.com/mesos/mesos-go/scheduler", + "Rev": "4b1767c0dfc51020e01f35da5b38472f40ce572a" + }, { "ImportPath": "github.com/mesos/mesos-go/upid", "Rev": "4b1767c0dfc51020e01f35da5b38472f40ce572a" diff --git a/Godeps/_workspace/src/github.com/kardianos/osext/LICENSE b/Godeps/_workspace/src/github.com/kardianos/osext/LICENSE new file mode 100644 index 00000000000..74487567632 --- /dev/null +++ b/Godeps/_workspace/src/github.com/kardianos/osext/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2012 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/Godeps/_workspace/src/github.com/kardianos/osext/README.md b/Godeps/_workspace/src/github.com/kardianos/osext/README.md new file mode 100644 index 00000000000..820e1ecb544 --- /dev/null +++ b/Godeps/_workspace/src/github.com/kardianos/osext/README.md @@ -0,0 +1,14 @@ +### Extensions to the "os" package. + +## Find the current Executable and ExecutableFolder. + +There is sometimes utility in finding the current executable file +that is running. This can be used for upgrading the current executable +or finding resources located relative to the executable file. + +Multi-platform and supports: + * Linux + * OS X + * Windows + * Plan 9 + * BSDs. diff --git a/Godeps/_workspace/src/github.com/kardianos/osext/osext.go b/Godeps/_workspace/src/github.com/kardianos/osext/osext.go new file mode 100644 index 00000000000..4ed4b9aa334 --- /dev/null +++ b/Godeps/_workspace/src/github.com/kardianos/osext/osext.go @@ -0,0 +1,27 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Extensions to the standard "os" package. +package osext + +import "path/filepath" + +// Executable returns an absolute path that can be used to +// re-invoke the current program. +// It may not be valid after the current program exits. +func Executable() (string, error) { + p, err := executable() + return filepath.Clean(p), err +} + +// Returns same path as Executable, returns just the folder +// path. Excludes the executable name. +func ExecutableFolder() (string, error) { + p, err := Executable() + if err != nil { + return "", err + } + folder, _ := filepath.Split(p) + return folder, nil +} diff --git a/Godeps/_workspace/src/github.com/kardianos/osext/osext_plan9.go b/Godeps/_workspace/src/github.com/kardianos/osext/osext_plan9.go new file mode 100644 index 00000000000..655750c5426 --- /dev/null +++ b/Godeps/_workspace/src/github.com/kardianos/osext/osext_plan9.go @@ -0,0 +1,20 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package osext + +import ( + "os" + "strconv" + "syscall" +) + +func executable() (string, error) { + f, err := os.Open("/proc/" + strconv.Itoa(os.Getpid()) + "/text") + if err != nil { + return "", err + } + defer f.Close() + return syscall.Fd2path(int(f.Fd())) +} diff --git a/Godeps/_workspace/src/github.com/kardianos/osext/osext_procfs.go b/Godeps/_workspace/src/github.com/kardianos/osext/osext_procfs.go new file mode 100644 index 00000000000..b2598bc77a4 --- /dev/null +++ b/Godeps/_workspace/src/github.com/kardianos/osext/osext_procfs.go @@ -0,0 +1,36 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build linux netbsd openbsd solaris dragonfly + +package osext + +import ( + "errors" + "fmt" + "os" + "runtime" + "strings" +) + +func executable() (string, error) { + switch runtime.GOOS { + case "linux": + const deletedTag = " (deleted)" + execpath, err := os.Readlink("/proc/self/exe") + if err != nil { + return execpath, err + } + execpath = strings.TrimSuffix(execpath, deletedTag) + execpath = strings.TrimPrefix(execpath, deletedTag) + return execpath, nil + case "netbsd": + return os.Readlink("/proc/curproc/exe") + case "openbsd", "dragonfly": + return os.Readlink("/proc/curproc/file") + case "solaris": + return os.Readlink(fmt.Sprintf("/proc/%d/path/a.out", os.Getpid())) + } + return "", errors.New("ExecPath not implemented for " + runtime.GOOS) +} diff --git a/Godeps/_workspace/src/github.com/kardianos/osext/osext_sysctl.go b/Godeps/_workspace/src/github.com/kardianos/osext/osext_sysctl.go new file mode 100644 index 00000000000..b66cac878c4 --- /dev/null +++ b/Godeps/_workspace/src/github.com/kardianos/osext/osext_sysctl.go @@ -0,0 +1,79 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin freebsd + +package osext + +import ( + "os" + "path/filepath" + "runtime" + "syscall" + "unsafe" +) + +var initCwd, initCwdErr = os.Getwd() + +func executable() (string, error) { + var mib [4]int32 + switch runtime.GOOS { + case "freebsd": + mib = [4]int32{1 /* CTL_KERN */, 14 /* KERN_PROC */, 12 /* KERN_PROC_PATHNAME */, -1} + case "darwin": + mib = [4]int32{1 /* CTL_KERN */, 38 /* KERN_PROCARGS */, int32(os.Getpid()), -1} + } + + n := uintptr(0) + // Get length. + _, _, errNum := syscall.Syscall6(syscall.SYS___SYSCTL, uintptr(unsafe.Pointer(&mib[0])), 4, 0, uintptr(unsafe.Pointer(&n)), 0, 0) + if errNum != 0 { + return "", errNum + } + if n == 0 { // This shouldn't happen. + return "", nil + } + buf := make([]byte, n) + _, _, errNum = syscall.Syscall6(syscall.SYS___SYSCTL, uintptr(unsafe.Pointer(&mib[0])), 4, uintptr(unsafe.Pointer(&buf[0])), uintptr(unsafe.Pointer(&n)), 0, 0) + if errNum != 0 { + return "", errNum + } + if n == 0 { // This shouldn't happen. + return "", nil + } + for i, v := range buf { + if v == 0 { + buf = buf[:i] + break + } + } + var err error + execPath := string(buf) + // execPath will not be empty due to above checks. + // Try to get the absolute path if the execPath is not rooted. + if execPath[0] != '/' { + execPath, err = getAbs(execPath) + if err != nil { + return execPath, err + } + } + // For darwin KERN_PROCARGS may return the path to a symlink rather than the + // actual executable. + if runtime.GOOS == "darwin" { + if execPath, err = filepath.EvalSymlinks(execPath); err != nil { + return execPath, err + } + } + return execPath, nil +} + +func getAbs(execPath string) (string, error) { + if initCwdErr != nil { + return execPath, initCwdErr + } + // The execPath may begin with a "../" or a "./" so clean it first. + // Join the two paths, trailing and starting slashes undetermined, so use + // the generic Join function. + return filepath.Join(initCwd, filepath.Clean(execPath)), nil +} diff --git a/Godeps/_workspace/src/github.com/kardianos/osext/osext_test.go b/Godeps/_workspace/src/github.com/kardianos/osext/osext_test.go new file mode 100644 index 00000000000..5aafa3af2d2 --- /dev/null +++ b/Godeps/_workspace/src/github.com/kardianos/osext/osext_test.go @@ -0,0 +1,180 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin linux freebsd netbsd windows + +package osext + +import ( + "bytes" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "runtime" + "testing" +) + +const ( + executableEnvVar = "OSTEST_OUTPUT_EXECUTABLE" + + executableEnvValueMatch = "match" + executableEnvValueDelete = "delete" +) + +func TestExecutableMatch(t *testing.T) { + ep, err := Executable() + if err != nil { + t.Fatalf("Executable failed: %v", err) + } + + // fullpath to be of the form "dir/prog". + dir := filepath.Dir(filepath.Dir(ep)) + fullpath, err := filepath.Rel(dir, ep) + if err != nil { + t.Fatalf("filepath.Rel: %v", err) + } + // Make child start with a relative program path. + // Alter argv[0] for child to verify getting real path without argv[0]. + cmd := &exec.Cmd{ + Dir: dir, + Path: fullpath, + Env: []string{fmt.Sprintf("%s=%s", executableEnvVar, executableEnvValueMatch)}, + } + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("exec(self) failed: %v", err) + } + outs := string(out) + if !filepath.IsAbs(outs) { + t.Fatalf("Child returned %q, want an absolute path", out) + } + if !sameFile(outs, ep) { + t.Fatalf("Child returned %q, not the same file as %q", out, ep) + } +} + +func TestExecutableDelete(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip() + } + fpath, err := Executable() + if err != nil { + t.Fatalf("Executable failed: %v", err) + } + + r, w := io.Pipe() + stderrBuff := &bytes.Buffer{} + stdoutBuff := &bytes.Buffer{} + cmd := &exec.Cmd{ + Path: fpath, + Env: []string{fmt.Sprintf("%s=%s", executableEnvVar, executableEnvValueDelete)}, + Stdin: r, + Stderr: stderrBuff, + Stdout: stdoutBuff, + } + err = cmd.Start() + if err != nil { + t.Fatalf("exec(self) start failed: %v", err) + } + + tempPath := fpath + "_copy" + _ = os.Remove(tempPath) + + err = copyFile(tempPath, fpath) + if err != nil { + t.Fatalf("copy file failed: %v", err) + } + err = os.Remove(fpath) + if err != nil { + t.Fatalf("remove running test file failed: %v", err) + } + err = os.Rename(tempPath, fpath) + if err != nil { + t.Fatalf("rename copy to previous name failed: %v", err) + } + + w.Write([]byte{0}) + w.Close() + + err = cmd.Wait() + if err != nil { + t.Fatalf("exec wait failed: %v", err) + } + + childPath := stderrBuff.String() + if !filepath.IsAbs(childPath) { + t.Fatalf("Child returned %q, want an absolute path", childPath) + } + if !sameFile(childPath, fpath) { + t.Fatalf("Child returned %q, not the same file as %q", childPath, fpath) + } +} + +func sameFile(fn1, fn2 string) bool { + fi1, err := os.Stat(fn1) + if err != nil { + return false + } + fi2, err := os.Stat(fn2) + if err != nil { + return false + } + return os.SameFile(fi1, fi2) +} +func copyFile(dest, src string) error { + df, err := os.Create(dest) + if err != nil { + return err + } + defer df.Close() + + sf, err := os.Open(src) + if err != nil { + return err + } + defer sf.Close() + + _, err = io.Copy(df, sf) + return err +} + +func TestMain(m *testing.M) { + env := os.Getenv(executableEnvVar) + switch env { + case "": + os.Exit(m.Run()) + case executableEnvValueMatch: + // First chdir to another path. + dir := "/" + if runtime.GOOS == "windows" { + dir = filepath.VolumeName(".") + } + os.Chdir(dir) + if ep, err := Executable(); err != nil { + fmt.Fprint(os.Stderr, "ERROR: ", err) + } else { + fmt.Fprint(os.Stderr, ep) + } + case executableEnvValueDelete: + bb := make([]byte, 1) + var err error + n, err := os.Stdin.Read(bb) + if err != nil { + fmt.Fprint(os.Stderr, "ERROR: ", err) + os.Exit(2) + } + if n != 1 { + fmt.Fprint(os.Stderr, "ERROR: n != 1, n == ", n) + os.Exit(2) + } + if ep, err := Executable(); err != nil { + fmt.Fprint(os.Stderr, "ERROR: ", err) + } else { + fmt.Fprint(os.Stderr, ep) + } + } + os.Exit(0) +} diff --git a/Godeps/_workspace/src/github.com/kardianos/osext/osext_windows.go b/Godeps/_workspace/src/github.com/kardianos/osext/osext_windows.go new file mode 100644 index 00000000000..72d282cf8c0 --- /dev/null +++ b/Godeps/_workspace/src/github.com/kardianos/osext/osext_windows.go @@ -0,0 +1,34 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package osext + +import ( + "syscall" + "unicode/utf16" + "unsafe" +) + +var ( + kernel = syscall.MustLoadDLL("kernel32.dll") + getModuleFileNameProc = kernel.MustFindProc("GetModuleFileNameW") +) + +// GetModuleFileName() with hModule = NULL +func executable() (exePath string, err error) { + return getModuleFileName() +} + +func getModuleFileName() (string, error) { + var n uint32 + b := make([]uint16, syscall.MAX_PATH) + size := uint32(len(b)) + + r0, _, e1 := getModuleFileNameProc.Call(0, uintptr(unsafe.Pointer(&b[0])), uintptr(size)) + n = uint32(r0) + if n == 0 { + return "", e1 + } + return string(utf16.Decode(b[0:n])), nil +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/callback/interface.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/callback/interface.go new file mode 100644 index 00000000000..d870fd3c729 --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/callback/interface.go @@ -0,0 +1,28 @@ +package callback + +import ( + "fmt" +) + +type Unsupported struct { + Callback Interface +} + +func (uc *Unsupported) Error() string { + return fmt.Sprintf("Unsupported callback <%T>: %v", uc.Callback, uc.Callback) +} + +type Interface interface { + // marker interface +} + +type Handler interface { + // may return an Unsupported error on failure + Handle(callbacks ...Interface) error +} + +type HandlerFunc func(callbacks ...Interface) error + +func (f HandlerFunc) Handle(callbacks ...Interface) error { + return f(callbacks...) +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/callback/interprocess.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/callback/interprocess.go new file mode 100644 index 00000000000..d9b389cafc0 --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/callback/interprocess.go @@ -0,0 +1,27 @@ +package callback + +import ( + "github.com/mesos/mesos-go/upid" +) + +type Interprocess struct { + client upid.UPID + server upid.UPID +} + +func NewInterprocess() *Interprocess { + return &Interprocess{} +} + +func (cb *Interprocess) Client() upid.UPID { + return cb.client +} + +func (cb *Interprocess) Server() upid.UPID { + return cb.server +} + +func (cb *Interprocess) Set(server, client upid.UPID) { + cb.server = server + cb.client = client +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/callback/name.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/callback/name.go new file mode 100644 index 00000000000..246020a9f2a --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/callback/name.go @@ -0,0 +1,17 @@ +package callback + +type Name struct { + name string +} + +func NewName() *Name { + return &Name{} +} + +func (cb *Name) Get() string { + return cb.name +} + +func (cb *Name) Set(name string) { + cb.name = name +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/callback/password.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/callback/password.go new file mode 100644 index 00000000000..6beadd07b42 --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/callback/password.go @@ -0,0 +1,20 @@ +package callback + +type Password struct { + password []byte +} + +func NewPassword() *Password { + return &Password{} +} + +func (cb *Password) Get() []byte { + clone := make([]byte, len(cb.password)) + copy(clone, cb.password) + return clone +} + +func (cb *Password) Set(password []byte) { + cb.password = make([]byte, len(password)) + copy(cb.password, password) +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/interface.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/interface.go new file mode 100644 index 00000000000..94420f5af26 --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/interface.go @@ -0,0 +1,63 @@ +package auth + +import ( + "errors" + "fmt" + "sync" + + log "github.com/golang/glog" + "github.com/mesos/mesos-go/auth/callback" + "golang.org/x/net/context" +) + +// SPI interface: login provider implementations support this interface, clients +// do not authenticate against this directly, instead they should use Login() +type Authenticatee interface { + // Returns no errors if successfully authenticated, otherwise a single + // error. + Authenticate(ctx context.Context, handler callback.Handler) error +} + +// Func adapter for interface: allow func's to implement the Authenticatee interface +// as long as the func signature matches +type AuthenticateeFunc func(ctx context.Context, handler callback.Handler) error + +func (f AuthenticateeFunc) Authenticate(ctx context.Context, handler callback.Handler) error { + return f(ctx, handler) +} + +var ( + // Authentication was attempted and failed (likely due to incorrect credentials, too + // many retries within a time window, etc). Distinctly different from authentication + // errors (e.g. network errors, configuration errors, etc). + AuthenticationFailed = errors.New("authentication failed") + + authenticateeProviders = make(map[string]Authenticatee) // authentication providers dict + providerLock sync.Mutex +) + +// Register an authentication provider (aka "login provider"). packages that +// provide Authenticatee implementations should invoke this func in their +// init() to register. +func RegisterAuthenticateeProvider(name string, auth Authenticatee) (err error) { + providerLock.Lock() + defer providerLock.Unlock() + + if _, found := authenticateeProviders[name]; found { + err = fmt.Errorf("authentication provider already registered: %v", name) + } else { + authenticateeProviders[name] = auth + log.V(1).Infof("registered authentication provider: %v", name) + } + return +} + +// Look up an authentication provider by name, returns non-nil and true if such +// a provider is found. +func getAuthenticateeProvider(name string) (provider Authenticatee, ok bool) { + providerLock.Lock() + defer providerLock.Unlock() + + provider, ok = authenticateeProviders[name] + return +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/login.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/login.go new file mode 100644 index 00000000000..416c2d61274 --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/login.go @@ -0,0 +1,80 @@ +package auth + +import ( + "errors" + "fmt" + + "github.com/mesos/mesos-go/auth/callback" + "github.com/mesos/mesos-go/upid" + "golang.org/x/net/context" +) + +var ( + // No login provider name has been specified in a context.Context + NoLoginProviderName = errors.New("missing login provider name in context") +) + +// Main client entrypoint into the authentication APIs: clients are expected to +// invoke this func with a context containing a login provider name value. +// This may be written as: +// providerName := ... // the user has probably configured this via some flag +// handler := ... // handlers provide data like usernames and passwords +// ctx := ... // obtain some initial or timed context +// err := auth.Login(auth.WithLoginProvider(ctx, providerName), handler) +func Login(ctx context.Context, handler callback.Handler) error { + name, ok := LoginProviderFrom(ctx) + if !ok { + return NoLoginProviderName + } + provider, ok := getAuthenticateeProvider(name) + if !ok { + return fmt.Errorf("unrecognized login provider name in context: %s", name) + } + return provider.Authenticate(ctx, handler) +} + +// Unexported key type, avoids conflicts with other context-using packages. All +// context items registered from this package should use keys of this type. +type loginKeyType int + +const ( + loginProviderNameKey loginKeyType = iota // name of login provider to use + parentUpidKey // upid.UPID of some parent process +) + +// Return a context that inherits all values from the parent ctx and specifies +// the login provider name given here. Intended to be invoked before calls to +// Login(). +func WithLoginProvider(ctx context.Context, providerName string) context.Context { + return context.WithValue(ctx, loginProviderNameKey, providerName) +} + +// Return the name of the login provider specified in this context. +func LoginProviderFrom(ctx context.Context) (name string, ok bool) { + name, ok = ctx.Value(loginProviderNameKey).(string) + return +} + +// Return the name of the login provider specified in this context, or empty +// string if none. +func LoginProvider(ctx context.Context) string { + name, _ := LoginProviderFrom(ctx) + return name +} + +func WithParentUPID(ctx context.Context, pid upid.UPID) context.Context { + return context.WithValue(ctx, parentUpidKey, pid) +} + +func ParentUPIDFrom(ctx context.Context) (pid upid.UPID, ok bool) { + pid, ok = ctx.Value(parentUpidKey).(upid.UPID) + return +} + +func ParentUPID(ctx context.Context) (upid *upid.UPID) { + if upid, ok := ParentUPIDFrom(ctx); ok { + return &upid + } else { + return nil + } +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/sasl/authenticatee.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/sasl/authenticatee.go new file mode 100644 index 00000000000..3d60bdb814a --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/sasl/authenticatee.go @@ -0,0 +1,358 @@ +package sasl + +import ( + "errors" + "fmt" + "sync/atomic" + + "github.com/gogo/protobuf/proto" + log "github.com/golang/glog" + "github.com/mesos/mesos-go/auth" + "github.com/mesos/mesos-go/auth/callback" + "github.com/mesos/mesos-go/auth/sasl/mech" + mesos "github.com/mesos/mesos-go/mesosproto" + "github.com/mesos/mesos-go/mesosutil/process" + "github.com/mesos/mesos-go/messenger" + "github.com/mesos/mesos-go/upid" + "golang.org/x/net/context" +) + +var ( + UnexpectedAuthenticationMechanisms = errors.New("Unexpected authentication 'mechanisms' received") + UnexpectedAuthenticationStep = errors.New("Unexpected authentication 'step' received") + UnexpectedAuthenticationCompleted = errors.New("Unexpected authentication 'completed' received") + UnexpectedAuthenticatorPid = errors.New("Unexpected authentator pid") // authenticator pid changed mid-process + UnsupportedMechanism = errors.New("failed to identify a compatible mechanism") +) + +type statusType int32 + +const ( + statusReady statusType = iota + statusStarting + statusStepping + _statusTerminal // meta status, should never be assigned: all status types following are "terminal" + statusCompleted + statusFailed + statusError + statusDiscarded + + // this login provider name is automatically registered with the auth package; see init() + ProviderName = "SASL" +) + +type authenticateeProcess struct { + transport messenger.Messenger + client upid.UPID + status statusType + done chan struct{} + err error + mech mech.Interface + stepFn mech.StepFunc + from *upid.UPID + handler callback.Handler +} + +type authenticateeConfig struct { + client upid.UPID // pid of the client we're attempting to authenticate + handler callback.Handler + transport messenger.Messenger // mesos communications transport +} + +type transportFactory interface { + makeTransport() messenger.Messenger +} + +type transportFactoryFunc func() messenger.Messenger + +func (f transportFactoryFunc) makeTransport() messenger.Messenger { + return f() +} + +func init() { + factory := func(ctx context.Context) transportFactoryFunc { + return transportFactoryFunc(func() messenger.Messenger { + parent := auth.ParentUPID(ctx) + if parent == nil { + log.Fatal("expected to have a parent UPID in context") + } + process := process.New("sasl_authenticatee") + tpid := &upid.UPID{ + ID: process.Label(), + Host: parent.Host, + } + return messenger.NewHttpWithBindingAddress(tpid, BindingAddressFrom(ctx)) + }) + } + delegate := auth.AuthenticateeFunc(func(ctx context.Context, handler callback.Handler) error { + if impl, err := makeAuthenticatee(handler, factory(ctx)); err != nil { + return err + } else { + return impl.Authenticate(ctx, handler) + } + }) + if err := auth.RegisterAuthenticateeProvider(ProviderName, delegate); err != nil { + log.Error(err) + } +} + +func (s *statusType) get() statusType { + return statusType(atomic.LoadInt32((*int32)(s))) +} + +func (s *statusType) swap(old, new statusType) bool { + return old != new && atomic.CompareAndSwapInt32((*int32)(s), int32(old), int32(new)) +} + +// build a new authenticatee implementation using the given callbacks and a new transport instance +func makeAuthenticatee(handler callback.Handler, factory transportFactory) (auth.Authenticatee, error) { + + ip := callback.NewInterprocess() + if err := handler.Handle(ip); err != nil { + return nil, err + } + config := &authenticateeConfig{ + client: ip.Client(), + handler: handler, + transport: factory.makeTransport(), + } + return auth.AuthenticateeFunc(func(ctx context.Context, handler callback.Handler) error { + ctx, auth := newAuthenticatee(ctx, config) + auth.authenticate(ctx, ip.Server()) + + select { + case <-ctx.Done(): + return auth.discard(ctx) + case <-auth.done: + return auth.err + } + }), nil +} + +// Terminate the authentication process upon context cancellation; +// only to be called if/when ctx.Done() has been signalled. +func (self *authenticateeProcess) discard(ctx context.Context) error { + err := ctx.Err() + status := statusFrom(ctx) + for ; status < _statusTerminal; status = (&self.status).get() { + if self.terminate(status, statusDiscarded, err) { + break + } + } + return err +} + +func newAuthenticatee(ctx context.Context, config *authenticateeConfig) (context.Context, *authenticateeProcess) { + initialStatus := statusReady + proc := &authenticateeProcess{ + transport: config.transport, + client: config.client, + handler: config.handler, + status: initialStatus, + done: make(chan struct{}), + } + ctx = withStatus(ctx, initialStatus) + err := proc.installHandlers(ctx) + if err == nil { + err = proc.startTransport() + } + if err != nil { + proc.terminate(initialStatus, statusError, err) + } + return ctx, proc +} + +func (self *authenticateeProcess) startTransport() error { + if err := self.transport.Start(); err != nil { + return err + } else { + go func() { + // stop the authentication transport upon termination of the + // authenticator process + select { + case <-self.done: + log.V(2).Infof("stopping authenticator transport: %v", self.transport.UPID()) + self.transport.Stop() + } + }() + } + return nil +} + +// returns true when handlers are installed without error, otherwise terminates the +// authentication process. +func (self *authenticateeProcess) installHandlers(ctx context.Context) error { + + type handlerFn func(ctx context.Context, from *upid.UPID, pbMsg proto.Message) + + withContext := func(f handlerFn) messenger.MessageHandler { + return func(from *upid.UPID, m proto.Message) { + status := (&self.status).get() + if self.from != nil && !self.from.Equal(from) { + self.terminate(status, statusError, UnexpectedAuthenticatorPid) + } else { + f(withStatus(ctx, status), from, m) + } + } + } + + // Anticipate mechanisms and steps from the server + handlers := []struct { + f handlerFn + m proto.Message + }{ + {self.mechanisms, &mesos.AuthenticationMechanismsMessage{}}, + {self.step, &mesos.AuthenticationStepMessage{}}, + {self.completed, &mesos.AuthenticationCompletedMessage{}}, + {self.failed, &mesos.AuthenticationFailedMessage{}}, + {self.errored, &mesos.AuthenticationErrorMessage{}}, + } + for _, h := range handlers { + if err := self.transport.Install(withContext(h.f), h.m); err != nil { + return err + } + } + return nil +} + +// return true if the authentication status was updated (if true, self.done will have been closed) +func (self *authenticateeProcess) terminate(old, new statusType, err error) bool { + if (&self.status).swap(old, new) { + self.err = err + if self.mech != nil { + self.mech.Discard() + } + close(self.done) + return true + } + return false +} + +func (self *authenticateeProcess) authenticate(ctx context.Context, pid upid.UPID) { + status := statusFrom(ctx) + if status != statusReady { + return + } + message := &mesos.AuthenticateMessage{ + Pid: proto.String(self.client.String()), + } + if err := self.transport.Send(ctx, &pid, message); err != nil { + self.terminate(status, statusError, err) + } else { + (&self.status).swap(status, statusStarting) + } +} + +func (self *authenticateeProcess) mechanisms(ctx context.Context, from *upid.UPID, pbMsg proto.Message) { + status := statusFrom(ctx) + if status != statusStarting { + self.terminate(status, statusError, UnexpectedAuthenticationMechanisms) + return + } + + msg, ok := pbMsg.(*mesos.AuthenticationMechanismsMessage) + if !ok { + self.terminate(status, statusError, fmt.Errorf("Expected AuthenticationMechanismsMessage, not %T", pbMsg)) + return + } + + mechanisms := msg.GetMechanisms() + log.Infof("Received SASL authentication mechanisms: %v", mechanisms) + + selectedMech, factory := mech.SelectSupported(mechanisms) + if selectedMech == "" { + self.terminate(status, statusError, UnsupportedMechanism) + return + } + + if m, f, err := factory(self.handler); err != nil { + self.terminate(status, statusError, err) + return + } else { + self.mech = m + self.stepFn = f + self.from = from + } + + // execute initialization step... + nextf, data, err := self.stepFn(self.mech, nil) + if err != nil { + self.terminate(status, statusError, err) + return + } else { + self.stepFn = nextf + } + + message := &mesos.AuthenticationStartMessage{ + Mechanism: proto.String(selectedMech), + Data: proto.String(string(data)), // may be nil, depends on init step + } + + if err := self.transport.Send(ctx, from, message); err != nil { + self.terminate(status, statusError, err) + } else { + (&self.status).swap(status, statusStepping) + } +} + +func (self *authenticateeProcess) step(ctx context.Context, from *upid.UPID, pbMsg proto.Message) { + status := statusFrom(ctx) + if status != statusStepping { + self.terminate(status, statusError, UnexpectedAuthenticationStep) + return + } + + log.Info("Received SASL authentication step") + + msg, ok := pbMsg.(*mesos.AuthenticationStepMessage) + if !ok { + self.terminate(status, statusError, fmt.Errorf("Expected AuthenticationStepMessage, not %T", pbMsg)) + return + } + + input := msg.GetData() + fn, output, err := self.stepFn(self.mech, input) + + if err != nil { + self.terminate(status, statusError, fmt.Errorf("failed to perform authentication step: %v", err)) + return + } + self.stepFn = fn + + // We don't start the client with SASL_SUCCESS_DATA so we may + // need to send one more "empty" message to the server. + message := &mesos.AuthenticationStepMessage{} + if len(output) > 0 { + message.Data = output + } + if err := self.transport.Send(ctx, from, message); err != nil { + self.terminate(status, statusError, err) + } +} + +func (self *authenticateeProcess) completed(ctx context.Context, from *upid.UPID, pbMsg proto.Message) { + status := statusFrom(ctx) + if status != statusStepping { + self.terminate(status, statusError, UnexpectedAuthenticationCompleted) + return + } + + log.Info("Authentication success") + self.terminate(status, statusCompleted, nil) +} + +func (self *authenticateeProcess) failed(ctx context.Context, from *upid.UPID, pbMsg proto.Message) { + status := statusFrom(ctx) + self.terminate(status, statusFailed, auth.AuthenticationFailed) +} + +func (self *authenticateeProcess) errored(ctx context.Context, from *upid.UPID, pbMsg proto.Message) { + var err error + if msg, ok := pbMsg.(*mesos.AuthenticationErrorMessage); !ok { + err = fmt.Errorf("Expected AuthenticationErrorMessage, not %T", pbMsg) + } else { + err = fmt.Errorf("Authentication error: %s", msg.GetError()) + } + status := statusFrom(ctx) + self.terminate(status, statusError, err) +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/sasl/authenticatee_test.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/sasl/authenticatee_test.go new file mode 100644 index 00000000000..9fd37b6fb96 --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/sasl/authenticatee_test.go @@ -0,0 +1,98 @@ +package sasl + +import ( + "testing" + "time" + + "github.com/gogo/protobuf/proto" + "github.com/mesos/mesos-go/auth/callback" + "github.com/mesos/mesos-go/auth/sasl/mech/crammd5" + mesos "github.com/mesos/mesos-go/mesosproto" + "github.com/mesos/mesos-go/messenger" + "github.com/mesos/mesos-go/upid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "golang.org/x/net/context" +) + +type MockTransport struct { + *messenger.MockedMessenger +} + +func (m *MockTransport) Send(ctx context.Context, upid *upid.UPID, msg proto.Message) error { + return m.Called(mock.Anything, upid, msg).Error(0) +} + +func TestAuthticatee_validLogin(t *testing.T) { + assert := assert.New(t) + ctx := context.TODO() + client := upid.UPID{ + ID: "someFramework", + Host: "b.net", + Port: "789", + } + server := upid.UPID{ + ID: "serv", + Host: "a.com", + Port: "123", + } + tpid := upid.UPID{ + ID: "sasl_transport", + Host: "g.org", + Port: "456", + } + handler := callback.HandlerFunc(func(cb ...callback.Interface) error { + for _, c := range cb { + switch c := c.(type) { + case *callback.Name: + c.Set("foo") + case *callback.Password: + c.Set([]byte("bar")) + case *callback.Interprocess: + c.Set(server, client) + default: + return &callback.Unsupported{Callback: c} + } + } + return nil + }) + var transport *MockTransport + factory := transportFactoryFunc(func() messenger.Messenger { + transport = &MockTransport{messenger.NewMockedMessenger()} + transport.On("Install").Return(nil) + transport.On("UPID").Return(&tpid) + transport.On("Start").Return(nil) + transport.On("Stop").Return(nil) + transport.On("Send", mock.Anything, &server, &mesos.AuthenticateMessage{ + Pid: proto.String(client.String()), + }).Return(nil).Once() + + transport.On("Send", mock.Anything, &server, &mesos.AuthenticationStartMessage{ + Mechanism: proto.String(crammd5.Name), + Data: proto.String(""), // may be nil, depends on init step + }).Return(nil).Once() + + transport.On("Send", mock.Anything, &server, &mesos.AuthenticationStepMessage{ + Data: []byte(`foo cc7fd96cd80123ea844a7dba29a594ed`), + }).Return(nil).Once() + + go func() { + transport.Recv(&server, &mesos.AuthenticationMechanismsMessage{ + Mechanisms: []string{crammd5.Name}, + }) + transport.Recv(&server, &mesos.AuthenticationStepMessage{ + Data: []byte(`lsd;lfkgjs;dlfkgjs;dfklg`), + }) + transport.Recv(&server, &mesos.AuthenticationCompletedMessage{}) + }() + return transport + }) + login, err := makeAuthenticatee(handler, factory) + assert.Nil(err) + + err = login.Authenticate(ctx, handler) + assert.Nil(err) + assert.NotNil(transport) + time.Sleep(1 * time.Second) // wait for the authenticator to shut down + transport.AssertExpectations(t) +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/sasl/context.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/sasl/context.go new file mode 100644 index 00000000000..8058ac34e6c --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/sasl/context.go @@ -0,0 +1,43 @@ +package sasl + +import ( + "net" + + "golang.org/x/net/context" +) + +// unexported to prevent collisions with context keys defined in +// other packages. +type _key int + +// If this package defined other context keys, they would have +// different integer values. +const ( + statusKey _key = iota + bindingAddressKey // bind address for login-related network ops +) + +func withStatus(ctx context.Context, s statusType) context.Context { + return context.WithValue(ctx, statusKey, s) +} + +func statusFrom(ctx context.Context) statusType { + s, ok := ctx.Value(statusKey).(statusType) + if !ok { + panic("missing status in context") + } + return s +} + +func WithBindingAddress(ctx context.Context, address net.IP) context.Context { + return context.WithValue(ctx, bindingAddressKey, address) +} + +func BindingAddressFrom(ctx context.Context) net.IP { + obj := ctx.Value(bindingAddressKey) + if addr, ok := obj.(net.IP); ok { + return addr + } else { + return nil + } +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/sasl/mech/crammd5/mechanism.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/sasl/mech/crammd5/mechanism.go new file mode 100644 index 00000000000..d6b4dafa155 --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/sasl/mech/crammd5/mechanism.go @@ -0,0 +1,72 @@ +package crammd5 + +import ( + "crypto/hmac" + "crypto/md5" + "encoding/hex" + "errors" + "io" + + log "github.com/golang/glog" + "github.com/mesos/mesos-go/auth/callback" + "github.com/mesos/mesos-go/auth/sasl/mech" +) + +var ( + Name = "CRAM-MD5" // name this mechanism is registered with + + //TODO(jdef) is this a generic SASL error? if so, move it up to mech + challengeDataRequired = errors.New("challenge data may not be empty") +) + +func init() { + mech.Register(Name, newInstance) +} + +type mechanism struct { + handler callback.Handler +} + +func (m *mechanism) Handler() callback.Handler { + return m.handler +} + +func (m *mechanism) Discard() { + // noop +} + +func newInstance(h callback.Handler) (mech.Interface, mech.StepFunc, error) { + m := &mechanism{ + handler: h, + } + fn := func(m mech.Interface, data []byte) (mech.StepFunc, []byte, error) { + // noop: no initialization needed + return challengeResponse, nil, nil + } + return m, fn, nil +} + +// algorithm lifted from wikipedia: http://en.wikipedia.org/wiki/CRAM-MD5 +// except that the SASL mechanism used by Mesos doesn't leverage base64 encoding +func challengeResponse(m mech.Interface, data []byte) (mech.StepFunc, []byte, error) { + if len(data) == 0 { + return mech.IllegalState, nil, challengeDataRequired + } + decoded := string(data) + log.V(4).Infof("challenge(decoded): %s", decoded) // for deep debugging only + + username := callback.NewName() + secret := callback.NewPassword() + + if err := m.Handler().Handle(username, secret); err != nil { + return mech.IllegalState, nil, err + } + hash := hmac.New(md5.New, secret.Get()) + if _, err := io.WriteString(hash, decoded); err != nil { + return mech.IllegalState, nil, err + } + + codes := hex.EncodeToString(hash.Sum(nil)) + msg := username.Get() + " " + codes + return nil, []byte(msg), nil +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/sasl/mech/interface.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/sasl/mech/interface.go new file mode 100644 index 00000000000..56b53bf56a5 --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/sasl/mech/interface.go @@ -0,0 +1,33 @@ +package mech + +import ( + "errors" + + "github.com/mesos/mesos-go/auth/callback" +) + +var ( + IllegalStateErr = errors.New("illegal mechanism state") +) + +type Interface interface { + Handler() callback.Handler + Discard() // clean up resources or sensitive information; idempotent +} + +// return a mechanism and it's initialization step (may be a noop that returns +// a nil data blob and handle to the first "real" challenge step). +type Factory func(h callback.Handler) (Interface, StepFunc, error) + +// StepFunc implementations should never return a nil StepFunc result. This +// helps keep the logic in the SASL authticatee simpler: step functions are +// never nil. Mechanisms that end up an error state (for example, some decoding +// logic fails...) should return a StepFunc that represents an error state. +// Some mechanisms may be able to recover from such. +type StepFunc func(m Interface, data []byte) (StepFunc, []byte, error) + +// reflects an unrecoverable, illegal mechanism state; always returns IllegalState +// as the next step along with an IllegalStateErr +func IllegalState(m Interface, data []byte) (StepFunc, []byte, error) { + return IllegalState, nil, IllegalStateErr +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/sasl/mech/plugins.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/sasl/mech/plugins.go new file mode 100644 index 00000000000..3642fccbeed --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/auth/sasl/mech/plugins.go @@ -0,0 +1,49 @@ +package mech + +import ( + "fmt" + "sync" + + log "github.com/golang/glog" +) + +var ( + mechLock sync.Mutex + supportedMechs = make(map[string]Factory) +) + +func Register(name string, f Factory) error { + mechLock.Lock() + defer mechLock.Unlock() + + if _, found := supportedMechs[name]; found { + return fmt.Errorf("Mechanism registered twice: %s", name) + } + supportedMechs[name] = f + log.V(1).Infof("Registered mechanism %s", name) + return nil +} + +func ListSupported() (list []string) { + mechLock.Lock() + defer mechLock.Unlock() + + for mechname := range supportedMechs { + list = append(list, mechname) + } + return list +} + +func SelectSupported(mechanisms []string) (selectedMech string, factory Factory) { + mechLock.Lock() + defer mechLock.Unlock() + + for _, m := range mechanisms { + if f, ok := supportedMechs[m]; ok { + selectedMech = m + factory = f + break + } + } + return +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/executor/doc.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/executor/doc.go new file mode 100644 index 00000000000..0f37d2c2237 --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/executor/doc.go @@ -0,0 +1,5 @@ +/* +Package executor includes the interfaces of the mesos executor and +the mesos executor driver, as well as an implementation of the driver. +*/ +package executor diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/executor/exectype.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/executor/exectype.go new file mode 100644 index 00000000000..1c70b4450af --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/executor/exectype.go @@ -0,0 +1,142 @@ +package executor + +import ( + "github.com/mesos/mesos-go/mesosproto" +) + +/** + * Executor callback interface to be implemented by frameworks' executors. Note + * that only one callback will be invoked at a time, so it is not + * recommended that you block within a callback because it may cause a + * deadlock. + * + * Each callback includes an instance to the executor driver that was + * used to run this executor. The driver will not change for the + * duration of an executor (i.e., from the point you do + * ExecutorDriver.Start() to the point that ExecutorDriver.Join() + * returns). This is intended for convenience so that an executor + * doesn't need to store a pointer to the driver itself. + */ +type Executor interface { + /** + * Invoked once the executor driver has been able to successfully + * connect with Mesos. In particular, a scheduler can pass some + * data to its executors through the FrameworkInfo.ExecutorInfo's + * data field. + */ + Registered(ExecutorDriver, *mesosproto.ExecutorInfo, *mesosproto.FrameworkInfo, *mesosproto.SlaveInfo) + + /** + * Invoked when the executor re-registers with a restarted slave. + */ + Reregistered(ExecutorDriver, *mesosproto.SlaveInfo) + + /** + * Invoked when the executor becomes "disconnected" from the slave + * (e.g., the slave is being restarted due to an upgrade). + */ + Disconnected(ExecutorDriver) + + /** + * Invoked when a task has been launched on this executor (initiated + * via SchedulerDriver.LaunchTasks). Note that this task can be realized + * with a goroutine, an external process, or some simple computation, however, + * no other callbacks will be invoked on this executor until this + * callback has returned. + */ + LaunchTask(ExecutorDriver, *mesosproto.TaskInfo) + + /** + * Invoked when a task running within this executor has been killed + * (via SchedulerDriver.KillTask). Note that no status update will + * be sent on behalf of the executor, the executor is responsible + * for creating a new TaskStatus (i.e., with TASK_KILLED) and + * invoking ExecutorDriver.SendStatusUpdate. + */ + KillTask(ExecutorDriver, *mesosproto.TaskID) + + /** + * Invoked when a framework message has arrived for this + * executor. These messages are best effort; do not expect a + * framework message to be retransmitted in any reliable fashion. + */ + FrameworkMessage(ExecutorDriver, string) + + /** + * Invoked when the executor should terminate all of its currently + * running tasks. Note that after Mesos has determined that an + * executor has terminated, any tasks that the executor did not send + * terminal status updates for (e.g., TASK_KILLED, TASK_FINISHED, + * TASK_FAILED, etc) a TASK_LOST status update will be created. + */ + Shutdown(ExecutorDriver) + + /** + * Invoked when a fatal error has occured with the executor and/or + * executor driver. The driver will be aborted BEFORE invoking this + * callback. + */ + Error(ExecutorDriver, string) +} + +/** + * ExecutorDriver interface for connecting an executor to Mesos. This + * interface is used both to manage the executor's lifecycle (start + * it, stop it, or wait for it to finish) and to interact with Mesos + * (e.g., send status updates, send framework messages, etc.). + * A driver method is expected to fail-fast and return an error when possible. + * Other internal errors (or remote error) that occur asynchronously are handled + * using the the Executor.Error() callback. + */ +type ExecutorDriver interface { + /** + * Starts the executor driver. This needs to be called before any + * other driver calls are made. + */ + Start() (mesosproto.Status, error) + + /** + * Stops the executor driver. + */ + Stop() (mesosproto.Status, error) + + /** + * Aborts the driver so that no more callbacks can be made to the + * executor. The semantics of abort and stop have deliberately been + * separated so that code can detect an aborted driver (i.e., via + * the return status of ExecutorDriver.Join, see below), and + * instantiate and start another driver if desired (from within the + * same process ... although this functionality is currently not + * supported for executors). + */ + Abort() (mesosproto.Status, error) + + /** + * Waits for the driver to be stopped or aborted, possibly + * blocking the calling goroutine indefinitely. The return status of + * this function can be used to determine if the driver was aborted + * (see package mesosproto for a description of Status). + */ + Join() (mesosproto.Status, error) + + /** + * Starts and immediately joins (i.e., blocks on) the driver. + */ + Run() (mesosproto.Status, error) + + /** + * Sends a status update to the framework scheduler, retrying as + * necessary until an acknowledgement has been received or the + * executor is terminated (in which case, a TASK_LOST status update + * will be sent). See Scheduler.StatusUpdate for more information + * about status update acknowledgements. + */ + SendStatusUpdate(*mesosproto.TaskStatus) (mesosproto.Status, error) + + /** + * Sends a message to the framework scheduler. These messages are + * best effort; do not expect a framework message to be + * retransmitted in any reliable fashion. + */ + SendFrameworkMessage(string) (mesosproto.Status, error) +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/executor/executor.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/executor/executor.go new file mode 100644 index 00000000000..05ed98581ee --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/executor/executor.go @@ -0,0 +1,583 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package executor + +import ( + "fmt" + "net" + "os" + "sync" + "time" + + "code.google.com/p/go-uuid/uuid" + "github.com/gogo/protobuf/proto" + log "github.com/golang/glog" + "github.com/mesos/mesos-go/mesosproto" + "github.com/mesos/mesos-go/mesosutil" + "github.com/mesos/mesos-go/mesosutil/process" + "github.com/mesos/mesos-go/messenger" + "github.com/mesos/mesos-go/upid" + "golang.org/x/net/context" +) + +type DriverConfig struct { + Executor Executor + HostnameOverride string // optional + BindingAddress net.IP // optional + BindingPort uint16 // optional + NewMessenger func() (messenger.Messenger, error) // optional +} + +// MesosExecutorDriver is a implementation of the ExecutorDriver. +type MesosExecutorDriver struct { + lock sync.RWMutex + self *upid.UPID + exec Executor + stopCh chan struct{} + destroyCh chan struct{} + stopped bool + status mesosproto.Status + messenger messenger.Messenger + slaveUPID *upid.UPID + slaveID *mesosproto.SlaveID + frameworkID *mesosproto.FrameworkID + executorID *mesosproto.ExecutorID + workDir string + connected bool + connection uuid.UUID + local bool // TODO(yifan): Not used yet. + directory string // TODO(yifan): Not used yet. + checkpoint bool + recoveryTimeout time.Duration + updates map[string]*mesosproto.StatusUpdate // Key is a UUID string. TODO(yifan): Not used yet. + tasks map[string]*mesosproto.TaskInfo // Key is a UUID string. TODO(yifan): Not used yet. +} + +// NewMesosExecutorDriver creates a new mesos executor driver. +func NewMesosExecutorDriver(config DriverConfig) (*MesosExecutorDriver, error) { + if config.Executor == nil { + msg := "Executor callback interface cannot be nil." + log.Errorln(msg) + return nil, fmt.Errorf(msg) + } + + hostname := mesosutil.GetHostname(config.HostnameOverride) + newMessenger := config.NewMessenger + if newMessenger == nil { + newMessenger = func() (messenger.Messenger, error) { + process := process.New("executor") + return messenger.ForHostname(process, hostname, config.BindingAddress, config.BindingPort) + } + } + + driver := &MesosExecutorDriver{ + exec: config.Executor, + status: mesosproto.Status_DRIVER_NOT_STARTED, + stopCh: make(chan struct{}), + destroyCh: make(chan struct{}), + stopped: true, + updates: make(map[string]*mesosproto.StatusUpdate), + tasks: make(map[string]*mesosproto.TaskInfo), + workDir: ".", + } + var err error + if driver.messenger, err = newMessenger(); err != nil { + return nil, err + } + if err = driver.init(); err != nil { + log.Errorf("failed to initialize the driver: %v", err) + return nil, err + } + return driver, nil +} + +// init initializes the driver. +func (driver *MesosExecutorDriver) init() error { + log.Infof("Init mesos executor driver\n") + log.Infof("Version: %v\n", mesosutil.MesosVersion) + + // Parse environments. + if err := driver.parseEnviroments(); err != nil { + log.Errorf("Failed to parse environments: %v\n", err) + return err + } + + // Install handlers. + driver.messenger.Install(driver.registered, &mesosproto.ExecutorRegisteredMessage{}) + driver.messenger.Install(driver.reregistered, &mesosproto.ExecutorReregisteredMessage{}) + driver.messenger.Install(driver.reconnect, &mesosproto.ReconnectExecutorMessage{}) + driver.messenger.Install(driver.runTask, &mesosproto.RunTaskMessage{}) + driver.messenger.Install(driver.killTask, &mesosproto.KillTaskMessage{}) + driver.messenger.Install(driver.statusUpdateAcknowledgement, &mesosproto.StatusUpdateAcknowledgementMessage{}) + driver.messenger.Install(driver.frameworkMessage, &mesosproto.FrameworkToExecutorMessage{}) + driver.messenger.Install(driver.shutdown, &mesosproto.ShutdownExecutorMessage{}) + driver.messenger.Install(driver.frameworkError, &mesosproto.FrameworkErrorMessage{}) + return nil +} + +func (driver *MesosExecutorDriver) parseEnviroments() error { + var value string + + value = os.Getenv("MESOS_LOCAL") + if len(value) > 0 { + driver.local = true + } + + value = os.Getenv("MESOS_SLAVE_PID") + if len(value) == 0 { + return fmt.Errorf("Cannot find MESOS_SLAVE_PID in the environment") + } + upid, err := upid.Parse(value) + if err != nil { + log.Errorf("Cannot parse UPID %v\n", err) + return err + } + driver.slaveUPID = upid + + value = os.Getenv("MESOS_SLAVE_ID") + driver.slaveID = &mesosproto.SlaveID{Value: proto.String(value)} + + value = os.Getenv("MESOS_FRAMEWORK_ID") + driver.frameworkID = &mesosproto.FrameworkID{Value: proto.String(value)} + + value = os.Getenv("MESOS_EXECUTOR_ID") + driver.executorID = &mesosproto.ExecutorID{Value: proto.String(value)} + + value = os.Getenv("MESOS_DIRECTORY") + if len(value) > 0 { + driver.workDir = value + } + + value = os.Getenv("MESOS_CHECKPOINT") + if value == "1" { + driver.checkpoint = true + } + // TODO(yifan): Parse the duration. For now just use default. + return nil +} + +// ------------------------- Accessors ----------------------- // +func (driver *MesosExecutorDriver) Status() mesosproto.Status { + driver.lock.RLock() + defer driver.lock.RUnlock() + return driver.status +} +func (driver *MesosExecutorDriver) setStatus(stat mesosproto.Status) { + driver.lock.Lock() + driver.status = stat + driver.lock.Unlock() +} + +func (driver *MesosExecutorDriver) Stopped() bool { + return driver.stopped +} + +func (driver *MesosExecutorDriver) setStopped(val bool) { + driver.lock.Lock() + driver.stopped = val + driver.lock.Unlock() +} + +func (driver *MesosExecutorDriver) Connected() bool { + return driver.connected +} + +func (driver *MesosExecutorDriver) setConnected(val bool) { + driver.lock.Lock() + driver.connected = val + driver.lock.Unlock() +} + +// --------------------- Message Handlers --------------------- // + +func (driver *MesosExecutorDriver) registered(from *upid.UPID, pbMsg proto.Message) { + log.Infoln("Executor driver registered") + + msg := pbMsg.(*mesosproto.ExecutorRegisteredMessage) + slaveID := msg.GetSlaveId() + executorInfo := msg.GetExecutorInfo() + frameworkInfo := msg.GetFrameworkInfo() + slaveInfo := msg.GetSlaveInfo() + + if driver.stopped { + log.Infof("Ignoring registered message from slave %v, because the driver is stopped!\n", slaveID) + return + } + + log.Infof("Registered on slave %v\n", slaveID) + driver.setConnected(true) + driver.connection = uuid.NewUUID() + driver.exec.Registered(driver, executorInfo, frameworkInfo, slaveInfo) +} + +func (driver *MesosExecutorDriver) reregistered(from *upid.UPID, pbMsg proto.Message) { + log.Infoln("Executor driver reregistered") + + msg := pbMsg.(*mesosproto.ExecutorReregisteredMessage) + slaveID := msg.GetSlaveId() + slaveInfo := msg.GetSlaveInfo() + + if driver.stopped { + log.Infof("Ignoring re-registered message from slave %v, because the driver is stopped!\n", slaveID) + return + } + + log.Infof("Re-registered on slave %v\n", slaveID) + driver.setConnected(true) + driver.connection = uuid.NewUUID() + driver.exec.Reregistered(driver, slaveInfo) +} + +func (driver *MesosExecutorDriver) send(upid *upid.UPID, msg proto.Message) error { + //TODO(jdef) should implement timeout here + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + c := make(chan error, 1) + go func() { c <- driver.messenger.Send(ctx, upid, msg) }() + + select { + case <-ctx.Done(): + <-c // wait for Send(...) + return ctx.Err() + case err := <-c: + return err + } +} + +func (driver *MesosExecutorDriver) reconnect(from *upid.UPID, pbMsg proto.Message) { + log.Infoln("Executor driver reconnect") + + msg := pbMsg.(*mesosproto.ReconnectExecutorMessage) + slaveID := msg.GetSlaveId() + + if driver.stopped { + log.Infof("Ignoring reconnect message from slave %v, because the driver is stopped!\n", slaveID) + return + } + + log.Infof("Received reconnect request from slave %v\n", slaveID) + driver.slaveUPID = from + + message := &mesosproto.ReregisterExecutorMessage{ + ExecutorId: driver.executorID, + FrameworkId: driver.frameworkID, + } + // Send all unacknowledged updates. + for _, u := range driver.updates { + message.Updates = append(message.Updates, u) + } + // Send all unacknowledged tasks. + for _, t := range driver.tasks { + message.Tasks = append(message.Tasks, t) + } + // Send the message. + if err := driver.send(driver.slaveUPID, message); err != nil { + log.Errorf("Failed to send %v: %v\n", message, err) + } +} + +func (driver *MesosExecutorDriver) runTask(from *upid.UPID, pbMsg proto.Message) { + log.Infoln("Executor driver runTask") + + msg := pbMsg.(*mesosproto.RunTaskMessage) + task := msg.GetTask() + taskID := task.GetTaskId() + + if driver.stopped { + log.Infof("Ignoring run task message for task %v because the driver is stopped!\n", taskID) + return + } + if _, ok := driver.tasks[taskID.String()]; ok { + log.Fatalf("Unexpected duplicate task %v\n", taskID) + } + + log.Infof("Executor asked to run task '%v'\n", taskID) + driver.tasks[taskID.String()] = task + driver.exec.LaunchTask(driver, task) +} + +func (driver *MesosExecutorDriver) killTask(from *upid.UPID, pbMsg proto.Message) { + log.Infoln("Executor driver killTask") + + msg := pbMsg.(*mesosproto.KillTaskMessage) + taskID := msg.GetTaskId() + + if driver.stopped { + log.Infof("Ignoring kill task message for task %v, because the driver is stopped!\n", taskID) + return + } + + log.Infof("Executor driver is asked to kill task '%v'\n", taskID) + driver.exec.KillTask(driver, taskID) +} + +func (driver *MesosExecutorDriver) statusUpdateAcknowledgement(from *upid.UPID, pbMsg proto.Message) { + log.Infoln("Executor statusUpdateAcknowledgement") + + msg := pbMsg.(*mesosproto.StatusUpdateAcknowledgementMessage) + log.Infof("Receiving status update acknowledgement %v", msg) + + frameworkID := msg.GetFrameworkId() + taskID := msg.GetTaskId() + uuid := uuid.UUID(msg.GetUuid()) + + if driver.stopped { + log.Infof("Ignoring status update acknowledgement %v for task %v of framework %v because the driver is stopped!\n", + uuid, taskID, frameworkID) + } + + // Remove the corresponding update. + delete(driver.updates, uuid.String()) + // Remove the corresponding task. + delete(driver.tasks, taskID.String()) +} + +func (driver *MesosExecutorDriver) frameworkMessage(from *upid.UPID, pbMsg proto.Message) { + log.Infoln("Executor driver received frameworkMessage") + + msg := pbMsg.(*mesosproto.FrameworkToExecutorMessage) + data := msg.GetData() + + if driver.stopped { + log.Infof("Ignoring framework message because the driver is stopped!\n") + return + } + + log.Infof("Executor driver receives framework message\n") + driver.exec.FrameworkMessage(driver, string(data)) +} + +func (driver *MesosExecutorDriver) shutdown(from *upid.UPID, pbMsg proto.Message) { + log.Infoln("Executor driver received shutdown") + + _, ok := pbMsg.(*mesosproto.ShutdownExecutorMessage) + if !ok { + panic("Not a ShutdownExecutorMessage! This should not happen") + } + + if driver.stopped { + log.Infof("Ignoring shutdown message because the driver is stopped!\n") + return + } + + log.Infof("Executor driver is asked to shutdown\n") + + driver.exec.Shutdown(driver) + // driver.Stop() will cause process to eventually stop. + driver.Stop() +} + +func (driver *MesosExecutorDriver) frameworkError(from *upid.UPID, pbMsg proto.Message) { + log.Infoln("Executor driver received error") + + msg := pbMsg.(*mesosproto.FrameworkErrorMessage) + driver.exec.Error(driver, msg.GetMessage()) +} + +// ------------------------ Driver Implementation ----------------- // + +// Start starts the executor driver +func (driver *MesosExecutorDriver) Start() (mesosproto.Status, error) { + log.Infoln("Starting the executor driver") + + if stat := driver.Status(); stat != mesosproto.Status_DRIVER_NOT_STARTED { + return stat, fmt.Errorf("Unable to Start, expecting status %s, but got %s", mesosproto.Status_DRIVER_NOT_STARTED, stat) + } + + driver.setStatus(mesosproto.Status_DRIVER_NOT_STARTED) + driver.setStopped(true) + + // Start the messenger. + if err := driver.messenger.Start(); err != nil { + log.Errorf("Failed to start executor: %v\n", err) + return driver.Status(), err + } + + driver.self = driver.messenger.UPID() + + // Register with slave. + log.V(3).Infoln("Sending Executor registration") + message := &mesosproto.RegisterExecutorMessage{ + FrameworkId: driver.frameworkID, + ExecutorId: driver.executorID, + } + + if err := driver.send(driver.slaveUPID, message); err != nil { + stat := driver.Status() + log.Errorf("Stopping the executor, failed to send %v: %v\n", message, err) + err0 := driver.stop(stat) + if err0 != nil { + log.Errorf("Failed to stop executor: %v\n", err) + return stat, err0 + } + return stat, err + } + driver.setStopped(false) + driver.setStatus(mesosproto.Status_DRIVER_RUNNING) + + log.Infoln("Mesos executor is started with PID=", driver.self.String()) + + return driver.Status(), nil +} + +// Stop stops the driver by sending a 'stopEvent' to the event loop, and +// receives the result from the response channel. +func (driver *MesosExecutorDriver) Stop() (mesosproto.Status, error) { + log.Infoln("Stopping the executor driver") + if stat := driver.Status(); stat != mesosproto.Status_DRIVER_RUNNING { + return stat, fmt.Errorf("Unable to Stop, expecting status %s, but got %s", mesosproto.Status_DRIVER_RUNNING, stat) + } + stopStat := mesosproto.Status_DRIVER_STOPPED + return stopStat, driver.stop(stopStat) +} + +// internal function for stopping the driver and set reason for stopping +// Note that messages inflight or queued will not be processed. +func (driver *MesosExecutorDriver) stop(stopStatus mesosproto.Status) error { + err := driver.messenger.Stop() + defer close(driver.destroyCh) + defer close(driver.stopCh) + + driver.setStatus(stopStatus) + driver.setStopped(true) + + if err != nil { + return err + } + + return nil +} + +// Abort aborts the driver by sending an 'abortEvent' to the event loop, and +// receives the result from the response channel. +func (driver *MesosExecutorDriver) Abort() (mesosproto.Status, error) { + if stat := driver.Status(); stat != mesosproto.Status_DRIVER_RUNNING { + return stat, fmt.Errorf("Unable to Stop, expecting status %s, but got %s", mesosproto.Status_DRIVER_RUNNING, stat) + } + + log.Infoln("Aborting the executor driver") + abortStat := mesosproto.Status_DRIVER_ABORTED + return abortStat, driver.stop(abortStat) +} + +// Join waits for the driver by sending a 'joinEvent' to the event loop, and wait +// on a channel for the notification of driver termination. +func (driver *MesosExecutorDriver) Join() (mesosproto.Status, error) { + log.Infoln("Waiting for the executor driver to stop") + if stat := driver.Status(); stat != mesosproto.Status_DRIVER_RUNNING { + return stat, fmt.Errorf("Unable to Join, expecting status %s, but got %s", mesosproto.Status_DRIVER_RUNNING, stat) + } + <-driver.stopCh // wait for stop signal + return driver.Status(), nil +} + +// Run starts the driver and calls Join() to wait for stop request. +func (driver *MesosExecutorDriver) Run() (mesosproto.Status, error) { + stat, err := driver.Start() + + if err != nil { + return driver.Stop() + } + + if stat != mesosproto.Status_DRIVER_RUNNING { + return stat, fmt.Errorf("Unable to continue to Run, expecting status %s, but got %s", mesosproto.Status_DRIVER_RUNNING, driver.status) + } + + return driver.Join() +} + +// SendStatusUpdate sends status updates to the slave. +func (driver *MesosExecutorDriver) SendStatusUpdate(taskStatus *mesosproto.TaskStatus) (mesosproto.Status, error) { + log.V(3).Infoln("Sending task status update: ", taskStatus.String()) + + if stat := driver.Status(); stat != mesosproto.Status_DRIVER_RUNNING { + return stat, fmt.Errorf("Unable to SendStatusUpdate, expecting driver.status %s, but got %s", mesosproto.Status_DRIVER_RUNNING, stat) + } + + if taskStatus.GetState() == mesosproto.TaskState_TASK_STAGING { + err := fmt.Errorf("Executor is not allowed to send TASK_STAGING status update. Aborting!") + log.Errorln(err) + if err0 := driver.stop(mesosproto.Status_DRIVER_ABORTED); err0 != nil { + log.Errorln("Error while stopping the driver", err0) + } + + return driver.Status(), err + } + + // Set up status update. + update := driver.makeStatusUpdate(taskStatus) + log.Infof("Executor sending status update %v\n", update.String()) + + // Capture the status update. + driver.updates[uuid.UUID(update.GetUuid()).String()] = update + + // Put the status update in the message. + message := &mesosproto.StatusUpdateMessage{ + Update: update, + Pid: proto.String(driver.self.String()), + } + // Send the message. + if err := driver.send(driver.slaveUPID, message); err != nil { + log.Errorf("Failed to send %v: %v\n", message, err) + return driver.status, err + } + + return driver.Status(), nil +} + +func (driver *MesosExecutorDriver) makeStatusUpdate(taskStatus *mesosproto.TaskStatus) *mesosproto.StatusUpdate { + now := float64(time.Now().Unix()) + // Fill in all the fields. + taskStatus.Timestamp = proto.Float64(now) + taskStatus.SlaveId = driver.slaveID + update := &mesosproto.StatusUpdate{ + FrameworkId: driver.frameworkID, + ExecutorId: driver.executorID, + SlaveId: driver.slaveID, + Status: taskStatus, + Timestamp: proto.Float64(now), + Uuid: uuid.NewUUID(), + } + return update +} + +// SendFrameworkMessage sends the framework message by sending a 'sendFrameworkMessageEvent' +// to the event loop, and receives the result from the response channel. +func (driver *MesosExecutorDriver) SendFrameworkMessage(data string) (mesosproto.Status, error) { + log.V(3).Infoln("Sending framework message", string(data)) + + if stat := driver.Status(); stat != mesosproto.Status_DRIVER_RUNNING { + return stat, fmt.Errorf("Unable to SendFrameworkMessage, expecting status %s, but got %s", mesosproto.Status_DRIVER_RUNNING, stat) + } + + message := &mesosproto.ExecutorToFrameworkMessage{ + SlaveId: driver.slaveID, + FrameworkId: driver.frameworkID, + ExecutorId: driver.executorID, + Data: []byte(data), + } + + // Send the message. + if err := driver.send(driver.slaveUPID, message); err != nil { + log.Errorln("Failed to send message %v: %v", message, err) + return driver.status, err + } + return driver.status, nil +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/executor/executor_intgr_test.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/executor/executor_intgr_test.go new file mode 100644 index 00000000000..38b72731872 --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/executor/executor_intgr_test.go @@ -0,0 +1,531 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package executor + +import ( + "io/ioutil" + "net/http" + "net/url" + "os" + "strings" + "sync" + "testing" + "time" + + "code.google.com/p/go-uuid/uuid" + "github.com/gogo/protobuf/proto" + log "github.com/golang/glog" + mesos "github.com/mesos/mesos-go/mesosproto" + util "github.com/mesos/mesos-go/mesosutil" + "github.com/mesos/mesos-go/testutil" + "github.com/stretchr/testify/assert" +) + +// testScuduler is used for testing Schduler callbacks. +type testExecutor struct { + ch chan bool + wg *sync.WaitGroup + t *testing.T +} + +func newTestExecutor(t *testing.T) *testExecutor { + return &testExecutor{ch: make(chan bool), t: t} +} + +func (exec *testExecutor) Registered(driver ExecutorDriver, execinfo *mesos.ExecutorInfo, fwinfo *mesos.FrameworkInfo, slaveinfo *mesos.SlaveInfo) { + log.Infoln("Exec.Registered() called.") + assert.NotNil(exec.t, execinfo) + assert.NotNil(exec.t, fwinfo) + assert.NotNil(exec.t, slaveinfo) + exec.ch <- true +} + +func (exec *testExecutor) Reregistered(driver ExecutorDriver, slaveinfo *mesos.SlaveInfo) { + log.Infoln("Exec.Re-registered() called.") + assert.NotNil(exec.t, slaveinfo) + exec.ch <- true +} + +func (e *testExecutor) Disconnected(ExecutorDriver) {} + +func (exec *testExecutor) LaunchTask(driver ExecutorDriver, taskinfo *mesos.TaskInfo) { + log.Infoln("Exec.LaunchTask() called.") + assert.NotNil(exec.t, taskinfo) + assert.True(exec.t, util.NewTaskID("test-task-001").Equal(taskinfo.TaskId)) + exec.ch <- true +} + +func (exec *testExecutor) KillTask(driver ExecutorDriver, taskid *mesos.TaskID) { + log.Infoln("Exec.KillTask() called.") + assert.NotNil(exec.t, taskid) + assert.True(exec.t, util.NewTaskID("test-task-001").Equal(taskid)) + exec.ch <- true +} + +func (exec *testExecutor) FrameworkMessage(driver ExecutorDriver, message string) { + log.Infoln("Exec.FrameworkMessage() called.") + assert.NotNil(exec.t, message) + assert.Equal(exec.t, "Hello-Test", message) + exec.ch <- true +} + +func (exec *testExecutor) Shutdown(ExecutorDriver) { + log.Infoln("Exec.Shutdown() called.") + exec.ch <- true +} + +func (exec *testExecutor) Error(driver ExecutorDriver, err string) { + log.Infoln("Exec.Error() called.") + log.Infoln("Got error ", err) + driver.Stop() + exec.ch <- true +} + +// ------------------------ Test Functions -------------------- // + +func setTestEnv(t *testing.T) { + assert.NoError(t, os.Setenv("MESOS_FRAMEWORK_ID", frameworkID)) + assert.NoError(t, os.Setenv("MESOS_EXECUTOR_ID", executorID)) +} + +func newIntegrationTestDriver(t *testing.T, exec Executor) *MesosExecutorDriver { + dconfig := DriverConfig{ + Executor: exec, + } + driver, err := NewMesosExecutorDriver(dconfig) + if err != nil { + t.Fatal(err) + } + return driver +} + +func TestExecutorDriverRegisterExecutorMessage(t *testing.T) { + setTestEnv(t) + ch := make(chan bool) + server := testutil.NewMockSlaveHttpServer(t, func(rsp http.ResponseWriter, req *http.Request) { + reqPath, err := url.QueryUnescape(req.URL.String()) + assert.NoError(t, err) + log.Infoln("RCVD request", reqPath) + + data, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Fatalf("Missing RegisteredExecutor data from scheduler.") + } + defer req.Body.Close() + + message := new(mesos.RegisterExecutorMessage) + err = proto.Unmarshal(data, message) + assert.NoError(t, err) + assert.Equal(t, frameworkID, message.GetFrameworkId().GetValue()) + assert.Equal(t, executorID, message.GetExecutorId().GetValue()) + + ch <- true + + rsp.WriteHeader(http.StatusAccepted) + }) + + defer server.Close() + + exec := newTestExecutor(t) + exec.ch = ch + + driver := newIntegrationTestDriver(t, exec) + assert.True(t, driver.stopped) + + stat, err := driver.Start() + assert.NoError(t, err) + assert.False(t, driver.stopped) + assert.Equal(t, mesos.Status_DRIVER_RUNNING, stat) + + select { + case <-ch: + case <-time.After(time.Millisecond * 2): + log.Errorf("Tired of waiting...") + } +} + +func TestExecutorDriverExecutorRegisteredEvent(t *testing.T) { + setTestEnv(t) + ch := make(chan bool) + // Mock Slave process to respond to registration event. + server := testutil.NewMockSlaveHttpServer(t, func(rsp http.ResponseWriter, req *http.Request) { + reqPath, err := url.QueryUnescape(req.URL.String()) + assert.NoError(t, err) + log.Infoln("RCVD request", reqPath) + rsp.WriteHeader(http.StatusAccepted) + }) + + defer server.Close() + + exec := newTestExecutor(t) + exec.ch = ch + exec.t = t + + // start + driver := newIntegrationTestDriver(t, exec) + stat, err := driver.Start() + assert.NoError(t, err) + assert.Equal(t, mesos.Status_DRIVER_RUNNING, stat) + + //simulate sending ExecutorRegisteredMessage from server to exec pid. + pbMsg := &mesos.ExecutorRegisteredMessage{ + ExecutorInfo: util.NewExecutorInfo(util.NewExecutorID(executorID), nil), + FrameworkId: util.NewFrameworkID(frameworkID), + FrameworkInfo: util.NewFrameworkInfo("test", "test-framework", util.NewFrameworkID(frameworkID)), + SlaveId: util.NewSlaveID(slaveID), + SlaveInfo: &mesos.SlaveInfo{Hostname: proto.String("localhost")}, + } + c := testutil.NewMockMesosClient(t, server.PID) + c.SendMessage(driver.self, pbMsg) + assert.True(t, driver.connected) + select { + case <-ch: + case <-time.After(time.Millisecond * 2): + log.Errorf("Tired of waiting...") + } +} + +func TestExecutorDriverExecutorReregisteredEvent(t *testing.T) { + setTestEnv(t) + ch := make(chan bool) + // Mock Slave process to respond to registration event. + server := testutil.NewMockSlaveHttpServer(t, func(rsp http.ResponseWriter, req *http.Request) { + reqPath, err := url.QueryUnescape(req.URL.String()) + assert.NoError(t, err) + log.Infoln("RCVD request", reqPath) + rsp.WriteHeader(http.StatusAccepted) + }) + + defer server.Close() + + exec := newTestExecutor(t) + exec.ch = ch + exec.t = t + + // start + driver := newIntegrationTestDriver(t, exec) + stat, err := driver.Start() + assert.NoError(t, err) + assert.Equal(t, mesos.Status_DRIVER_RUNNING, stat) + + //simulate sending ExecutorRegisteredMessage from server to exec pid. + pbMsg := &mesos.ExecutorReregisteredMessage{ + SlaveId: util.NewSlaveID(slaveID), + SlaveInfo: &mesos.SlaveInfo{Hostname: proto.String("localhost")}, + } + c := testutil.NewMockMesosClient(t, server.PID) + c.SendMessage(driver.self, pbMsg) + assert.True(t, driver.connected) + select { + case <-ch: + case <-time.After(time.Millisecond * 2): + log.Errorf("Tired of waiting...") + } +} + +func TestExecutorDriverReconnectEvent(t *testing.T) { + setTestEnv(t) + ch := make(chan bool) + // Mock Slave process to respond to registration event. + server := testutil.NewMockSlaveHttpServer(t, func(rsp http.ResponseWriter, req *http.Request) { + reqPath, err := url.QueryUnescape(req.URL.String()) + assert.NoError(t, err) + log.Infoln("RCVD request", reqPath) + + // exec registration request + if strings.Contains(reqPath, "RegisterExecutorMessage") { + log.Infoln("Got Executor registration request") + } + + if strings.Contains(reqPath, "ReregisterExecutorMessage") { + log.Infoln("Got Executor Re-registration request") + ch <- true + } + + rsp.WriteHeader(http.StatusAccepted) + }) + + defer server.Close() + + exec := newTestExecutor(t) + exec.t = t + + // start + driver := newIntegrationTestDriver(t, exec) + stat, err := driver.Start() + assert.NoError(t, err) + assert.Equal(t, mesos.Status_DRIVER_RUNNING, stat) + driver.connected = true + + // send "reconnect" event to driver + pbMsg := &mesos.ReconnectExecutorMessage{ + SlaveId: util.NewSlaveID(slaveID), + } + c := testutil.NewMockMesosClient(t, server.PID) + c.SendMessage(driver.self, pbMsg) + + select { + case <-ch: + case <-time.After(time.Millisecond * 2): + log.Errorf("Tired of waiting...") + } + +} + +func TestExecutorDriverRunTaskEvent(t *testing.T) { + setTestEnv(t) + ch := make(chan bool) + // Mock Slave process to respond to registration event. + server := testutil.NewMockSlaveHttpServer(t, func(rsp http.ResponseWriter, req *http.Request) { + reqPath, err := url.QueryUnescape(req.URL.String()) + assert.NoError(t, err) + log.Infoln("RCVD request", reqPath) + rsp.WriteHeader(http.StatusAccepted) + }) + + defer server.Close() + + exec := newTestExecutor(t) + exec.ch = ch + exec.t = t + + // start + driver := newIntegrationTestDriver(t, exec) + stat, err := driver.Start() + assert.NoError(t, err) + assert.Equal(t, mesos.Status_DRIVER_RUNNING, stat) + driver.connected = true + + // send runtask event to driver + pbMsg := &mesos.RunTaskMessage{ + FrameworkId: util.NewFrameworkID(frameworkID), + Framework: util.NewFrameworkInfo( + "test", "test-framework-001", util.NewFrameworkID(frameworkID), + ), + Pid: proto.String(server.PID.String()), + Task: util.NewTaskInfo( + "test-task", + util.NewTaskID("test-task-001"), + util.NewSlaveID(slaveID), + []*mesos.Resource{ + util.NewScalarResource("mem", 112), + util.NewScalarResource("cpus", 2), + }, + ), + } + + c := testutil.NewMockMesosClient(t, server.PID) + c.SendMessage(driver.self, pbMsg) + + select { + case <-ch: + case <-time.After(time.Millisecond * 2): + log.Errorf("Tired of waiting...") + } + +} + +func TestExecutorDriverKillTaskEvent(t *testing.T) { + setTestEnv(t) + ch := make(chan bool) + // Mock Slave process to respond to registration event. + server := testutil.NewMockSlaveHttpServer(t, func(rsp http.ResponseWriter, req *http.Request) { + reqPath, err := url.QueryUnescape(req.URL.String()) + assert.NoError(t, err) + log.Infoln("RCVD request", reqPath) + rsp.WriteHeader(http.StatusAccepted) + }) + + defer server.Close() + + exec := newTestExecutor(t) + exec.ch = ch + exec.t = t + + // start + driver := newIntegrationTestDriver(t, exec) + stat, err := driver.Start() + assert.NoError(t, err) + assert.Equal(t, mesos.Status_DRIVER_RUNNING, stat) + driver.connected = true + + // send runtask event to driver + pbMsg := &mesos.KillTaskMessage{ + FrameworkId: util.NewFrameworkID(frameworkID), + TaskId: util.NewTaskID("test-task-001"), + } + + c := testutil.NewMockMesosClient(t, server.PID) + c.SendMessage(driver.self, pbMsg) + + select { + case <-ch: + case <-time.After(time.Millisecond * 2): + log.Errorf("Tired of waiting...") + } +} + +func TestExecutorDriverStatusUpdateAcknowledgement(t *testing.T) { + setTestEnv(t) + ch := make(chan bool) + // Mock Slave process to respond to registration event. + server := testutil.NewMockSlaveHttpServer(t, func(rsp http.ResponseWriter, req *http.Request) { + reqPath, err := url.QueryUnescape(req.URL.String()) + assert.NoError(t, err) + log.Infoln("RCVD request", reqPath) + rsp.WriteHeader(http.StatusAccepted) + }) + + defer server.Close() + + exec := newTestExecutor(t) + exec.ch = ch + exec.t = t + + // start + driver := newIntegrationTestDriver(t, exec) + stat, err := driver.Start() + assert.NoError(t, err) + assert.Equal(t, mesos.Status_DRIVER_RUNNING, stat) + driver.connected = true + + // send ACK from server + pbMsg := &mesos.StatusUpdateAcknowledgementMessage{ + SlaveId: util.NewSlaveID(slaveID), + FrameworkId: util.NewFrameworkID(frameworkID), + TaskId: util.NewTaskID("test-task-001"), + Uuid: []byte(uuid.NewRandom().String()), + } + + c := testutil.NewMockMesosClient(t, server.PID) + c.SendMessage(driver.self, pbMsg) + <-time.After(time.Millisecond * 2) +} + +func TestExecutorDriverFrameworkToExecutorMessageEvent(t *testing.T) { + setTestEnv(t) + ch := make(chan bool) + // Mock Slave process to respond to registration event. + server := testutil.NewMockSlaveHttpServer(t, func(rsp http.ResponseWriter, req *http.Request) { + reqPath, err := url.QueryUnescape(req.URL.String()) + assert.NoError(t, err) + log.Infoln("RCVD request", reqPath) + rsp.WriteHeader(http.StatusAccepted) + }) + + defer server.Close() + + exec := newTestExecutor(t) + exec.ch = ch + exec.t = t + + // start + driver := newIntegrationTestDriver(t, exec) + stat, err := driver.Start() + assert.NoError(t, err) + assert.Equal(t, mesos.Status_DRIVER_RUNNING, stat) + driver.connected = true + + // send runtask event to driver + pbMsg := &mesos.FrameworkToExecutorMessage{ + SlaveId: util.NewSlaveID(slaveID), + ExecutorId: util.NewExecutorID(executorID), + FrameworkId: util.NewFrameworkID(frameworkID), + Data: []byte("Hello-Test"), + } + + c := testutil.NewMockMesosClient(t, server.PID) + c.SendMessage(driver.self, pbMsg) + + select { + case <-ch: + case <-time.After(time.Millisecond * 2): + log.Errorf("Tired of waiting...") + } +} + +func TestExecutorDriverShutdownEvent(t *testing.T) { + setTestEnv(t) + ch := make(chan bool) + // Mock Slave process to respond to registration event. + server := testutil.NewMockSlaveHttpServer(t, func(rsp http.ResponseWriter, req *http.Request) { + reqPath, err := url.QueryUnescape(req.URL.String()) + assert.NoError(t, err) + log.Infoln("RCVD request", reqPath) + rsp.WriteHeader(http.StatusAccepted) + }) + + defer server.Close() + + exec := newTestExecutor(t) + exec.ch = ch + exec.t = t + + // start + driver := newIntegrationTestDriver(t, exec) + stat, err := driver.Start() + assert.NoError(t, err) + assert.Equal(t, mesos.Status_DRIVER_RUNNING, stat) + driver.connected = true + + // send runtask event to driver + pbMsg := &mesos.ShutdownExecutorMessage{} + + c := testutil.NewMockMesosClient(t, server.PID) + c.SendMessage(driver.self, pbMsg) + + select { + case <-ch: + case <-time.After(time.Millisecond * 5): + log.Errorf("Tired of waiting...") + } + + <-time.After(time.Millisecond * 5) // wait for shutdown to finish. + assert.Equal(t, mesos.Status_DRIVER_STOPPED, driver.Status()) +} + +func TestExecutorDriverError(t *testing.T) { + setTestEnv(t) + // Mock Slave process to respond to registration event. + server := testutil.NewMockSlaveHttpServer(t, func(rsp http.ResponseWriter, req *http.Request) { + reqPath, err := url.QueryUnescape(req.URL.String()) + assert.NoError(t, err) + log.Infoln("RCVD request", reqPath) + rsp.WriteHeader(http.StatusAccepted) + }) + + ch := make(chan bool) + exec := newTestExecutor(t) + exec.ch = ch + exec.t = t + + driver := newIntegrationTestDriver(t, exec) + server.Close() // will cause error + // Run() cause async message processing to start + // Therefore, error-handling will be done via Executor.Error callaback. + stat, err := driver.Run() + assert.NoError(t, err) + assert.Equal(t, mesos.Status_DRIVER_STOPPED, stat) + + select { + case <-ch: + case <-time.After(time.Millisecond * 5): + log.Errorf("Tired of waiting...") + } +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/executor/executor_test.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/executor/executor_test.go new file mode 100644 index 00000000000..a2894b2c299 --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/executor/executor_test.go @@ -0,0 +1,396 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package executor + +import ( + "fmt" + "os" + "testing" + "time" + + "github.com/mesos/mesos-go/healthchecker" + "github.com/mesos/mesos-go/mesosproto" + util "github.com/mesos/mesos-go/mesosutil" + "github.com/mesos/mesos-go/messenger" + "github.com/mesos/mesos-go/upid" + "github.com/stretchr/testify/assert" +) + +var ( + slavePID = "slave(1)@127.0.0.1:8080" + slaveID = "some-slave-id-uuid" + frameworkID = "some-framework-id-uuid" + executorID = "some-executor-id-uuid" +) + +func setEnvironments(t *testing.T, workDir string, checkpoint bool) { + assert.NoError(t, os.Setenv("MESOS_SLAVE_PID", slavePID)) + assert.NoError(t, os.Setenv("MESOS_SLAVE_ID", slaveID)) + assert.NoError(t, os.Setenv("MESOS_FRAMEWORK_ID", frameworkID)) + assert.NoError(t, os.Setenv("MESOS_EXECUTOR_ID", executorID)) + if len(workDir) > 0 { + assert.NoError(t, os.Setenv("MESOS_DIRECTORY", workDir)) + } + if checkpoint { + assert.NoError(t, os.Setenv("MESOS_CHECKPOINT", "1")) + } +} + +func clearEnvironments(t *testing.T) { + assert.NoError(t, os.Setenv("MESOS_SLAVE_PID", "")) + assert.NoError(t, os.Setenv("MESOS_SLAVE_ID", "")) + assert.NoError(t, os.Setenv("MESOS_FRAMEWORK_ID", "")) + assert.NoError(t, os.Setenv("MESOS_EXECUTOR_ID", "")) +} + +func newTestExecutorDriver(t *testing.T, exec Executor) *MesosExecutorDriver { + dconfig := DriverConfig{ + Executor: exec, + } + driver, err := NewMesosExecutorDriver(dconfig) + if err != nil { + t.Fatal(err) + } + return driver +} + +func createTestExecutorDriver(t *testing.T) ( + *MesosExecutorDriver, + *messenger.MockedMessenger, + *healthchecker.MockedHealthChecker) { + + exec := NewMockedExecutor() + + setEnvironments(t, "", false) + driver := newTestExecutorDriver(t, exec) + + messenger := messenger.NewMockedMessenger() + messenger.On("Start").Return(nil) + messenger.On("UPID").Return(&upid.UPID{}) + messenger.On("Send").Return(nil) + messenger.On("Stop").Return(nil) + + checker := healthchecker.NewMockedHealthChecker() + checker.On("Start").Return() + checker.On("Stop").Return() + + driver.messenger = messenger + return driver, messenger, checker +} + +func TestExecutorDriverStartFailedToParseEnvironment(t *testing.T) { + clearEnvironments(t) + exec := NewMockedExecutor() + exec.On("Error").Return(nil) + driver := newTestExecutorDriver(t, exec) + assert.Nil(t, driver) +} + +func TestExecutorDriverStartFailedToStartMessenger(t *testing.T) { + exec := NewMockedExecutor() + + setEnvironments(t, "", false) + driver := newTestExecutorDriver(t, exec) + assert.NotNil(t, driver) + messenger := messenger.NewMockedMessenger() + driver.messenger = messenger + + // Set expections and return values. + messenger.On("Start").Return(fmt.Errorf("messenger failed to start")) + messenger.On("Stop").Return(nil) + + status, err := driver.Start() + assert.Error(t, err) + assert.Equal(t, mesosproto.Status_DRIVER_NOT_STARTED, status) + + messenger.Stop() + + messenger.AssertNumberOfCalls(t, "Start", 1) + messenger.AssertNumberOfCalls(t, "Stop", 1) +} + +func TestExecutorDriverStartFailedToSendRegisterMessage(t *testing.T) { + exec := NewMockedExecutor() + + setEnvironments(t, "", false) + driver := newTestExecutorDriver(t, exec) + messenger := messenger.NewMockedMessenger() + driver.messenger = messenger + + // Set expections and return values. + messenger.On("Start").Return(nil) + messenger.On("UPID").Return(&upid.UPID{}) + messenger.On("Send").Return(fmt.Errorf("messenger failed to send")) + messenger.On("Stop").Return(nil) + + status, err := driver.Start() + assert.Error(t, err) + assert.Equal(t, mesosproto.Status_DRIVER_NOT_STARTED, status) + + messenger.AssertNumberOfCalls(t, "Start", 1) + messenger.AssertNumberOfCalls(t, "UPID", 1) + messenger.AssertNumberOfCalls(t, "Send", 1) + messenger.AssertNumberOfCalls(t, "Stop", 1) +} + +func TestExecutorDriverStartSucceed(t *testing.T) { + setEnvironments(t, "", false) + + exec := NewMockedExecutor() + exec.On("Error").Return(nil) + + driver := newTestExecutorDriver(t, exec) + + messenger := messenger.NewMockedMessenger() + driver.messenger = messenger + messenger.On("Start").Return(nil) + messenger.On("UPID").Return(&upid.UPID{}) + messenger.On("Send").Return(nil) + messenger.On("Stop").Return(nil) + + checker := healthchecker.NewMockedHealthChecker() + checker.On("Start").Return() + checker.On("Stop").Return() + + assert.True(t, driver.stopped) + status, err := driver.Start() + assert.False(t, driver.stopped) + assert.NoError(t, err) + assert.Equal(t, mesosproto.Status_DRIVER_RUNNING, status) + + messenger.AssertNumberOfCalls(t, "Start", 1) + messenger.AssertNumberOfCalls(t, "UPID", 1) + messenger.AssertNumberOfCalls(t, "Send", 1) +} + +func TestExecutorDriverRun(t *testing.T) { + setEnvironments(t, "", false) + + // Set expections and return values. + messenger := messenger.NewMockedMessenger() + messenger.On("Start").Return(nil) + messenger.On("UPID").Return(&upid.UPID{}) + messenger.On("Send").Return(nil) + messenger.On("Stop").Return(nil) + + exec := NewMockedExecutor() + exec.On("Error").Return(nil) + + driver := newTestExecutorDriver(t, exec) + driver.messenger = messenger + assert.True(t, driver.stopped) + + checker := healthchecker.NewMockedHealthChecker() + checker.On("Start").Return() + checker.On("Stop").Return() + + go func() { + stat, err := driver.Run() + assert.NoError(t, err) + assert.Equal(t, mesosproto.Status_DRIVER_STOPPED, stat) + }() + time.Sleep(time.Millisecond * 1) // allow for things to settle + assert.False(t, driver.stopped) + assert.Equal(t, mesosproto.Status_DRIVER_RUNNING, driver.Status()) + + // mannually close it all + driver.setStatus(mesosproto.Status_DRIVER_STOPPED) + close(driver.stopCh) + time.Sleep(time.Millisecond * 1) +} + +func TestExecutorDriverJoin(t *testing.T) { + setEnvironments(t, "", false) + + // Set expections and return values. + messenger := messenger.NewMockedMessenger() + messenger.On("Start").Return(nil) + messenger.On("UPID").Return(&upid.UPID{}) + messenger.On("Send").Return(nil) + messenger.On("Stop").Return(nil) + + exec := NewMockedExecutor() + exec.On("Error").Return(nil) + + driver := newTestExecutorDriver(t, exec) + driver.messenger = messenger + assert.True(t, driver.stopped) + + checker := healthchecker.NewMockedHealthChecker() + checker.On("Start").Return() + checker.On("Stop").Return() + + stat, err := driver.Start() + assert.NoError(t, err) + assert.False(t, driver.stopped) + assert.Equal(t, mesosproto.Status_DRIVER_RUNNING, stat) + + testCh := make(chan mesosproto.Status) + go func() { + stat, _ := driver.Join() + testCh <- stat + }() + + close(driver.stopCh) // manually stopping + stat = <-testCh // when Stop() is called, stat will be DRIVER_STOPPED. + +} + +func TestExecutorDriverAbort(t *testing.T) { + statusChan := make(chan mesosproto.Status) + driver, messenger, _ := createTestExecutorDriver(t) + + assert.True(t, driver.stopped) + stat, err := driver.Start() + assert.False(t, driver.stopped) + assert.NoError(t, err) + assert.Equal(t, mesosproto.Status_DRIVER_RUNNING, stat) + go func() { + st, _ := driver.Join() + statusChan <- st + }() + + stat, err = driver.Abort() + assert.NoError(t, err) + assert.Equal(t, mesosproto.Status_DRIVER_ABORTED, stat) + assert.Equal(t, mesosproto.Status_DRIVER_ABORTED, <-statusChan) + assert.True(t, driver.stopped) + + // Abort for the second time, should return directly. + stat, err = driver.Abort() + assert.Error(t, err) + assert.Equal(t, mesosproto.Status_DRIVER_ABORTED, stat) + stat, err = driver.Stop() + assert.Error(t, err) + assert.Equal(t, mesosproto.Status_DRIVER_ABORTED, stat) + assert.True(t, driver.stopped) + + // Restart should not start. + stat, err = driver.Start() + assert.True(t, driver.stopped) + assert.Error(t, err) + assert.Equal(t, mesosproto.Status_DRIVER_ABORTED, stat) + + messenger.AssertNumberOfCalls(t, "Start", 1) + messenger.AssertNumberOfCalls(t, "UPID", 1) + messenger.AssertNumberOfCalls(t, "Send", 1) + messenger.AssertNumberOfCalls(t, "Stop", 1) +} + +func TestExecutorDriverStop(t *testing.T) { + statusChan := make(chan mesosproto.Status) + driver, messenger, _ := createTestExecutorDriver(t) + + assert.True(t, driver.stopped) + stat, err := driver.Start() + assert.False(t, driver.stopped) + assert.NoError(t, err) + assert.Equal(t, mesosproto.Status_DRIVER_RUNNING, stat) + go func() { + stat, _ := driver.Join() + statusChan <- stat + }() + stat, err = driver.Stop() + assert.NoError(t, err) + assert.Equal(t, mesosproto.Status_DRIVER_STOPPED, stat) + assert.Equal(t, mesosproto.Status_DRIVER_STOPPED, <-statusChan) + assert.True(t, driver.stopped) + + // Stop for the second time, should return directly. + stat, err = driver.Stop() + assert.Error(t, err) + assert.Equal(t, mesosproto.Status_DRIVER_STOPPED, stat) + stat, err = driver.Abort() + assert.Error(t, err) + assert.Equal(t, mesosproto.Status_DRIVER_STOPPED, stat) + assert.True(t, driver.stopped) + + // Restart should not start. + stat, err = driver.Start() + assert.True(t, driver.stopped) + assert.Error(t, err) + assert.Equal(t, mesosproto.Status_DRIVER_STOPPED, stat) + + messenger.AssertNumberOfCalls(t, "Start", 1) + messenger.AssertNumberOfCalls(t, "UPID", 1) + messenger.AssertNumberOfCalls(t, "Send", 1) + messenger.AssertNumberOfCalls(t, "Stop", 1) +} + +func TestExecutorDriverSendStatusUpdate(t *testing.T) { + + driver, _, _ := createTestExecutorDriver(t) + + stat, err := driver.Start() + assert.NoError(t, err) + assert.Equal(t, mesosproto.Status_DRIVER_RUNNING, stat) + driver.connected = true + driver.stopped = false + + taskStatus := util.NewTaskStatus( + util.NewTaskID("test-task-001"), + mesosproto.TaskState_TASK_RUNNING, + ) + + stat, err = driver.SendStatusUpdate(taskStatus) + assert.NoError(t, err) + assert.Equal(t, mesosproto.Status_DRIVER_RUNNING, stat) +} + +func TestExecutorDriverSendStatusUpdateStaging(t *testing.T) { + + driver, _, _ := createTestExecutorDriver(t) + + exec := NewMockedExecutor() + exec.On("Error").Return(nil) + driver.exec = exec + + stat, err := driver.Start() + assert.NoError(t, err) + assert.Equal(t, mesosproto.Status_DRIVER_RUNNING, stat) + driver.connected = true + driver.stopped = false + + taskStatus := util.NewTaskStatus( + util.NewTaskID("test-task-001"), + mesosproto.TaskState_TASK_STAGING, + ) + + stat, err = driver.SendStatusUpdate(taskStatus) + assert.Error(t, err) + assert.Equal(t, mesosproto.Status_DRIVER_ABORTED, stat) +} + +func TestExecutorDriverSendFrameworkMessage(t *testing.T) { + + driver, _, _ := createTestExecutorDriver(t) + + stat, err := driver.SendFrameworkMessage("failed") + assert.Error(t, err) + + stat, err = driver.Start() + assert.NoError(t, err) + assert.Equal(t, mesosproto.Status_DRIVER_RUNNING, stat) + driver.connected = true + driver.stopped = false + + stat, err = driver.SendFrameworkMessage("Testing Mesos") + assert.NoError(t, err) + assert.Equal(t, mesosproto.Status_DRIVER_RUNNING, stat) +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/executor/mocked_executor.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/executor/mocked_executor.go new file mode 100644 index 00000000000..2b4853f3d55 --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/executor/mocked_executor.go @@ -0,0 +1,74 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package executor + +import ( + "github.com/mesos/mesos-go/mesosproto" + "github.com/stretchr/testify/mock" +) + +// MockedExecutor is used for testing the executor driver. +type MockedExecutor struct { + mock.Mock +} + +// NewMockedExecutor returns a mocked executor. +func NewMockedExecutor() *MockedExecutor { + return &MockedExecutor{} +} + +// Registered implements the Registered handler. +func (e *MockedExecutor) Registered(ExecutorDriver, *mesosproto.ExecutorInfo, *mesosproto.FrameworkInfo, *mesosproto.SlaveInfo) { + e.Called() +} + +// Reregistered implements the Reregistered handler. +func (e *MockedExecutor) Reregistered(ExecutorDriver, *mesosproto.SlaveInfo) { + e.Called() +} + +// Disconnected implements the Disconnected handler. +func (e *MockedExecutor) Disconnected(ExecutorDriver) { + e.Called() +} + +// LaunchTask implements the LaunchTask handler. +func (e *MockedExecutor) LaunchTask(ExecutorDriver, *mesosproto.TaskInfo) { + e.Called() +} + +// KillTask implements the KillTask handler. +func (e *MockedExecutor) KillTask(ExecutorDriver, *mesosproto.TaskID) { + e.Called() +} + +// FrameworkMessage implements the FrameworkMessage handler. +func (e *MockedExecutor) FrameworkMessage(ExecutorDriver, string) { + e.Called() +} + +// Shutdown implements the Shutdown handler. +func (e *MockedExecutor) Shutdown(ExecutorDriver) { + e.Called() +} + +// Error implements the Error handler. +func (e *MockedExecutor) Error(ExecutorDriver, string) { + e.Called() +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/README.md b/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/README.md new file mode 100644 index 00000000000..da0673e78a0 --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/README.md @@ -0,0 +1,39 @@ +####Benchmark of the messenger. + +```shell +$ go test -v -run=Benckmark* -bench=. +PASS +BenchmarkMessengerSendSmallMessage 50000 70568 ns/op +BenchmarkMessengerSendMediumMessage 50000 70265 ns/op +BenchmarkMessengerSendBigMessage 50000 72693 ns/op +BenchmarkMessengerSendLargeMessage 50000 72896 ns/op +BenchmarkMessengerSendMixedMessage 50000 72631 ns/op +BenchmarkMessengerSendRecvSmallMessage 20000 78409 ns/op +BenchmarkMessengerSendRecvMediumMessage 20000 80471 ns/op +BenchmarkMessengerSendRecvBigMessage 20000 82629 ns/op +BenchmarkMessengerSendRecvLargeMessage 20000 85987 ns/op +BenchmarkMessengerSendRecvMixedMessage 20000 83678 ns/op +ok github.com/mesos/mesos-go/messenger 115.135s + +$ go test -v -run=Benckmark* -bench=. -cpu=4 -send-routines=4 2>/dev/null +PASS +BenchmarkMessengerSendSmallMessage-4 50000 35529 ns/op +BenchmarkMessengerSendMediumMessage-4 50000 35997 ns/op +BenchmarkMessengerSendBigMessage-4 50000 36871 ns/op +BenchmarkMessengerSendLargeMessage-4 50000 37310 ns/op +BenchmarkMessengerSendMixedMessage-4 50000 37419 ns/op +BenchmarkMessengerSendRecvSmallMessage-4 50000 39320 ns/op +BenchmarkMessengerSendRecvMediumMessage-4 50000 41990 ns/op +BenchmarkMessengerSendRecvBigMessage-4 50000 42157 ns/op +BenchmarkMessengerSendRecvLargeMessage-4 50000 45472 ns/op +BenchmarkMessengerSendRecvMixedMessage-4 50000 47393 ns/op +ok github.com/mesos/mesos-go/messenger 105.173s +``` + +####environment: + +``` +OS: Linux yifan-laptop 3.13.0-32-generic #57-Ubuntu SMP Tue Jul 15 03:51:08 UTC 2014 x86_64 x86_64 x86_64 GNU/Linux +CPU: Intel(R) Core(TM) i5-3210M CPU @ 2.50GHz +MEM: 4G DDR3 1600MHz +``` diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/doc.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/doc.go new file mode 100644 index 00000000000..3b7bd8147c8 --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/doc.go @@ -0,0 +1,7 @@ +/* +Package messenger includes a messenger and a transporter. +The messenger provides interfaces to send a protobuf message +through the underlying transporter. It also dispatches messages +to installed handlers. +*/ +package messenger diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/http_transporter.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/http_transporter.go new file mode 100644 index 00000000000..30370b04835 --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/http_transporter.go @@ -0,0 +1,371 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package messenger + +import ( + "bytes" + "fmt" + "github.com/mesos/mesos-go/upid" + "io/ioutil" + "net" + "net/http" + "net/url" + "strings" + "sync" + "sync/atomic" + "syscall" + "time" + + log "github.com/golang/glog" + "golang.org/x/net/context" +) + +var ( + discardOnStopError = fmt.Errorf("discarding message because transport is shutting down") +) + +// HTTPTransporter implements the interfaces of the Transporter. +type HTTPTransporter struct { + // If the host is empty("") then it will listen on localhost. + // If the port is empty("") then it will listen on random port. + upid *upid.UPID + listener net.Listener // TODO(yifan): Change to TCPListener. + mux *http.ServeMux + tr *http.Transport + client *http.Client // TODO(yifan): Set read/write deadline. + messageQueue chan *Message + address net.IP // optional binding address + started chan struct{} + stopped chan struct{} + stopping int32 + lifeLock sync.Mutex // protect lifecycle (start/stop) funcs +} + +// NewHTTPTransporter creates a new http transporter with an optional binding address. +func NewHTTPTransporter(upid *upid.UPID, address net.IP) *HTTPTransporter { + tr := &http.Transport{} + result := &HTTPTransporter{ + upid: upid, + messageQueue: make(chan *Message, defaultQueueSize), + mux: http.NewServeMux(), + client: &http.Client{Transport: tr}, + tr: tr, + address: address, + started: make(chan struct{}), + stopped: make(chan struct{}), + } + close(result.stopped) + return result +} + +// some network errors are probably recoverable, attempt to determine that here. +func isRecoverableError(err error) bool { + if urlErr, ok := err.(*url.Error); ok { + log.V(2).Infof("checking url.Error for recoverability") + return urlErr.Op == "Post" && isRecoverableError(urlErr.Err) + } else if netErr, ok := err.(*net.OpError); ok && netErr.Err != nil { + log.V(2).Infof("checking net.OpError for recoverability: %#v", err) + if netErr.Temporary() { + return true + } + //TODO(jdef) this is pretty hackish, there's probably a better way + return (netErr.Op == "dial" && netErr.Net == "tcp" && netErr.Err == syscall.ECONNREFUSED) + } + log.V(2).Infof("unrecoverable error: %#v", err) + return false +} + +type recoverableError struct { + Err error +} + +func (e *recoverableError) Error() string { + if e == nil { + return "" + } + return e.Err.Error() +} + +// Send sends the message to its specified upid. +func (t *HTTPTransporter) Send(ctx context.Context, msg *Message) (sendError error) { + log.V(2).Infof("Sending message to %v via http\n", msg.UPID) + req, err := t.makeLibprocessRequest(msg) + if err != nil { + log.Errorf("Failed to make libprocess request: %v\n", err) + return err + } + duration := 1 * time.Second + for attempt := 0; attempt < 5; attempt++ { //TODO(jdef) extract/parameterize constant + if sendError != nil { + duration *= 2 + log.Warningf("attempting to recover from error '%v', waiting before retry: %v", sendError, duration) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(duration): + // ..retry request, continue + case <-t.stopped: + return discardOnStopError + } + } + sendError = t.httpDo(ctx, req, func(resp *http.Response, err error) error { + if err != nil { + if isRecoverableError(err) { + return &recoverableError{Err: err} + } + log.Infof("Failed to POST: %v\n", err) + return err + } + defer resp.Body.Close() + + // ensure master acknowledgement. + if (resp.StatusCode != http.StatusOK) && + (resp.StatusCode != http.StatusAccepted) { + msg := fmt.Sprintf("Master %s rejected %s. Returned status %s.", + msg.UPID, msg.RequestURI(), resp.Status) + log.Warning(msg) + return fmt.Errorf(msg) + } + return nil + }) + if sendError == nil { + // success + return + } else if _, ok := sendError.(*recoverableError); ok { + // recoverable, attempt backoff? + continue + } + // unrecoverable + break + } + if recoverable, ok := sendError.(*recoverableError); ok { + sendError = recoverable.Err + } + return +} + +func (t *HTTPTransporter) httpDo(ctx context.Context, req *http.Request, f func(*http.Response, error) error) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-t.stopped: + return discardOnStopError + default: // continue + } + + c := make(chan error, 1) + go func() { c <- f(t.client.Do(req)) }() + select { + case <-ctx.Done(): + t.tr.CancelRequest(req) + <-c // Wait for f to return. + return ctx.Err() + case err := <-c: + return err + case <-t.stopped: + t.tr.CancelRequest(req) + <-c // Wait for f to return. + return discardOnStopError + } +} + +// Recv returns the message, one at a time. +func (t *HTTPTransporter) Recv() (*Message, error) { + select { + default: + select { + case msg := <-t.messageQueue: + return msg, nil + case <-t.stopped: + } + case <-t.stopped: + } + return nil, discardOnStopError +} + +//Inject places a message into the incoming message queue. +func (t *HTTPTransporter) Inject(ctx context.Context, msg *Message) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-t.stopped: + return discardOnStopError + default: // continue + } + + select { + case t.messageQueue <- msg: + return nil + case <-ctx.Done(): + return ctx.Err() + case <-t.stopped: + return discardOnStopError + } +} + +// Install the request URI according to the message's name. +func (t *HTTPTransporter) Install(msgName string) { + requestURI := fmt.Sprintf("/%s/%s", t.upid.ID, msgName) + t.mux.HandleFunc(requestURI, t.messageHandler) +} + +// Listen starts listen on UPID. If UPID is empty, the transporter +// will listen on a random port, and then fill the UPID with the +// host:port it is listening. +func (t *HTTPTransporter) listen() error { + var host string + if t.address != nil { + host = t.address.String() + } else { + host = t.upid.Host + } + port := t.upid.Port + // NOTE: Explicitly specifies IPv4 because Libprocess + // only supports IPv4 for now. + ln, err := net.Listen("tcp4", net.JoinHostPort(host, port)) + if err != nil { + log.Errorf("HTTPTransporter failed to listen: %v\n", err) + return err + } + // Save the host:port in case they are not specified in upid. + host, port, _ = net.SplitHostPort(ln.Addr().String()) + t.upid.Host, t.upid.Port = host, port + t.listener = ln + return nil +} + +// Start starts the http transporter +func (t *HTTPTransporter) Start() <-chan error { + t.lifeLock.Lock() + defer t.lifeLock.Unlock() + + select { + case <-t.started: + // already started + return nil + case <-t.stopped: + defer close(t.started) + t.stopped = make(chan struct{}) + atomic.StoreInt32(&t.stopping, 0) + default: + panic("not started, not stopped, what am i? how can i start?") + } + + ch := make(chan error, 1) + if err := t.listen(); err != nil { + ch <- err + } else { + // TODO(yifan): Set read/write deadline. + log.Infof("http transport listening on %v", t.listener.Addr()) + go func() { + err := http.Serve(t.listener, t.mux) + if atomic.CompareAndSwapInt32(&t.stopping, 1, 0) { + ch <- nil + } else { + ch <- err + } + }() + } + return ch +} + +// Stop stops the http transporter by closing the listener. +func (t *HTTPTransporter) Stop(graceful bool) error { + t.lifeLock.Lock() + defer t.lifeLock.Unlock() + + select { + case <-t.stopped: + // already stopped + return nil + case <-t.started: + defer close(t.stopped) + t.started = make(chan struct{}) + default: + panic("not started, not stopped, what am i? how can i stop?") + } + //TODO(jdef) if graceful, wait for pending requests to terminate + atomic.StoreInt32(&t.stopping, 1) + err := t.listener.Close() + return err +} + +// UPID returns the upid of the transporter. +func (t *HTTPTransporter) UPID() *upid.UPID { + return t.upid +} + +func (t *HTTPTransporter) messageHandler(w http.ResponseWriter, r *http.Request) { + // Verify it's a libprocess request. + from, err := getLibprocessFrom(r) + if err != nil { + log.Errorf("Ignoring the request, because it's not a libprocess request: %v\n", err) + w.WriteHeader(http.StatusBadRequest) + return + } + data, err := ioutil.ReadAll(r.Body) + if err != nil { + log.Errorf("Failed to read HTTP body: %v\n", err) + w.WriteHeader(http.StatusBadRequest) + return + } + log.V(2).Infof("Receiving message from %v, length %v\n", from, len(data)) + w.WriteHeader(http.StatusAccepted) + t.messageQueue <- &Message{ + UPID: from, + Name: extractNameFromRequestURI(r.RequestURI), + Bytes: data, + } +} + +func (t *HTTPTransporter) makeLibprocessRequest(msg *Message) (*http.Request, error) { + if msg.UPID == nil { + panic(fmt.Sprintf("message is missing UPID: %+v", msg)) + } + hostport := net.JoinHostPort(msg.UPID.Host, msg.UPID.Port) + targetURL := fmt.Sprintf("http://%s%s", hostport, msg.RequestURI()) + log.V(2).Infof("libproc target URL %s", targetURL) + req, err := http.NewRequest("POST", targetURL, bytes.NewReader(msg.Bytes)) + if err != nil { + log.Errorf("Failed to create request: %v\n", err) + return nil, err + } + req.Header.Add("Libprocess-From", t.upid.String()) + req.Header.Add("Content-Type", "application/x-protobuf") + req.Header.Add("Connection", "Keep-Alive") + + return req, nil +} + +func getLibprocessFrom(r *http.Request) (*upid.UPID, error) { + if r.Method != "POST" { + return nil, fmt.Errorf("Not a POST request") + } + ua, ok := r.Header["User-Agent"] + if ok && strings.HasPrefix(ua[0], "libprocess/") { + // TODO(yifan): Just take the first field for now. + return upid.Parse(ua[0][len("libprocess/"):]) + } + lf, ok := r.Header["Libprocess-From"] + if ok { + // TODO(yifan): Just take the first field for now. + return upid.Parse(lf[0]) + } + return nil, fmt.Errorf("Cannot find 'User-Agent' or 'Libprocess-From'") +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/http_transporter_test.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/http_transporter_test.go new file mode 100644 index 00000000000..e1d14096526 --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/http_transporter_test.go @@ -0,0 +1,273 @@ +package messenger + +import ( + "fmt" + "net/http" + "net/http/httptest" + "regexp" + "strconv" + "testing" + "time" + + "github.com/mesos/mesos-go/messenger/testmessage" + "github.com/mesos/mesos-go/upid" + "github.com/stretchr/testify/assert" + "golang.org/x/net/context" +) + +func TestTransporterNew(t *testing.T) { + id, err := upid.Parse(fmt.Sprintf("mesos1@localhost:%d", getNewPort())) + assert.NoError(t, err) + trans := NewHTTPTransporter(id, nil) + assert.NotNil(t, trans) + assert.NotNil(t, trans.upid) + assert.NotNil(t, trans.messageQueue) + assert.NotNil(t, trans.client) +} + +func TestTransporterSend(t *testing.T) { + idreg := regexp.MustCompile(`[A-Za-z0-9_\-]+@[A-Za-z0-9_\-\.]+:[0-9]+`) + serverId := "testserver" + + // setup mesos client-side + fromUpid, err := upid.Parse(fmt.Sprintf("mesos1@localhost:%d", getNewPort())) + assert.NoError(t, err) + + protoMsg := testmessage.GenerateSmallMessage() + msgName := getMessageName(protoMsg) + msg := &Message{ + Name: msgName, + ProtoMessage: protoMsg, + } + requestURI := fmt.Sprintf("/%s/%s", serverId, msgName) + + // setup server-side + msgReceived := make(chan struct{}) + srv := makeMockServer(requestURI, func(rsp http.ResponseWriter, req *http.Request) { + defer close(msgReceived) + from := req.Header.Get("Libprocess-From") + assert.NotEmpty(t, from) + assert.True(t, idreg.MatchString(from), fmt.Sprintf("regexp failed for '%v'", from)) + }) + defer srv.Close() + toUpid, err := upid.Parse(fmt.Sprintf("%s@%s", serverId, srv.Listener.Addr().String())) + assert.NoError(t, err) + + // make transport call. + transport := NewHTTPTransporter(fromUpid, nil) + errch := transport.Start() + defer transport.Stop(false) + + msg.UPID = toUpid + err = transport.Send(context.TODO(), msg) + assert.NoError(t, err) + + select { + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for message receipt") + case <-msgReceived: + case err := <-errch: + if err != nil { + t.Fatalf(err.Error()) + } + } +} + +func TestTransporter_DiscardedSend(t *testing.T) { + serverId := "testserver" + + // setup mesos client-side + fromUpid, err := upid.Parse(fmt.Sprintf("mesos1@localhost:%d", getNewPort())) + assert.NoError(t, err) + + protoMsg := testmessage.GenerateSmallMessage() + msgName := getMessageName(protoMsg) + msg := &Message{ + Name: msgName, + ProtoMessage: protoMsg, + } + requestURI := fmt.Sprintf("/%s/%s", serverId, msgName) + + // setup server-side + msgReceived := make(chan struct{}) + srv := makeMockServer(requestURI, func(rsp http.ResponseWriter, req *http.Request) { + close(msgReceived) + time.Sleep(2 * time.Second) // long enough that we should be able to stop it + }) + defer srv.Close() + toUpid, err := upid.Parse(fmt.Sprintf("%s@%s", serverId, srv.Listener.Addr().String())) + assert.NoError(t, err) + + // make transport call. + transport := NewHTTPTransporter(fromUpid, nil) + errch := transport.Start() + defer transport.Stop(false) + + msg.UPID = toUpid + senderr := make(chan struct{}) + go func() { + defer close(senderr) + err = transport.Send(context.TODO(), msg) + assert.NotNil(t, err) + assert.Equal(t, discardOnStopError, err) + }() + + // wait for message to be received + select { + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for message receipt") + return + case <-msgReceived: + transport.Stop(false) + case err := <-errch: + if err != nil { + t.Fatalf(err.Error()) + return + } + } + + // wait for send() to process discarded-error + select { + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for aborted send") + return + case <-senderr: // continue + } +} + +func TestTransporterStartAndRcvd(t *testing.T) { + serverId := "testserver" + serverPort := getNewPort() + serverAddr := "127.0.0.1:" + strconv.Itoa(serverPort) + protoMsg := testmessage.GenerateSmallMessage() + msgName := getMessageName(protoMsg) + ctrl := make(chan struct{}) + + // setup receiver (server) process + rcvPid, err := upid.Parse(fmt.Sprintf("%s@%s", serverId, serverAddr)) + assert.NoError(t, err) + receiver := NewHTTPTransporter(rcvPid, nil) + receiver.Install(msgName) + + go func() { + defer close(ctrl) + msg, err := receiver.Recv() + assert.Nil(t, err) + assert.NotNil(t, msg) + if msg != nil { + assert.Equal(t, msgName, msg.Name) + } + }() + + errch := receiver.Start() + defer receiver.Stop(false) + assert.NotNil(t, errch) + + time.Sleep(time.Millisecond * 7) // time to catchup + + // setup sender (client) process + sndUpid, err := upid.Parse(fmt.Sprintf("mesos1@localhost:%d", getNewPort())) + assert.NoError(t, err) + + sender := NewHTTPTransporter(sndUpid, nil) + msg := &Message{ + UPID: rcvPid, + Name: msgName, + ProtoMessage: protoMsg, + } + errch2 := sender.Start() + defer sender.Stop(false) + + sender.Send(context.TODO(), msg) + + select { + case <-time.After(time.Second * 5): + t.Fatalf("Timeout") + case <-ctrl: + case err := <-errch: + if err != nil { + t.Fatalf(err.Error()) + } + case err := <-errch2: + if err != nil { + t.Fatalf(err.Error()) + } + } +} + +func TestTransporterStartAndInject(t *testing.T) { + serverId := "testserver" + serverPort := getNewPort() + serverAddr := "127.0.0.1:" + strconv.Itoa(serverPort) + protoMsg := testmessage.GenerateSmallMessage() + msgName := getMessageName(protoMsg) + ctrl := make(chan struct{}) + + // setup receiver (server) process + rcvPid, err := upid.Parse(fmt.Sprintf("%s@%s", serverId, serverAddr)) + assert.NoError(t, err) + receiver := NewHTTPTransporter(rcvPid, nil) + receiver.Install(msgName) + errch := receiver.Start() + defer receiver.Stop(false) + + msg := &Message{ + UPID: rcvPid, + Name: msgName, + ProtoMessage: protoMsg, + } + + receiver.Inject(context.TODO(), msg) + + go func() { + defer close(ctrl) + msg, err := receiver.Recv() + assert.Nil(t, err) + assert.NotNil(t, msg) + if msg != nil { + assert.Equal(t, msgName, msg.Name) + } + }() + + select { + case <-time.After(time.Second * 1): + t.Fatalf("Timeout") + case <-ctrl: + case err := <-errch: + if err != nil { + t.Fatalf(err.Error()) + } + } +} + +func TestTransporterStartAndStop(t *testing.T) { + serverId := "testserver" + serverPort := getNewPort() + serverAddr := "127.0.0.1:" + strconv.Itoa(serverPort) + + // setup receiver (server) process + rcvPid, err := upid.Parse(fmt.Sprintf("%s@%s", serverId, serverAddr)) + assert.NoError(t, err) + receiver := NewHTTPTransporter(rcvPid, nil) + + errch := receiver.Start() + assert.NotNil(t, errch) + + time.Sleep(1 * time.Second) + receiver.Stop(false) + + select { + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for transport to stop") + case err := <-errch: + if err != nil { + t.Fatalf(err.Error()) + } + } +} + +func makeMockServer(path string, handler func(rsp http.ResponseWriter, req *http.Request)) *httptest.Server { + mux := http.NewServeMux() + mux.HandleFunc(path, handler) + return httptest.NewServer(mux) +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/message.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/message.go new file mode 100644 index 00000000000..331317f45cd --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/message.go @@ -0,0 +1,45 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package messenger + +import ( + "fmt" + "strings" + + "github.com/gogo/protobuf/proto" + "github.com/mesos/mesos-go/upid" +) + +// Message defines the type that passes in the Messenger. +type Message struct { + UPID *upid.UPID + Name string + ProtoMessage proto.Message + Bytes []byte +} + +// RequestURI returns the request URI of the message. +func (m *Message) RequestURI() string { + return fmt.Sprintf("/%s/%s", m.UPID.ID, m.Name) +} + +// NOTE: This should not fail or panic. +func extractNameFromRequestURI(requestURI string) string { + return strings.Split(requestURI, "/")[2] +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/messenger.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/messenger.go new file mode 100644 index 00000000000..5b242e5bce3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/messenger.go @@ -0,0 +1,357 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package messenger + +import ( + "flag" + "fmt" + "net" + "reflect" + "strconv" + "time" + + "github.com/gogo/protobuf/proto" + log "github.com/golang/glog" + mesos "github.com/mesos/mesos-go/mesosproto" + "github.com/mesos/mesos-go/mesosutil/process" + "github.com/mesos/mesos-go/upid" + "golang.org/x/net/context" +) + +const ( + defaultQueueSize = 1024 + preparePeriod = time.Second * 1 +) + +var ( + sendRoutines int + encodeRoutines int + decodeRoutines int +) + +func init() { + flag.IntVar(&sendRoutines, "send-routines", 1, "Number of network sending routines") + flag.IntVar(&encodeRoutines, "encode-routines", 1, "Number of encoding routines") + flag.IntVar(&decodeRoutines, "decode-routines", 1, "Number of decoding routines") +} + +// MessageHandler is the callback of the message. When the callback +// is invoked, the sender's upid and the message is passed to the callback. +type MessageHandler func(from *upid.UPID, pbMsg proto.Message) + +// Messenger defines the interfaces that should be implemented. +type Messenger interface { + Install(handler MessageHandler, msg proto.Message) error + Send(ctx context.Context, upid *upid.UPID, msg proto.Message) error + Route(ctx context.Context, from *upid.UPID, msg proto.Message) error + Start() error + Stop() error + UPID() *upid.UPID +} + +// MesosMessenger is an implementation of the Messenger interface. +type MesosMessenger struct { + upid *upid.UPID + encodingQueue chan *Message + sendingQueue chan *Message + installedMessages map[string]reflect.Type + installedHandlers map[string]MessageHandler + stop chan struct{} + tr Transporter +} + +// create a new default messenger (HTTP). If a non-nil, non-wildcard bindingAddress is +// specified then it will be used for both the UPID and Transport binding address. Otherwise +// hostname is resolved to an IP address and the UPID.Host is set to that address and the +// bindingAddress is passed through to the Transport. +func ForHostname(proc *process.Process, hostname string, bindingAddress net.IP, port uint16) (Messenger, error) { + upid := &upid.UPID{ + ID: proc.Label(), + Port: strconv.Itoa(int(port)), + } + if bindingAddress != nil && "0.0.0.0" != bindingAddress.String() { + upid.Host = bindingAddress.String() + } else { + ips, err := net.LookupIP(hostname) + if err != nil { + return nil, err + } + // try to find an ipv4 and use that + ip := net.IP(nil) + for _, addr := range ips { + if ip = addr.To4(); ip != nil { + break + } + } + if ip == nil { + // no ipv4? best guess, just take the first addr + if len(ips) > 0 { + ip = ips[0] + log.Warningf("failed to find an IPv4 address for '%v', best guess is '%v'", hostname, ip) + } else { + return nil, fmt.Errorf("failed to determine IP address for host '%v'", hostname) + } + } + upid.Host = ip.String() + } + return NewHttpWithBindingAddress(upid, bindingAddress), nil +} + +// NewMesosMessenger creates a new mesos messenger. +func NewHttp(upid *upid.UPID) *MesosMessenger { + return NewHttpWithBindingAddress(upid, nil) +} + +func NewHttpWithBindingAddress(upid *upid.UPID, address net.IP) *MesosMessenger { + return New(upid, NewHTTPTransporter(upid, address)) +} + +func New(upid *upid.UPID, t Transporter) *MesosMessenger { + return &MesosMessenger{ + upid: upid, + encodingQueue: make(chan *Message, defaultQueueSize), + sendingQueue: make(chan *Message, defaultQueueSize), + installedMessages: make(map[string]reflect.Type), + installedHandlers: make(map[string]MessageHandler), + tr: t, + } +} + +/// Install installs the handler with the given message. +func (m *MesosMessenger) Install(handler MessageHandler, msg proto.Message) error { + // Check if the message is a pointer. + mtype := reflect.TypeOf(msg) + if mtype.Kind() != reflect.Ptr { + return fmt.Errorf("Message %v is not a Ptr type", msg) + } + + // Check if the message is already installed. + name := getMessageName(msg) + if _, ok := m.installedMessages[name]; ok { + return fmt.Errorf("Message %v is already installed", name) + } + m.installedMessages[name] = mtype.Elem() + m.installedHandlers[name] = handler + m.tr.Install(name) + return nil +} + +// Send puts a message into the outgoing queue, waiting to be sent. +// With buffered channels, this will not block under moderate throughput. +// When an error is generated, the error can be communicated by placing +// a message on the incoming queue to be handled upstream. +func (m *MesosMessenger) Send(ctx context.Context, upid *upid.UPID, msg proto.Message) error { + if upid == nil { + panic("cannot sent a message to a nil pid") + } else if upid.Equal(m.upid) { + return fmt.Errorf("Send the message to self") + } + name := getMessageName(msg) + log.V(2).Infof("Sending message %v to %v\n", name, upid) + select { + case <-ctx.Done(): + return ctx.Err() + case m.encodingQueue <- &Message{upid, name, msg, nil}: + return nil + } +} + +// Route puts a message either in the incoming or outgoing queue. +// This method is useful for: +// 1) routing internal error to callback handlers +// 2) testing components without starting remote servers. +func (m *MesosMessenger) Route(ctx context.Context, upid *upid.UPID, msg proto.Message) error { + // if destination is not self, send to outbound. + if !upid.Equal(m.upid) { + return m.Send(ctx, upid, msg) + } + + data, err := proto.Marshal(msg) + if err != nil { + return err + } + name := getMessageName(msg) + return m.tr.Inject(ctx, &Message{upid, name, msg, data}) +} + +// Start starts the messenger. +func (m *MesosMessenger) Start() error { + + m.stop = make(chan struct{}) + errChan := m.tr.Start() + + select { + case err := <-errChan: + log.Errorf("failed to start messenger: %v", err) + return err + case <-time.After(preparePeriod): // continue + } + + m.upid = m.tr.UPID() + + for i := 0; i < sendRoutines; i++ { + go m.sendLoop() + } + for i := 0; i < encodeRoutines; i++ { + go m.encodeLoop() + } + for i := 0; i < decodeRoutines; i++ { + go m.decodeLoop() + } + go func() { + select { + case err := <-errChan: + if err != nil { + //TODO(jdef) should the driver abort in this case? probably + //since this messenger will never attempt to re-establish the + //transport + log.Error(err) + } + case <-m.stop: + } + }() + return nil +} + +// Stop stops the messenger and clean up all the goroutines. +func (m *MesosMessenger) Stop() error { + //TODO(jdef) don't hardcode the graceful flag here + if err := m.tr.Stop(true); err != nil { + log.Errorf("Failed to stop the transporter: %v\n", err) + return err + } + close(m.stop) + return nil +} + +// UPID returns the upid of the messenger. +func (m *MesosMessenger) UPID() *upid.UPID { + return m.upid +} + +func (m *MesosMessenger) encodeLoop() { + for { + select { + case <-m.stop: + return + case msg := <-m.encodingQueue: + e := func() error { + //TODO(jdef) implement timeout for context + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + b, err := proto.Marshal(msg.ProtoMessage) + if err != nil { + return err + } + msg.Bytes = b + select { + case <-ctx.Done(): + return ctx.Err() + case m.sendingQueue <- msg: + return nil + } + }() + if e != nil { + m.reportError(fmt.Errorf("Failed to enqueue message %v: %v", msg, e)) + } + } + } +} + +func (m *MesosMessenger) reportError(err error) { + log.V(2).Info(err) + //TODO(jdef) implement timeout for context + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + c := make(chan error, 1) + go func() { c <- m.Route(ctx, m.UPID(), &mesos.FrameworkErrorMessage{Message: proto.String(err.Error())}) }() + select { + case <-ctx.Done(): + <-c // wait for Route to return + case e := <-c: + if e != nil { + log.Errorf("failed to report error %v due to: %v", err, e) + } + } +} + +func (m *MesosMessenger) sendLoop() { + for { + select { + case <-m.stop: + return + case msg := <-m.sendingQueue: + e := func() error { + //TODO(jdef) implement timeout for context + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + c := make(chan error, 1) + go func() { c <- m.tr.Send(ctx, msg) }() + + select { + case <-ctx.Done(): + // Transport layer must use the context to detect cancelled requests. + <-c // wait for Send to return + return ctx.Err() + case err := <-c: + return err + } + }() + if e != nil { + m.reportError(fmt.Errorf("Failed to send message %v: %v", msg.Name, e)) + } + } + } +} + +// Since HTTPTransporter.Recv() is already buffered, so we don't need a 'recvLoop' here. +func (m *MesosMessenger) decodeLoop() { + for { + select { + case <-m.stop: + return + default: + } + msg, err := m.tr.Recv() + if err != nil { + if err == discardOnStopError { + log.V(1).Info("exiting decodeLoop, transport shutting down") + return + } else { + panic(fmt.Sprintf("unexpected transport error: %v", err)) + } + } + log.V(2).Infof("Receiving message %v from %v\n", msg.Name, msg.UPID) + msg.ProtoMessage = reflect.New(m.installedMessages[msg.Name]).Interface().(proto.Message) + if err := proto.Unmarshal(msg.Bytes, msg.ProtoMessage); err != nil { + log.Errorf("Failed to unmarshal message %v: %v\n", msg, err) + continue + } + // TODO(yifan): Catch panic. + m.installedHandlers[msg.Name](msg.UPID, msg.ProtoMessage) + } +} + +// getMessageName returns the name of the message in the mesos manner. +func getMessageName(msg proto.Message) string { + return fmt.Sprintf("%v.%v", "mesos.internal", reflect.TypeOf(msg).Elem().Name()) +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/messenger_test.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/messenger_test.go new file mode 100644 index 00000000000..096f201116c --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/messenger_test.go @@ -0,0 +1,433 @@ +package messenger + +import ( + "fmt" + "math/rand" + "net/http" + "net/http/httptest" + "strconv" + "sync" + "testing" + "time" + + "github.com/gogo/protobuf/proto" + "github.com/mesos/mesos-go/messenger/testmessage" + "github.com/mesos/mesos-go/upid" + "github.com/stretchr/testify/assert" + "golang.org/x/net/context" +) + +var ( + startPort = 10000 + rand.Intn(30000) + globalWG = new(sync.WaitGroup) +) + +func noopHandler(*upid.UPID, proto.Message) { + globalWG.Done() +} + +func getNewPort() int { + startPort++ + return startPort +} + +func shuffleMessages(queue *[]proto.Message) { + for i := range *queue { + index := rand.Intn(i + 1) + (*queue)[i], (*queue)[index] = (*queue)[index], (*queue)[i] + } +} + +func generateSmallMessages(n int) []proto.Message { + queue := make([]proto.Message, n) + for i := range queue { + queue[i] = testmessage.GenerateSmallMessage() + } + return queue +} + +func generateMediumMessages(n int) []proto.Message { + queue := make([]proto.Message, n) + for i := range queue { + queue[i] = testmessage.GenerateMediumMessage() + } + return queue +} + +func generateBigMessages(n int) []proto.Message { + queue := make([]proto.Message, n) + for i := range queue { + queue[i] = testmessage.GenerateBigMessage() + } + return queue +} + +func generateLargeMessages(n int) []proto.Message { + queue := make([]proto.Message, n) + for i := range queue { + queue[i] = testmessage.GenerateLargeMessage() + } + return queue +} + +func generateMixedMessages(n int) []proto.Message { + queue := make([]proto.Message, n*4) + for i := 0; i < n*4; i = i + 4 { + queue[i] = testmessage.GenerateSmallMessage() + queue[i+1] = testmessage.GenerateMediumMessage() + queue[i+2] = testmessage.GenerateBigMessage() + queue[i+3] = testmessage.GenerateLargeMessage() + } + shuffleMessages(&queue) + return queue +} + +func installMessages(t *testing.T, m Messenger, queue *[]proto.Message, counts *[]int, done chan struct{}) { + testCounts := func(counts []int, done chan struct{}) { + for i := range counts { + if counts[i] != cap(*queue)/4 { + return + } + } + close(done) + } + hander1 := func(from *upid.UPID, pbMsg proto.Message) { + (*queue) = append(*queue, pbMsg) + (*counts)[0]++ + testCounts(*counts, done) + } + hander2 := func(from *upid.UPID, pbMsg proto.Message) { + (*queue) = append(*queue, pbMsg) + (*counts)[1]++ + testCounts(*counts, done) + } + hander3 := func(from *upid.UPID, pbMsg proto.Message) { + (*queue) = append(*queue, pbMsg) + (*counts)[2]++ + testCounts(*counts, done) + } + hander4 := func(from *upid.UPID, pbMsg proto.Message) { + (*queue) = append(*queue, pbMsg) + (*counts)[3]++ + testCounts(*counts, done) + } + assert.NoError(t, m.Install(hander1, &testmessage.SmallMessage{})) + assert.NoError(t, m.Install(hander2, &testmessage.MediumMessage{})) + assert.NoError(t, m.Install(hander3, &testmessage.BigMessage{})) + assert.NoError(t, m.Install(hander4, &testmessage.LargeMessage{})) +} + +func runTestServer(b *testing.B, wg *sync.WaitGroup) *httptest.Server { + mux := http.NewServeMux() + mux.HandleFunc("/testserver/mesos.internal.SmallMessage", func(http.ResponseWriter, *http.Request) { + wg.Done() + }) + mux.HandleFunc("/testserver/mesos.internal.MediumMessage", func(http.ResponseWriter, *http.Request) { + wg.Done() + }) + mux.HandleFunc("/testserver/mesos.internal.BigMessage", func(http.ResponseWriter, *http.Request) { + wg.Done() + }) + mux.HandleFunc("/testserver/mesos.internal.LargeMessage", func(http.ResponseWriter, *http.Request) { + wg.Done() + }) + return httptest.NewServer(mux) +} + +func TestMessengerFailToInstall(t *testing.T) { + m := NewHttp(&upid.UPID{ID: "mesos"}) + handler := func(from *upid.UPID, pbMsg proto.Message) {} + assert.NotNil(t, m) + assert.NoError(t, m.Install(handler, &testmessage.SmallMessage{})) + assert.Error(t, m.Install(handler, &testmessage.SmallMessage{})) +} + +func TestMessengerFailToStart(t *testing.T) { + port := strconv.Itoa(getNewPort()) + m1 := NewHttp(&upid.UPID{ID: "mesos", Host: "localhost", Port: port}) + m2 := NewHttp(&upid.UPID{ID: "mesos", Host: "localhost", Port: port}) + assert.NoError(t, m1.Start()) + assert.Error(t, m2.Start()) +} + +func TestMessengerFailToSend(t *testing.T) { + upid, err := upid.Parse(fmt.Sprintf("mesos1@localhost:%d", getNewPort())) + assert.NoError(t, err) + m := NewHttp(upid) + assert.NoError(t, m.Start()) + assert.Error(t, m.Send(context.TODO(), upid, &testmessage.SmallMessage{})) +} + +func TestMessenger(t *testing.T) { + messages := generateMixedMessages(1000) + + upid1, err := upid.Parse(fmt.Sprintf("mesos1@localhost:%d", getNewPort())) + assert.NoError(t, err) + upid2, err := upid.Parse(fmt.Sprintf("mesos2@localhost:%d", getNewPort())) + assert.NoError(t, err) + + m1 := NewHttp(upid1) + m2 := NewHttp(upid2) + + done := make(chan struct{}) + counts := make([]int, 4) + msgQueue := make([]proto.Message, 0, len(messages)) + installMessages(t, m2, &msgQueue, &counts, done) + + assert.NoError(t, m1.Start()) + assert.NoError(t, m2.Start()) + + go func() { + for _, msg := range messages { + assert.NoError(t, m1.Send(context.TODO(), upid2, msg)) + } + }() + + select { + case <-time.After(time.Second * 10): + t.Fatalf("Timeout") + case <-done: + } + + for i := range counts { + assert.Equal(t, 1000, counts[i]) + } + assert.Equal(t, messages, msgQueue) +} + +func BenchmarkMessengerSendSmallMessage(b *testing.B) { + messages := generateSmallMessages(1000) + + wg := new(sync.WaitGroup) + wg.Add(b.N) + srv := runTestServer(b, wg) + defer srv.Close() + + upid1, err := upid.Parse(fmt.Sprintf("mesos1@localhost:%d", getNewPort())) + assert.NoError(b, err) + upid2, err := upid.Parse(fmt.Sprintf("testserver@%s", srv.Listener.Addr().String())) + + assert.NoError(b, err) + + m1 := NewHttp(upid1) + assert.NoError(b, m1.Start()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + m1.Send(context.TODO(), upid2, messages[i%1000]) + } + wg.Wait() +} + +func BenchmarkMessengerSendMediumMessage(b *testing.B) { + messages := generateMediumMessages(1000) + + wg := new(sync.WaitGroup) + wg.Add(b.N) + srv := runTestServer(b, wg) + defer srv.Close() + + upid1, err := upid.Parse(fmt.Sprintf("mesos1@localhost:%d", getNewPort())) + assert.NoError(b, err) + upid2, err := upid.Parse(fmt.Sprintf("testserver@%s", srv.Listener.Addr().String())) + assert.NoError(b, err) + + m1 := NewHttp(upid1) + assert.NoError(b, m1.Start()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + m1.Send(context.TODO(), upid2, messages[i%1000]) + } + wg.Wait() +} + +func BenchmarkMessengerSendBigMessage(b *testing.B) { + messages := generateBigMessages(1000) + + wg := new(sync.WaitGroup) + wg.Add(b.N) + srv := runTestServer(b, wg) + defer srv.Close() + + upid1, err := upid.Parse(fmt.Sprintf("mesos1@localhost:%d", getNewPort())) + assert.NoError(b, err) + upid2, err := upid.Parse(fmt.Sprintf("testserver@%s", srv.Listener.Addr().String())) + assert.NoError(b, err) + + m1 := NewHttp(upid1) + assert.NoError(b, m1.Start()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + m1.Send(context.TODO(), upid2, messages[i%1000]) + } + wg.Wait() +} + +func BenchmarkMessengerSendLargeMessage(b *testing.B) { + messages := generateLargeMessages(1000) + + wg := new(sync.WaitGroup) + wg.Add(b.N) + srv := runTestServer(b, wg) + defer srv.Close() + + upid1, err := upid.Parse(fmt.Sprintf("mesos1@localhost:%d", getNewPort())) + assert.NoError(b, err) + upid2, err := upid.Parse(fmt.Sprintf("testserver@%s", srv.Listener.Addr().String())) + assert.NoError(b, err) + + m1 := NewHttp(upid1) + assert.NoError(b, m1.Start()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + m1.Send(context.TODO(), upid2, messages[i%1000]) + } + wg.Wait() +} + +func BenchmarkMessengerSendMixedMessage(b *testing.B) { + messages := generateMixedMessages(1000) + + wg := new(sync.WaitGroup) + wg.Add(b.N) + srv := runTestServer(b, wg) + defer srv.Close() + + upid1, err := upid.Parse(fmt.Sprintf("mesos1@localhost:%d", getNewPort())) + assert.NoError(b, err) + upid2, err := upid.Parse(fmt.Sprintf("testserver@%s", srv.Listener.Addr().String())) + assert.NoError(b, err) + + m1 := NewHttp(upid1) + assert.NoError(b, m1.Start()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + m1.Send(context.TODO(), upid2, messages[i%1000]) + } + wg.Wait() +} + +func BenchmarkMessengerSendRecvSmallMessage(b *testing.B) { + globalWG.Add(b.N) + + messages := generateSmallMessages(1000) + + upid1, err := upid.Parse(fmt.Sprintf("mesos1@localhost:%d", getNewPort())) + assert.NoError(b, err) + upid2, err := upid.Parse(fmt.Sprintf("mesos2@localhost:%d", getNewPort())) + assert.NoError(b, err) + + m1 := NewHttp(upid1) + m2 := NewHttp(upid2) + assert.NoError(b, m1.Start()) + assert.NoError(b, m2.Start()) + assert.NoError(b, m2.Install(noopHandler, &testmessage.SmallMessage{})) + + time.Sleep(time.Second) // Avoid race on upid. + b.ResetTimer() + for i := 0; i < b.N; i++ { + m1.Send(context.TODO(), upid2, messages[i%1000]) + } + globalWG.Wait() +} + +func BenchmarkMessengerSendRecvMediumMessage(b *testing.B) { + globalWG.Add(b.N) + + messages := generateMediumMessages(1000) + + upid1, err := upid.Parse(fmt.Sprintf("mesos1@localhost:%d", getNewPort())) + assert.NoError(b, err) + upid2, err := upid.Parse(fmt.Sprintf("mesos2@localhost:%d", getNewPort())) + assert.NoError(b, err) + + m1 := NewHttp(upid1) + m2 := NewHttp(upid2) + assert.NoError(b, m1.Start()) + assert.NoError(b, m2.Start()) + assert.NoError(b, m2.Install(noopHandler, &testmessage.MediumMessage{})) + + time.Sleep(time.Second) // Avoid race on upid. + b.ResetTimer() + for i := 0; i < b.N; i++ { + m1.Send(context.TODO(), upid2, messages[i%1000]) + } + globalWG.Wait() +} + +func BenchmarkMessengerSendRecvBigMessage(b *testing.B) { + globalWG.Add(b.N) + + messages := generateBigMessages(1000) + + upid1, err := upid.Parse(fmt.Sprintf("mesos1@localhost:%d", getNewPort())) + assert.NoError(b, err) + upid2, err := upid.Parse(fmt.Sprintf("mesos2@localhost:%d", getNewPort())) + assert.NoError(b, err) + + m1 := NewHttp(upid1) + m2 := NewHttp(upid2) + assert.NoError(b, m1.Start()) + assert.NoError(b, m2.Start()) + assert.NoError(b, m2.Install(noopHandler, &testmessage.BigMessage{})) + + time.Sleep(time.Second) // Avoid race on upid. + b.ResetTimer() + for i := 0; i < b.N; i++ { + m1.Send(context.TODO(), upid2, messages[i%1000]) + } + globalWG.Wait() +} + +func BenchmarkMessengerSendRecvLargeMessage(b *testing.B) { + globalWG.Add(b.N) + messages := generateLargeMessages(1000) + + upid1, err := upid.Parse(fmt.Sprintf("mesos1@localhost:%d", getNewPort())) + assert.NoError(b, err) + upid2, err := upid.Parse(fmt.Sprintf("mesos2@localhost:%d", getNewPort())) + assert.NoError(b, err) + + m1 := NewHttp(upid1) + m2 := NewHttp(upid2) + assert.NoError(b, m1.Start()) + assert.NoError(b, m2.Start()) + assert.NoError(b, m2.Install(noopHandler, &testmessage.LargeMessage{})) + + time.Sleep(time.Second) // Avoid race on upid. + b.ResetTimer() + for i := 0; i < b.N; i++ { + m1.Send(context.TODO(), upid2, messages[i%1000]) + } + globalWG.Wait() +} + +func BenchmarkMessengerSendRecvMixedMessage(b *testing.B) { + globalWG.Add(b.N) + messages := generateMixedMessages(1000) + + upid1, err := upid.Parse(fmt.Sprintf("mesos1@localhost:%d", getNewPort())) + assert.NoError(b, err) + upid2, err := upid.Parse(fmt.Sprintf("mesos2@localhost:%d", getNewPort())) + assert.NoError(b, err) + + m1 := NewHttp(upid1) + m2 := NewHttp(upid2) + assert.NoError(b, m1.Start()) + assert.NoError(b, m2.Start()) + assert.NoError(b, m2.Install(noopHandler, &testmessage.SmallMessage{})) + assert.NoError(b, m2.Install(noopHandler, &testmessage.MediumMessage{})) + assert.NoError(b, m2.Install(noopHandler, &testmessage.BigMessage{})) + assert.NoError(b, m2.Install(noopHandler, &testmessage.LargeMessage{})) + + time.Sleep(time.Second) // Avoid race on upid. + b.ResetTimer() + for i := 0; i < b.N; i++ { + m1.Send(context.TODO(), upid2, messages[i%1000]) + } + globalWG.Wait() +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/mocked_messenger.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/mocked_messenger.go new file mode 100644 index 00000000000..34d53d0868f --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/mocked_messenger.go @@ -0,0 +1,106 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package messenger + +import ( + "reflect" + + "github.com/gogo/protobuf/proto" + "github.com/mesos/mesos-go/upid" + "github.com/stretchr/testify/mock" + "golang.org/x/net/context" +) + +type message struct { + from *upid.UPID + msg proto.Message +} + +// MockedMessenger is a messenger that returns error on every operation. +type MockedMessenger struct { + mock.Mock + messageQueue chan *message + handlers map[string]MessageHandler + stop chan struct{} +} + +// NewMockedMessenger returns a mocked messenger used for testing. +func NewMockedMessenger() *MockedMessenger { + return &MockedMessenger{ + messageQueue: make(chan *message, 1), + handlers: make(map[string]MessageHandler), + stop: make(chan struct{}), + } +} + +// Install is a mocked implementation. +func (m *MockedMessenger) Install(handler MessageHandler, msg proto.Message) error { + m.handlers[reflect.TypeOf(msg).Elem().Name()] = handler + return m.Called().Error(0) +} + +// Send is a mocked implementation. +func (m *MockedMessenger) Send(ctx context.Context, upid *upid.UPID, msg proto.Message) error { + return m.Called().Error(0) +} + +func (m *MockedMessenger) Route(ctx context.Context, upid *upid.UPID, msg proto.Message) error { + return m.Called().Error(0) +} + +// Start is a mocked implementation. +func (m *MockedMessenger) Start() error { + go m.recvLoop() + return m.Called().Error(0) +} + +// Stop is a mocked implementation. +func (m *MockedMessenger) Stop() error { + // don't close an already-closed channel + select { + case <-m.stop: + // noop + default: + close(m.stop) + } + return m.Called().Error(0) +} + +// UPID is a mocked implementation. +func (m *MockedMessenger) UPID() *upid.UPID { + return m.Called().Get(0).(*upid.UPID) +} + +func (m *MockedMessenger) recvLoop() { + for { + select { + case <-m.stop: + return + case msg := <-m.messageQueue: + name := reflect.TypeOf(msg.msg).Elem().Name() + m.handlers[name](msg.from, msg.msg) + } + } +} + +// Recv receives a upid and a message, it will dispatch the message to its handler +// with the upid. This is for testing. +func (m *MockedMessenger) Recv(from *upid.UPID, msg proto.Message) { + m.messageQueue <- &message{from, msg} +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/testmessage/Makefile b/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/testmessage/Makefile new file mode 100644 index 00000000000..9bf30108452 --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/testmessage/Makefile @@ -0,0 +1,2 @@ +all: testmessage.proto + protoc --proto_path=${GOPATH}/src:${GOPATH}/src/github.com/gogo/protobuf/protobuf:. --gogo_out=. testmessage.proto diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/testmessage/generator.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/testmessage/generator.go new file mode 100644 index 00000000000..56cbe13b8e0 --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/testmessage/generator.go @@ -0,0 +1,49 @@ +package testmessage + +import ( + "math/rand" +) + +func generateRandomString(length int) string { + b := make([]byte, length) + for i := range b { + b[i] = byte(rand.Int()) + } + return string(b) +} + +// GenerateSmallMessage generates a small size message. +func GenerateSmallMessage() *SmallMessage { + v := make([]string, 3) + for i := range v { + v[i] = generateRandomString(5) + } + return &SmallMessage{Values: v} +} + +// GenerateMediumMessage generates a medium size message. +func GenerateMediumMessage() *MediumMessage { + v := make([]string, 10) + for i := range v { + v[i] = generateRandomString(10) + } + return &MediumMessage{Values: v} +} + +// GenerateBigMessage generates a big size message. +func GenerateBigMessage() *BigMessage { + v := make([]string, 20) + for i := range v { + v[i] = generateRandomString(20) + } + return &BigMessage{Values: v} +} + +// GenerateLargeMessage generates a large size message. +func GenerateLargeMessage() *LargeMessage { + v := make([]string, 30) + for i := range v { + v[i] = generateRandomString(30) + } + return &LargeMessage{Values: v} +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/testmessage/testmessage.pb.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/testmessage/testmessage.pb.go new file mode 100644 index 00000000000..11035be133b --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/testmessage/testmessage.pb.go @@ -0,0 +1,1114 @@ +// Code generated by protoc-gen-gogo. +// source: testmessage.proto +// DO NOT EDIT! + +/* +Package testmessage is a generated protocol buffer package. + +It is generated from these files: + testmessage.proto + +It has these top-level messages: + SmallMessage + MediumMessage + BigMessage + LargeMessage +*/ +package testmessage + +import proto "github.com/gogo/protobuf/proto" +import math "math" + +// discarding unused import gogoproto "github.com/gogo/protobuf/gogoproto/gogo.pb" + +import io "io" +import fmt "fmt" +import github_com_gogo_protobuf_proto "github.com/gogo/protobuf/proto" + +import fmt1 "fmt" +import strings "strings" +import reflect "reflect" + +import fmt2 "fmt" +import strings1 "strings" +import github_com_gogo_protobuf_proto1 "github.com/gogo/protobuf/proto" +import sort "sort" +import strconv "strconv" +import reflect1 "reflect" + +import fmt3 "fmt" +import bytes "bytes" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = math.Inf + +type SmallMessage struct { + Values []string `protobuf:"bytes,1,rep" json:"Values,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *SmallMessage) Reset() { *m = SmallMessage{} } +func (*SmallMessage) ProtoMessage() {} + +func (m *SmallMessage) GetValues() []string { + if m != nil { + return m.Values + } + return nil +} + +type MediumMessage struct { + Values []string `protobuf:"bytes,1,rep" json:"Values,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *MediumMessage) Reset() { *m = MediumMessage{} } +func (*MediumMessage) ProtoMessage() {} + +func (m *MediumMessage) GetValues() []string { + if m != nil { + return m.Values + } + return nil +} + +type BigMessage struct { + Values []string `protobuf:"bytes,1,rep" json:"Values,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *BigMessage) Reset() { *m = BigMessage{} } +func (*BigMessage) ProtoMessage() {} + +func (m *BigMessage) GetValues() []string { + if m != nil { + return m.Values + } + return nil +} + +type LargeMessage struct { + Values []string `protobuf:"bytes,1,rep" json:"Values,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *LargeMessage) Reset() { *m = LargeMessage{} } +func (*LargeMessage) ProtoMessage() {} + +func (m *LargeMessage) GetValues() []string { + if m != nil { + return m.Values + } + return nil +} + +func init() { +} +func (m *SmallMessage) Unmarshal(data []byte) error { + l := len(data) + index := 0 + for index < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if index >= l { + return io.ErrUnexpectedEOF + } + b := data[index] + index++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Values", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if index >= l { + return io.ErrUnexpectedEOF + } + b := data[index] + index++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + postIndex := index + int(stringLen) + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Values = append(m.Values, string(data[index:postIndex])) + index = postIndex + default: + var sizeOfWire int + for { + sizeOfWire++ + wire >>= 7 + if wire == 0 { + break + } + } + index -= sizeOfWire + skippy, err := github_com_gogo_protobuf_proto.Skip(data[index:]) + if err != nil { + return err + } + if (index + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXX_unrecognized = append(m.XXX_unrecognized, data[index:index+skippy]...) + index += skippy + } + } + return nil +} +func (m *MediumMessage) Unmarshal(data []byte) error { + l := len(data) + index := 0 + for index < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if index >= l { + return io.ErrUnexpectedEOF + } + b := data[index] + index++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Values", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if index >= l { + return io.ErrUnexpectedEOF + } + b := data[index] + index++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + postIndex := index + int(stringLen) + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Values = append(m.Values, string(data[index:postIndex])) + index = postIndex + default: + var sizeOfWire int + for { + sizeOfWire++ + wire >>= 7 + if wire == 0 { + break + } + } + index -= sizeOfWire + skippy, err := github_com_gogo_protobuf_proto.Skip(data[index:]) + if err != nil { + return err + } + if (index + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXX_unrecognized = append(m.XXX_unrecognized, data[index:index+skippy]...) + index += skippy + } + } + return nil +} +func (m *BigMessage) Unmarshal(data []byte) error { + l := len(data) + index := 0 + for index < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if index >= l { + return io.ErrUnexpectedEOF + } + b := data[index] + index++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Values", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if index >= l { + return io.ErrUnexpectedEOF + } + b := data[index] + index++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + postIndex := index + int(stringLen) + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Values = append(m.Values, string(data[index:postIndex])) + index = postIndex + default: + var sizeOfWire int + for { + sizeOfWire++ + wire >>= 7 + if wire == 0 { + break + } + } + index -= sizeOfWire + skippy, err := github_com_gogo_protobuf_proto.Skip(data[index:]) + if err != nil { + return err + } + if (index + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXX_unrecognized = append(m.XXX_unrecognized, data[index:index+skippy]...) + index += skippy + } + } + return nil +} +func (m *LargeMessage) Unmarshal(data []byte) error { + l := len(data) + index := 0 + for index < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if index >= l { + return io.ErrUnexpectedEOF + } + b := data[index] + index++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Values", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if index >= l { + return io.ErrUnexpectedEOF + } + b := data[index] + index++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + postIndex := index + int(stringLen) + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Values = append(m.Values, string(data[index:postIndex])) + index = postIndex + default: + var sizeOfWire int + for { + sizeOfWire++ + wire >>= 7 + if wire == 0 { + break + } + } + index -= sizeOfWire + skippy, err := github_com_gogo_protobuf_proto.Skip(data[index:]) + if err != nil { + return err + } + if (index + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXX_unrecognized = append(m.XXX_unrecognized, data[index:index+skippy]...) + index += skippy + } + } + return nil +} +func (this *SmallMessage) String() string { + if this == nil { + return "nil" + } + s := strings.Join([]string{`&SmallMessage{`, + `Values:` + fmt1.Sprintf("%v", this.Values) + `,`, + `XXX_unrecognized:` + fmt1.Sprintf("%v", this.XXX_unrecognized) + `,`, + `}`, + }, "") + return s +} +func (this *MediumMessage) String() string { + if this == nil { + return "nil" + } + s := strings.Join([]string{`&MediumMessage{`, + `Values:` + fmt1.Sprintf("%v", this.Values) + `,`, + `XXX_unrecognized:` + fmt1.Sprintf("%v", this.XXX_unrecognized) + `,`, + `}`, + }, "") + return s +} +func (this *BigMessage) String() string { + if this == nil { + return "nil" + } + s := strings.Join([]string{`&BigMessage{`, + `Values:` + fmt1.Sprintf("%v", this.Values) + `,`, + `XXX_unrecognized:` + fmt1.Sprintf("%v", this.XXX_unrecognized) + `,`, + `}`, + }, "") + return s +} +func (this *LargeMessage) String() string { + if this == nil { + return "nil" + } + s := strings.Join([]string{`&LargeMessage{`, + `Values:` + fmt1.Sprintf("%v", this.Values) + `,`, + `XXX_unrecognized:` + fmt1.Sprintf("%v", this.XXX_unrecognized) + `,`, + `}`, + }, "") + return s +} +func valueToStringTestmessage(v interface{}) string { + rv := reflect.ValueOf(v) + if rv.IsNil() { + return "nil" + } + pv := reflect.Indirect(rv).Interface() + return fmt1.Sprintf("*%v", pv) +} +func (m *SmallMessage) Size() (n int) { + var l int + _ = l + if len(m.Values) > 0 { + for _, s := range m.Values { + l = len(s) + n += 1 + l + sovTestmessage(uint64(l)) + } + } + if m.XXX_unrecognized != nil { + n += len(m.XXX_unrecognized) + } + return n +} + +func (m *MediumMessage) Size() (n int) { + var l int + _ = l + if len(m.Values) > 0 { + for _, s := range m.Values { + l = len(s) + n += 1 + l + sovTestmessage(uint64(l)) + } + } + if m.XXX_unrecognized != nil { + n += len(m.XXX_unrecognized) + } + return n +} + +func (m *BigMessage) Size() (n int) { + var l int + _ = l + if len(m.Values) > 0 { + for _, s := range m.Values { + l = len(s) + n += 1 + l + sovTestmessage(uint64(l)) + } + } + if m.XXX_unrecognized != nil { + n += len(m.XXX_unrecognized) + } + return n +} + +func (m *LargeMessage) Size() (n int) { + var l int + _ = l + if len(m.Values) > 0 { + for _, s := range m.Values { + l = len(s) + n += 1 + l + sovTestmessage(uint64(l)) + } + } + if m.XXX_unrecognized != nil { + n += len(m.XXX_unrecognized) + } + return n +} + +func sovTestmessage(x uint64) (n int) { + for { + n++ + x >>= 7 + if x == 0 { + break + } + } + return n +} +func sozTestmessage(x uint64) (n int) { + return sovTestmessage(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func NewPopulatedSmallMessage(r randyTestmessage, easy bool) *SmallMessage { + this := &SmallMessage{} + if r.Intn(10) != 0 { + v1 := r.Intn(10) + this.Values = make([]string, v1) + for i := 0; i < v1; i++ { + this.Values[i] = randStringTestmessage(r) + } + } + if !easy && r.Intn(10) != 0 { + this.XXX_unrecognized = randUnrecognizedTestmessage(r, 2) + } + return this +} + +func NewPopulatedMediumMessage(r randyTestmessage, easy bool) *MediumMessage { + this := &MediumMessage{} + if r.Intn(10) != 0 { + v2 := r.Intn(10) + this.Values = make([]string, v2) + for i := 0; i < v2; i++ { + this.Values[i] = randStringTestmessage(r) + } + } + if !easy && r.Intn(10) != 0 { + this.XXX_unrecognized = randUnrecognizedTestmessage(r, 2) + } + return this +} + +func NewPopulatedBigMessage(r randyTestmessage, easy bool) *BigMessage { + this := &BigMessage{} + if r.Intn(10) != 0 { + v3 := r.Intn(10) + this.Values = make([]string, v3) + for i := 0; i < v3; i++ { + this.Values[i] = randStringTestmessage(r) + } + } + if !easy && r.Intn(10) != 0 { + this.XXX_unrecognized = randUnrecognizedTestmessage(r, 2) + } + return this +} + +func NewPopulatedLargeMessage(r randyTestmessage, easy bool) *LargeMessage { + this := &LargeMessage{} + if r.Intn(10) != 0 { + v4 := r.Intn(10) + this.Values = make([]string, v4) + for i := 0; i < v4; i++ { + this.Values[i] = randStringTestmessage(r) + } + } + if !easy && r.Intn(10) != 0 { + this.XXX_unrecognized = randUnrecognizedTestmessage(r, 2) + } + return this +} + +type randyTestmessage interface { + Float32() float32 + Float64() float64 + Int63() int64 + Int31() int32 + Uint32() uint32 + Intn(n int) int +} + +func randUTF8RuneTestmessage(r randyTestmessage) rune { + res := rune(r.Uint32() % 1112064) + if 55296 <= res { + res += 2047 + } + return res +} +func randStringTestmessage(r randyTestmessage) string { + v5 := r.Intn(100) + tmps := make([]rune, v5) + for i := 0; i < v5; i++ { + tmps[i] = randUTF8RuneTestmessage(r) + } + return string(tmps) +} +func randUnrecognizedTestmessage(r randyTestmessage, maxFieldNumber int) (data []byte) { + l := r.Intn(5) + for i := 0; i < l; i++ { + wire := r.Intn(4) + if wire == 3 { + wire = 5 + } + fieldNumber := maxFieldNumber + r.Intn(100) + data = randFieldTestmessage(data, r, fieldNumber, wire) + } + return data +} +func randFieldTestmessage(data []byte, r randyTestmessage, fieldNumber int, wire int) []byte { + key := uint32(fieldNumber)<<3 | uint32(wire) + switch wire { + case 0: + data = encodeVarintPopulateTestmessage(data, uint64(key)) + v6 := r.Int63() + if r.Intn(2) == 0 { + v6 *= -1 + } + data = encodeVarintPopulateTestmessage(data, uint64(v6)) + case 1: + data = encodeVarintPopulateTestmessage(data, uint64(key)) + data = append(data, byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256))) + case 2: + data = encodeVarintPopulateTestmessage(data, uint64(key)) + ll := r.Intn(100) + data = encodeVarintPopulateTestmessage(data, uint64(ll)) + for j := 0; j < ll; j++ { + data = append(data, byte(r.Intn(256))) + } + default: + data = encodeVarintPopulateTestmessage(data, uint64(key)) + data = append(data, byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256))) + } + return data +} +func encodeVarintPopulateTestmessage(data []byte, v uint64) []byte { + for v >= 1<<7 { + data = append(data, uint8(uint64(v)&0x7f|0x80)) + v >>= 7 + } + data = append(data, uint8(v)) + return data +} +func (m *SmallMessage) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +func (m *SmallMessage) MarshalTo(data []byte) (n int, err error) { + var i int + _ = i + var l int + _ = l + if len(m.Values) > 0 { + for _, s := range m.Values { + data[i] = 0xa + i++ + l = len(s) + for l >= 1<<7 { + data[i] = uint8(uint64(l)&0x7f | 0x80) + l >>= 7 + i++ + } + data[i] = uint8(l) + i++ + i += copy(data[i:], s) + } + } + if m.XXX_unrecognized != nil { + i += copy(data[i:], m.XXX_unrecognized) + } + return i, nil +} + +func (m *MediumMessage) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +func (m *MediumMessage) MarshalTo(data []byte) (n int, err error) { + var i int + _ = i + var l int + _ = l + if len(m.Values) > 0 { + for _, s := range m.Values { + data[i] = 0xa + i++ + l = len(s) + for l >= 1<<7 { + data[i] = uint8(uint64(l)&0x7f | 0x80) + l >>= 7 + i++ + } + data[i] = uint8(l) + i++ + i += copy(data[i:], s) + } + } + if m.XXX_unrecognized != nil { + i += copy(data[i:], m.XXX_unrecognized) + } + return i, nil +} + +func (m *BigMessage) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +func (m *BigMessage) MarshalTo(data []byte) (n int, err error) { + var i int + _ = i + var l int + _ = l + if len(m.Values) > 0 { + for _, s := range m.Values { + data[i] = 0xa + i++ + l = len(s) + for l >= 1<<7 { + data[i] = uint8(uint64(l)&0x7f | 0x80) + l >>= 7 + i++ + } + data[i] = uint8(l) + i++ + i += copy(data[i:], s) + } + } + if m.XXX_unrecognized != nil { + i += copy(data[i:], m.XXX_unrecognized) + } + return i, nil +} + +func (m *LargeMessage) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +func (m *LargeMessage) MarshalTo(data []byte) (n int, err error) { + var i int + _ = i + var l int + _ = l + if len(m.Values) > 0 { + for _, s := range m.Values { + data[i] = 0xa + i++ + l = len(s) + for l >= 1<<7 { + data[i] = uint8(uint64(l)&0x7f | 0x80) + l >>= 7 + i++ + } + data[i] = uint8(l) + i++ + i += copy(data[i:], s) + } + } + if m.XXX_unrecognized != nil { + i += copy(data[i:], m.XXX_unrecognized) + } + return i, nil +} + +func encodeFixed64Testmessage(data []byte, offset int, v uint64) int { + data[offset] = uint8(v) + data[offset+1] = uint8(v >> 8) + data[offset+2] = uint8(v >> 16) + data[offset+3] = uint8(v >> 24) + data[offset+4] = uint8(v >> 32) + data[offset+5] = uint8(v >> 40) + data[offset+6] = uint8(v >> 48) + data[offset+7] = uint8(v >> 56) + return offset + 8 +} +func encodeFixed32Testmessage(data []byte, offset int, v uint32) int { + data[offset] = uint8(v) + data[offset+1] = uint8(v >> 8) + data[offset+2] = uint8(v >> 16) + data[offset+3] = uint8(v >> 24) + return offset + 4 +} +func encodeVarintTestmessage(data []byte, offset int, v uint64) int { + for v >= 1<<7 { + data[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + data[offset] = uint8(v) + return offset + 1 +} +func (this *SmallMessage) GoString() string { + if this == nil { + return "nil" + } + s := strings1.Join([]string{`&testmessage.SmallMessage{` + + `Values:` + fmt2.Sprintf("%#v", this.Values), + `XXX_unrecognized:` + fmt2.Sprintf("%#v", this.XXX_unrecognized) + `}`}, ", ") + return s +} +func (this *MediumMessage) GoString() string { + if this == nil { + return "nil" + } + s := strings1.Join([]string{`&testmessage.MediumMessage{` + + `Values:` + fmt2.Sprintf("%#v", this.Values), + `XXX_unrecognized:` + fmt2.Sprintf("%#v", this.XXX_unrecognized) + `}`}, ", ") + return s +} +func (this *BigMessage) GoString() string { + if this == nil { + return "nil" + } + s := strings1.Join([]string{`&testmessage.BigMessage{` + + `Values:` + fmt2.Sprintf("%#v", this.Values), + `XXX_unrecognized:` + fmt2.Sprintf("%#v", this.XXX_unrecognized) + `}`}, ", ") + return s +} +func (this *LargeMessage) GoString() string { + if this == nil { + return "nil" + } + s := strings1.Join([]string{`&testmessage.LargeMessage{` + + `Values:` + fmt2.Sprintf("%#v", this.Values), + `XXX_unrecognized:` + fmt2.Sprintf("%#v", this.XXX_unrecognized) + `}`}, ", ") + return s +} +func valueToGoStringTestmessage(v interface{}, typ string) string { + rv := reflect1.ValueOf(v) + if rv.IsNil() { + return "nil" + } + pv := reflect1.Indirect(rv).Interface() + return fmt2.Sprintf("func(v %v) *%v { return &v } ( %#v )", typ, typ, pv) +} +func extensionToGoStringTestmessage(e map[int32]github_com_gogo_protobuf_proto1.Extension) string { + if e == nil { + return "nil" + } + s := "map[int32]proto.Extension{" + keys := make([]int, 0, len(e)) + for k := range e { + keys = append(keys, int(k)) + } + sort.Ints(keys) + ss := []string{} + for _, k := range keys { + ss = append(ss, strconv.Itoa(k)+": "+e[int32(k)].GoString()) + } + s += strings1.Join(ss, ",") + "}" + return s +} +func (this *SmallMessage) VerboseEqual(that interface{}) error { + if that == nil { + if this == nil { + return nil + } + return fmt3.Errorf("that == nil && this != nil") + } + + that1, ok := that.(*SmallMessage) + if !ok { + return fmt3.Errorf("that is not of type *SmallMessage") + } + if that1 == nil { + if this == nil { + return nil + } + return fmt3.Errorf("that is type *SmallMessage but is nil && this != nil") + } else if this == nil { + return fmt3.Errorf("that is type *SmallMessagebut is not nil && this == nil") + } + if len(this.Values) != len(that1.Values) { + return fmt3.Errorf("Values this(%v) Not Equal that(%v)", len(this.Values), len(that1.Values)) + } + for i := range this.Values { + if this.Values[i] != that1.Values[i] { + return fmt3.Errorf("Values this[%v](%v) Not Equal that[%v](%v)", i, this.Values[i], i, that1.Values[i]) + } + } + if !bytes.Equal(this.XXX_unrecognized, that1.XXX_unrecognized) { + return fmt3.Errorf("XXX_unrecognized this(%v) Not Equal that(%v)", this.XXX_unrecognized, that1.XXX_unrecognized) + } + return nil +} +func (this *SmallMessage) Equal(that interface{}) bool { + if that == nil { + if this == nil { + return true + } + return false + } + + that1, ok := that.(*SmallMessage) + if !ok { + return false + } + if that1 == nil { + if this == nil { + return true + } + return false + } else if this == nil { + return false + } + if len(this.Values) != len(that1.Values) { + return false + } + for i := range this.Values { + if this.Values[i] != that1.Values[i] { + return false + } + } + if !bytes.Equal(this.XXX_unrecognized, that1.XXX_unrecognized) { + return false + } + return true +} +func (this *MediumMessage) VerboseEqual(that interface{}) error { + if that == nil { + if this == nil { + return nil + } + return fmt3.Errorf("that == nil && this != nil") + } + + that1, ok := that.(*MediumMessage) + if !ok { + return fmt3.Errorf("that is not of type *MediumMessage") + } + if that1 == nil { + if this == nil { + return nil + } + return fmt3.Errorf("that is type *MediumMessage but is nil && this != nil") + } else if this == nil { + return fmt3.Errorf("that is type *MediumMessagebut is not nil && this == nil") + } + if len(this.Values) != len(that1.Values) { + return fmt3.Errorf("Values this(%v) Not Equal that(%v)", len(this.Values), len(that1.Values)) + } + for i := range this.Values { + if this.Values[i] != that1.Values[i] { + return fmt3.Errorf("Values this[%v](%v) Not Equal that[%v](%v)", i, this.Values[i], i, that1.Values[i]) + } + } + if !bytes.Equal(this.XXX_unrecognized, that1.XXX_unrecognized) { + return fmt3.Errorf("XXX_unrecognized this(%v) Not Equal that(%v)", this.XXX_unrecognized, that1.XXX_unrecognized) + } + return nil +} +func (this *MediumMessage) Equal(that interface{}) bool { + if that == nil { + if this == nil { + return true + } + return false + } + + that1, ok := that.(*MediumMessage) + if !ok { + return false + } + if that1 == nil { + if this == nil { + return true + } + return false + } else if this == nil { + return false + } + if len(this.Values) != len(that1.Values) { + return false + } + for i := range this.Values { + if this.Values[i] != that1.Values[i] { + return false + } + } + if !bytes.Equal(this.XXX_unrecognized, that1.XXX_unrecognized) { + return false + } + return true +} +func (this *BigMessage) VerboseEqual(that interface{}) error { + if that == nil { + if this == nil { + return nil + } + return fmt3.Errorf("that == nil && this != nil") + } + + that1, ok := that.(*BigMessage) + if !ok { + return fmt3.Errorf("that is not of type *BigMessage") + } + if that1 == nil { + if this == nil { + return nil + } + return fmt3.Errorf("that is type *BigMessage but is nil && this != nil") + } else if this == nil { + return fmt3.Errorf("that is type *BigMessagebut is not nil && this == nil") + } + if len(this.Values) != len(that1.Values) { + return fmt3.Errorf("Values this(%v) Not Equal that(%v)", len(this.Values), len(that1.Values)) + } + for i := range this.Values { + if this.Values[i] != that1.Values[i] { + return fmt3.Errorf("Values this[%v](%v) Not Equal that[%v](%v)", i, this.Values[i], i, that1.Values[i]) + } + } + if !bytes.Equal(this.XXX_unrecognized, that1.XXX_unrecognized) { + return fmt3.Errorf("XXX_unrecognized this(%v) Not Equal that(%v)", this.XXX_unrecognized, that1.XXX_unrecognized) + } + return nil +} +func (this *BigMessage) Equal(that interface{}) bool { + if that == nil { + if this == nil { + return true + } + return false + } + + that1, ok := that.(*BigMessage) + if !ok { + return false + } + if that1 == nil { + if this == nil { + return true + } + return false + } else if this == nil { + return false + } + if len(this.Values) != len(that1.Values) { + return false + } + for i := range this.Values { + if this.Values[i] != that1.Values[i] { + return false + } + } + if !bytes.Equal(this.XXX_unrecognized, that1.XXX_unrecognized) { + return false + } + return true +} +func (this *LargeMessage) VerboseEqual(that interface{}) error { + if that == nil { + if this == nil { + return nil + } + return fmt3.Errorf("that == nil && this != nil") + } + + that1, ok := that.(*LargeMessage) + if !ok { + return fmt3.Errorf("that is not of type *LargeMessage") + } + if that1 == nil { + if this == nil { + return nil + } + return fmt3.Errorf("that is type *LargeMessage but is nil && this != nil") + } else if this == nil { + return fmt3.Errorf("that is type *LargeMessagebut is not nil && this == nil") + } + if len(this.Values) != len(that1.Values) { + return fmt3.Errorf("Values this(%v) Not Equal that(%v)", len(this.Values), len(that1.Values)) + } + for i := range this.Values { + if this.Values[i] != that1.Values[i] { + return fmt3.Errorf("Values this[%v](%v) Not Equal that[%v](%v)", i, this.Values[i], i, that1.Values[i]) + } + } + if !bytes.Equal(this.XXX_unrecognized, that1.XXX_unrecognized) { + return fmt3.Errorf("XXX_unrecognized this(%v) Not Equal that(%v)", this.XXX_unrecognized, that1.XXX_unrecognized) + } + return nil +} +func (this *LargeMessage) Equal(that interface{}) bool { + if that == nil { + if this == nil { + return true + } + return false + } + + that1, ok := that.(*LargeMessage) + if !ok { + return false + } + if that1 == nil { + if this == nil { + return true + } + return false + } else if this == nil { + return false + } + if len(this.Values) != len(that1.Values) { + return false + } + for i := range this.Values { + if this.Values[i] != that1.Values[i] { + return false + } + } + if !bytes.Equal(this.XXX_unrecognized, that1.XXX_unrecognized) { + return false + } + return true +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/testmessage/testmessage.proto b/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/testmessage/testmessage.proto new file mode 100644 index 00000000000..b1fa57fbdec --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/testmessage/testmessage.proto @@ -0,0 +1,31 @@ +package testmessage; + +import "github.com/gogo/protobuf/gogoproto/gogo.proto"; + +option (gogoproto.gostring_all) = true; +option (gogoproto.equal_all) = true; +option (gogoproto.verbose_equal_all) = true; +option (gogoproto.goproto_stringer_all) = false; +option (gogoproto.stringer_all) = true; +option (gogoproto.populate_all) = true; +option (gogoproto.testgen_all) = false; +option (gogoproto.benchgen_all) = false; +option (gogoproto.marshaler_all) = true; +option (gogoproto.sizer_all) = true; +option (gogoproto.unmarshaler_all) = true; + +message SmallMessage { + repeated string Values = 1; +} + +message MediumMessage { + repeated string Values = 1; +} + +message BigMessage { + repeated string Values = 1; +} + +message LargeMessage { + repeated string Values = 1; +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/transporter.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/transporter.go new file mode 100644 index 00000000000..7d920c08b02 --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/messenger/transporter.go @@ -0,0 +1,53 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package messenger + +import ( + "github.com/mesos/mesos-go/upid" + "golang.org/x/net/context" +) + +// Transporter defines methods for communicating with remote processes. +type Transporter interface { + //Send sends message to remote process. Must use context to determine + //cancelled requests. Will stop sending when transport is stopped. + Send(ctx context.Context, msg *Message) error + + //Rcvd receives and delegate message handling to installed handlers. + //Will stop receiving when transport is stopped. + Recv() (*Message, error) + + //Inject injects a message to the incoming queue. Must use context to + //determine cancelled requests. Injection is aborted if the transport + //is stopped. + Inject(ctx context.Context, msg *Message) error + + //Install mount an handler based on incoming message name. + Install(messageName string) + + //Start starts the transporter and returns immediately. The error chan + //is never nil. + Start() <-chan error + + //Stop kills the transporter. + Stop(graceful bool) error + + //UPID returns the PID for transporter. + UPID() *upid.UPID +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/doc.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/doc.go new file mode 100644 index 00000000000..94cfbacd6af --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/doc.go @@ -0,0 +1,6 @@ +/* +Package scheduler includes the interfaces for the mesos scheduler and +the mesos executor driver. It also contains as well as an implementation +of the driver that you can use in your code. +*/ +package scheduler diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/handler.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/handler.go new file mode 100644 index 00000000000..fc7fe6ab8ad --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/handler.go @@ -0,0 +1,29 @@ +package scheduler + +import ( + "github.com/mesos/mesos-go/auth/callback" + mesos "github.com/mesos/mesos-go/mesosproto" + "github.com/mesos/mesos-go/upid" +) + +type CredentialHandler struct { + pid *upid.UPID // the process to authenticate against (master) + client *upid.UPID // the process to be authenticated (slave / framework) + credential *mesos.Credential +} + +func (h *CredentialHandler) Handle(callbacks ...callback.Interface) error { + for _, cb := range callbacks { + switch cb := cb.(type) { + case *callback.Name: + cb.Set(h.credential.GetPrincipal()) + case *callback.Password: + cb.Set(h.credential.GetSecret()) + case *callback.Interprocess: + cb.Set(*(h.pid), *(h.client)) + default: + return &callback.Unsupported{Callback: cb} + } + } + return nil +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/mock_scheduler.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/mock_scheduler.go new file mode 100644 index 00000000000..9cfe54d343f --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/mock_scheduler.go @@ -0,0 +1,56 @@ +package scheduler + +import ( + log "github.com/golang/glog" + mesos "github.com/mesos/mesos-go/mesosproto" + "github.com/stretchr/testify/mock" +) + +type MockScheduler struct { + mock.Mock +} + +func NewMockScheduler() *MockScheduler { + return &MockScheduler{} +} + +func (sched *MockScheduler) Registered(SchedulerDriver, *mesos.FrameworkID, *mesos.MasterInfo) { + sched.Called() +} + +func (sched *MockScheduler) Reregistered(SchedulerDriver, *mesos.MasterInfo) { + sched.Called() +} + +func (sched *MockScheduler) Disconnected(SchedulerDriver) { + sched.Called() +} + +func (sched *MockScheduler) ResourceOffers(SchedulerDriver, []*mesos.Offer) { + sched.Called() +} + +func (sched *MockScheduler) OfferRescinded(SchedulerDriver, *mesos.OfferID) { + sched.Called() +} + +func (sched *MockScheduler) StatusUpdate(SchedulerDriver, *mesos.TaskStatus) { + sched.Called() +} + +func (sched *MockScheduler) FrameworkMessage(SchedulerDriver, *mesos.ExecutorID, *mesos.SlaveID, string) { + sched.Called() +} + +func (sched *MockScheduler) SlaveLost(SchedulerDriver, *mesos.SlaveID) { + sched.Called() +} + +func (sched *MockScheduler) ExecutorLost(SchedulerDriver, *mesos.ExecutorID, *mesos.SlaveID, int) { + sched.Called() +} + +func (sched *MockScheduler) Error(d SchedulerDriver, msg string) { + log.Error(msg) + sched.Called() +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/plugins.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/plugins.go new file mode 100644 index 00000000000..0054bbdd977 --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/plugins.go @@ -0,0 +1,7 @@ +package scheduler + +import ( + _ "github.com/mesos/mesos-go/auth/sasl" + _ "github.com/mesos/mesos-go/auth/sasl/mech/crammd5" + _ "github.com/mesos/mesos-go/detector/zoo" +) diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/schedcache.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/schedcache.go new file mode 100644 index 00000000000..5644623223b --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/schedcache.go @@ -0,0 +1,96 @@ +package scheduler + +import ( + log "github.com/golang/glog" + mesos "github.com/mesos/mesos-go/mesosproto" + "github.com/mesos/mesos-go/upid" + "sync" +) + +type cachedOffer struct { + offer *mesos.Offer + slavePid *upid.UPID +} + +func newCachedOffer(offer *mesos.Offer, slavePid *upid.UPID) *cachedOffer { + return &cachedOffer{offer: offer, slavePid: slavePid} +} + +// schedCache a managed cache with backing maps to store offeres +// and tasked slaves. +type schedCache struct { + lock sync.RWMutex + savedOffers map[string]*cachedOffer // current offers key:OfferID + savedSlavePids map[string]*upid.UPID // Current saved slaves, key:slaveId +} + +func newSchedCache() *schedCache { + return &schedCache{ + savedOffers: make(map[string]*cachedOffer), + savedSlavePids: make(map[string]*upid.UPID), + } +} + +// putOffer stores an offer and the slavePID associated with offer. +func (cache *schedCache) putOffer(offer *mesos.Offer, pid *upid.UPID) { + if offer == nil || pid == nil { + log.V(3).Infoln("WARN: Offer not cached. The offer or pid cannot be nil") + return + } + log.V(3).Infoln("Caching offer ", offer.Id.GetValue(), " with slavePID ", pid.String()) + cache.lock.Lock() + cache.savedOffers[offer.Id.GetValue()] = &cachedOffer{offer: offer, slavePid: pid} + cache.lock.Unlock() +} + +// getOffer returns cached offer +func (cache *schedCache) getOffer(offerId *mesos.OfferID) *cachedOffer { + if offerId == nil { + log.V(3).Infoln("WARN: OfferId == nil, returning nil") + return nil + } + cache.lock.RLock() + defer cache.lock.RUnlock() + return cache.savedOffers[offerId.GetValue()] +} + +// containsOff test cache for offer(offerId) +func (cache *schedCache) containsOffer(offerId *mesos.OfferID) bool { + cache.lock.RLock() + defer cache.lock.RUnlock() + _, ok := cache.savedOffers[offerId.GetValue()] + return ok +} + +func (cache *schedCache) removeOffer(offerId *mesos.OfferID) { + cache.lock.Lock() + delete(cache.savedOffers, offerId.GetValue()) + cache.lock.Unlock() +} + +func (cache *schedCache) putSlavePid(slaveId *mesos.SlaveID, pid *upid.UPID) { + cache.lock.Lock() + cache.savedSlavePids[slaveId.GetValue()] = pid + cache.lock.Unlock() +} + +func (cache *schedCache) getSlavePid(slaveId *mesos.SlaveID) *upid.UPID { + if slaveId == nil { + log.V(3).Infoln("SlaveId == nil, returning empty UPID") + return nil + } + return cache.savedSlavePids[slaveId.GetValue()] +} + +func (cache *schedCache) containsSlavePid(slaveId *mesos.SlaveID) bool { + cache.lock.RLock() + defer cache.lock.RUnlock() + _, ok := cache.savedSlavePids[slaveId.GetValue()] + return ok +} + +func (cache *schedCache) removeSlavePid(slaveId *mesos.SlaveID) { + cache.lock.Lock() + delete(cache.savedSlavePids, slaveId.GetValue()) + cache.lock.Unlock() +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/schedcache_test.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/schedcache_test.go new file mode 100644 index 00000000000..4a3a46e5c4b --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/schedcache_test.go @@ -0,0 +1,215 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package scheduler + +import ( + mesos "github.com/mesos/mesos-go/mesosproto" + util "github.com/mesos/mesos-go/mesosutil" + "github.com/stretchr/testify/assert" + "testing" + + "github.com/mesos/mesos-go/upid" +) + +func TestSchedCacheNew(t *testing.T) { + cache := newSchedCache() + assert.NotNil(t, cache) + assert.NotNil(t, cache.savedOffers) + assert.NotNil(t, cache.savedSlavePids) +} + +func TestSchedCachePutOffer(t *testing.T) { + cache := newSchedCache() + + offer01 := createTestOffer("01") + pid01, err := upid.Parse("slave01@127.0.0.1:5050") + assert.NoError(t, err) + cache.putOffer(offer01, pid01) + + offer02 := createTestOffer("02") + pid02, err := upid.Parse("slave02@127.0.0.1:5050") + assert.NoError(t, err) + cache.putOffer(offer02, pid02) + + assert.Equal(t, len(cache.savedOffers), 2) + cachedOffer1, ok := cache.savedOffers["test-offer-01"] + assert.True(t, ok) + cachedOffer2, ok := cache.savedOffers["test-offer-02"] + assert.True(t, ok) + + assert.NotNil(t, cachedOffer1.offer) + assert.Equal(t, "test-offer-01", cachedOffer1.offer.Id.GetValue()) + assert.NotNil(t, cachedOffer2.offer) + assert.Equal(t, "test-offer-02", cachedOffer2.offer.Id.GetValue()) + + assert.NotNil(t, cachedOffer1.slavePid) + assert.Equal(t, "slave01@127.0.0.1:5050", cachedOffer1.slavePid.String()) + assert.NotNil(t, cachedOffer2.slavePid) + assert.Equal(t, "slave02@127.0.0.1:5050", cachedOffer2.slavePid.String()) + +} + +func TestSchedCacheGetOffer(t *testing.T) { + cache := newSchedCache() + offer01 := createTestOffer("01") + pid01, err := upid.Parse("slave01@127.0.0.1:5050") + assert.NoError(t, err) + offer02 := createTestOffer("02") + pid02, err := upid.Parse("slave02@127.0.0.1:5050") + assert.NoError(t, err) + + cache.putOffer(offer01, pid01) + cache.putOffer(offer02, pid02) + + cachedOffer01 := cache.getOffer(util.NewOfferID("test-offer-01")).offer + cachedOffer02 := cache.getOffer(util.NewOfferID("test-offer-02")).offer + assert.NotEqual(t, offer01, cachedOffer02) + assert.Equal(t, offer01, cachedOffer01) + assert.Equal(t, offer02, cachedOffer02) + +} + +func TestSchedCacheContainsOffer(t *testing.T) { + cache := newSchedCache() + offer01 := createTestOffer("01") + pid01, err := upid.Parse("slave01@127.0.0.1:5050") + assert.NoError(t, err) + offer02 := createTestOffer("02") + pid02, err := upid.Parse("slave02@127.0.0.1:5050") + assert.NoError(t, err) + + cache.putOffer(offer01, pid01) + cache.putOffer(offer02, pid02) + + assert.True(t, cache.containsOffer(util.NewOfferID("test-offer-01"))) + assert.True(t, cache.containsOffer(util.NewOfferID("test-offer-02"))) + assert.False(t, cache.containsOffer(util.NewOfferID("test-offer-05"))) +} + +func TestSchedCacheRemoveOffer(t *testing.T) { + cache := newSchedCache() + offer01 := createTestOffer("01") + pid01, err := upid.Parse("slave01@127.0.0.1:5050") + assert.NoError(t, err) + offer02 := createTestOffer("02") + pid02, err := upid.Parse("slave02@127.0.0.1:5050") + assert.NoError(t, err) + + cache.putOffer(offer01, pid01) + cache.putOffer(offer02, pid02) + cache.removeOffer(util.NewOfferID("test-offer-01")) + + assert.Equal(t, 1, len(cache.savedOffers)) + assert.True(t, cache.containsOffer(util.NewOfferID("test-offer-02"))) + assert.False(t, cache.containsOffer(util.NewOfferID("test-offer-01"))) +} + +func TestSchedCachePutSlavePid(t *testing.T) { + cache := newSchedCache() + + pid01, err := upid.Parse("slave01@127.0.0.1:5050") + assert.NoError(t, err) + pid02, err := upid.Parse("slave02@127.0.0.1:5050") + assert.NoError(t, err) + pid03, err := upid.Parse("slave03@127.0.0.1:5050") + assert.NoError(t, err) + + cache.putSlavePid(util.NewSlaveID("slave01"), pid01) + cache.putSlavePid(util.NewSlaveID("slave02"), pid02) + cache.putSlavePid(util.NewSlaveID("slave03"), pid03) + + assert.Equal(t, len(cache.savedSlavePids), 3) + cachedSlavePid1, ok := cache.savedSlavePids["slave01"] + assert.True(t, ok) + cachedSlavePid2, ok := cache.savedSlavePids["slave02"] + assert.True(t, ok) + cachedSlavePid3, ok := cache.savedSlavePids["slave03"] + assert.True(t, ok) + + assert.True(t, cachedSlavePid1.Equal(pid01)) + assert.True(t, cachedSlavePid2.Equal(pid02)) + assert.True(t, cachedSlavePid3.Equal(pid03)) +} + +func TestSchedCacheGetSlavePid(t *testing.T) { + cache := newSchedCache() + + pid01, err := upid.Parse("slave01@127.0.0.1:5050") + assert.NoError(t, err) + pid02, err := upid.Parse("slave02@127.0.0.1:5050") + assert.NoError(t, err) + + cache.putSlavePid(util.NewSlaveID("slave01"), pid01) + cache.putSlavePid(util.NewSlaveID("slave02"), pid02) + + cachedSlavePid1 := cache.getSlavePid(util.NewSlaveID("slave01")) + cachedSlavePid2 := cache.getSlavePid(util.NewSlaveID("slave02")) + + assert.NotNil(t, cachedSlavePid1) + assert.NotNil(t, cachedSlavePid2) + assert.True(t, pid01.Equal(cachedSlavePid1)) + assert.True(t, pid02.Equal(cachedSlavePid2)) + assert.False(t, pid01.Equal(cachedSlavePid2)) +} + +func TestSchedCacheContainsSlavePid(t *testing.T) { + cache := newSchedCache() + + pid01, err := upid.Parse("slave01@127.0.0.1:5050") + assert.NoError(t, err) + pid02, err := upid.Parse("slave02@127.0.0.1:5050") + assert.NoError(t, err) + + cache.putSlavePid(util.NewSlaveID("slave01"), pid01) + cache.putSlavePid(util.NewSlaveID("slave02"), pid02) + + assert.True(t, cache.containsSlavePid(util.NewSlaveID("slave01"))) + assert.True(t, cache.containsSlavePid(util.NewSlaveID("slave02"))) + assert.False(t, cache.containsSlavePid(util.NewSlaveID("slave05"))) +} + +func TestSchedCacheRemoveSlavePid(t *testing.T) { + cache := newSchedCache() + + pid01, err := upid.Parse("slave01@127.0.0.1:5050") + assert.NoError(t, err) + pid02, err := upid.Parse("slave02@127.0.0.1:5050") + assert.NoError(t, err) + + cache.putSlavePid(util.NewSlaveID("slave01"), pid01) + cache.putSlavePid(util.NewSlaveID("slave02"), pid02) + + assert.True(t, cache.containsSlavePid(util.NewSlaveID("slave01"))) + assert.True(t, cache.containsSlavePid(util.NewSlaveID("slave02"))) + assert.False(t, cache.containsSlavePid(util.NewSlaveID("slave05"))) + + cache.removeSlavePid(util.NewSlaveID("slave01")) + assert.Equal(t, 1, len(cache.savedSlavePids)) + assert.False(t, cache.containsSlavePid(util.NewSlaveID("slave01"))) + +} + +func createTestOffer(idSuffix string) *mesos.Offer { + return util.NewOffer( + util.NewOfferID("test-offer-"+idSuffix), + util.NewFrameworkID("test-framework-"+idSuffix), + util.NewSlaveID("test-slave-"+idSuffix), + "localhost."+idSuffix, + ) +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/schedtype.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/schedtype.go new file mode 100644 index 00000000000..b7634efa9e1 --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/schedtype.go @@ -0,0 +1,191 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package scheduler + +import ( + mesos "github.com/mesos/mesos-go/mesosproto" +) + +// Interface for connecting a scheduler to Mesos. This +// interface is used both to manage the scheduler's lifecycle (start +// it, stop it, or wait for it to finish) and to interact with Mesos +// (e.g., launch tasks, kill tasks, etc.). +// See the MesosSchedulerDriver type for a concrete +// impl of a SchedulerDriver. +type SchedulerDriver interface { + // Starts the scheduler driver. This needs to be called before any + // other driver calls are made. + Start() (mesos.Status, error) + + // Stops the scheduler driver. If the 'failover' flag is set to + // false then it is expected that this framework will never + // reconnect to Mesos and all of its executors and tasks can be + // terminated. Otherwise, all executors and tasks will remain + // running (for some framework specific failover timeout) allowing the + // scheduler to reconnect (possibly in the same process, or from a + // different process, for example, on a different machine). + Stop(failover bool) (mesos.Status, error) + + // Aborts the driver so that no more callbacks can be made to the + // scheduler. The semantics of abort and stop have deliberately been + // separated so that code can detect an aborted driver (i.e., via + // the return status of SchedulerDriver::join, see below), and + // instantiate and start another driver if desired (from within the + // same process). Note that 'Stop()' is not automatically called + // inside 'Abort()'. + Abort() (mesos.Status, error) + + // Waits for the driver to be stopped or aborted, possibly + // _blocking_ the current thread indefinitely. The return status of + // this function can be used to determine if the driver was aborted + // (see mesos.proto for a description of Status). + Join() (mesos.Status, error) + + // Starts and immediately joins (i.e., blocks on) the driver. + Run() (mesos.Status, error) + + // Requests resources from Mesos (see mesos.proto for a description + // of Request and how, for example, to request resources + // from specific slaves). Any resources available are offered to the + // framework via Scheduler.ResourceOffers callback, asynchronously. + RequestResources(requests []*mesos.Request) (mesos.Status, error) + + // Launches the given set of tasks. Any resources remaining (i.e., + // not used by the tasks or their executors) will be considered + // declined. The specified filters are applied on all unused + // resources (see mesos.proto for a description of Filters). + // Available resources are aggregated when mutiple offers are + // provided. Note that all offers must belong to the same slave. + // Invoking this function with an empty collection of tasks declines + // offers in their entirety (see Scheduler::declineOffer). + LaunchTasks(offerIDs []*mesos.OfferID, tasks []*mesos.TaskInfo, filters *mesos.Filters) (mesos.Status, error) + + // Kills the specified task. Note that attempting to kill a task is + // currently not reliable. If, for example, a scheduler fails over + // while it was attempting to kill a task it will need to retry in + // the future. Likewise, if unregistered / disconnected, the request + // will be dropped (these semantics may be changed in the future). + KillTask(taskID *mesos.TaskID) (mesos.Status, error) + + // Declines an offer in its entirety and applies the specified + // filters on the resources (see mesos.proto for a description of + // Filters). Note that this can be done at any time, it is not + // necessary to do this within the Scheduler::resourceOffers + // callback. + DeclineOffer(offerID *mesos.OfferID, filters *mesos.Filters) (mesos.Status, error) + + // Removes all filters previously set by the framework (via + // LaunchTasks()). This enables the framework to receive offers from + // those filtered slaves. + ReviveOffers() (mesos.Status, error) + + // Sends a message from the framework to one of its executors. These + // messages are best effort; do not expect a framework message to be + // retransmitted in any reliable fashion. + SendFrameworkMessage(executorID *mesos.ExecutorID, slaveID *mesos.SlaveID, data string) (mesos.Status, error) + + // Allows the framework to query the status for non-terminal tasks. + // This causes the master to send back the latest task status for + // each task in 'statuses', if possible. Tasks that are no longer + // known will result in a TASK_LOST update. If statuses is empty, + // then the master will send the latest status for each task + // currently known. + ReconcileTasks(statuses []*mesos.TaskStatus) (mesos.Status, error) +} + +// Scheduler a type with callback attributes to be provided by frameworks +// schedulers. +// +// Each callback includes a reference to the scheduler driver that was +// used to run this scheduler. The pointer will not change for the +// duration of a scheduler (i.e., from the point you do +// SchedulerDriver.Start() to the point that SchedulerDriver.Stop() +// returns). This is intended for convenience so that a scheduler +// doesn't need to store a reference to the driver itself. +type Scheduler interface { + + // Invoked when the scheduler successfully registers with a Mesos + // master. A unique ID (generated by the master) used for + // distinguishing this framework from others and MasterInfo + // with the ip and port of the current master are provided as arguments. + Registered(SchedulerDriver, *mesos.FrameworkID, *mesos.MasterInfo) + + // Invoked when the scheduler re-registers with a newly elected Mesos master. + // This is only called when the scheduler has previously been registered. + // MasterInfo containing the updated information about the elected master + // is provided as an argument. + Reregistered(SchedulerDriver, *mesos.MasterInfo) + + // Invoked when the scheduler becomes "disconnected" from the master + // (e.g., the master fails and another is taking over). + Disconnected(SchedulerDriver) + + // Invoked when resources have been offered to this framework. A + // single offer will only contain resources from a single slave. + // Resources associated with an offer will not be re-offered to + // _this_ framework until either (a) this framework has rejected + // those resources (see SchedulerDriver::launchTasks) or (b) those + // resources have been rescinded (see Scheduler::offerRescinded). + // Note that resources may be concurrently offered to more than one + // framework at a time (depending on the allocator being used). In + // that case, the first framework to launch tasks using those + // resources will be able to use them while the other frameworks + // will have those resources rescinded (or if a framework has + // already launched tasks with those resources then those tasks will + // fail with a TASK_LOST status and a message saying as much). + ResourceOffers(SchedulerDriver, []*mesos.Offer) + + // Invoked when an offer is no longer valid (e.g., the slave was + // lost or another framework used resources in the offer). If for + // whatever reason an offer is never rescinded (e.g., dropped + // message, failing over framework, etc.), a framwork that attempts + // to launch tasks using an invalid offer will receive TASK_LOST + // status updates for those tasks (see Scheduler::resourceOffers). + OfferRescinded(SchedulerDriver, *mesos.OfferID) + + // Invoked when the status of a task has changed (e.g., a slave is + // lost and so the task is lost, a task finishes and an executor + // sends a status update saying so, etc). Note that returning from + // this callback _acknowledges_ receipt of this status update! If + // for whatever reason the scheduler aborts during this callback (or + // the process exits) another status update will be delivered (note, + // however, that this is currently not true if the slave sending the + // status update is lost/fails during that time). + StatusUpdate(SchedulerDriver, *mesos.TaskStatus) + + // Invoked when an executor sends a message. These messages are best + // effort; do not expect a framework message to be retransmitted in + // any reliable fashion. + FrameworkMessage(SchedulerDriver, *mesos.ExecutorID, *mesos.SlaveID, string) + + // Invoked when a slave has been determined unreachable (e.g., + // machine failure, network partition). Most frameworks will need to + // reschedule any tasks launched on this slave on a new slave. + SlaveLost(SchedulerDriver, *mesos.SlaveID) + + // Invoked when an executor has exited/terminated. Note that any + // tasks running will have TASK_LOST status updates automagically + // generated. + ExecutorLost(SchedulerDriver, *mesos.ExecutorID, *mesos.SlaveID, int) + + // Invoked when there is an unrecoverable error in the scheduler or + // scheduler driver. The driver will be aborted BEFORE invoking this + // callback. + Error(SchedulerDriver, string) +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/scheduler.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/scheduler.go new file mode 100644 index 00000000000..439e9977f87 --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/scheduler.go @@ -0,0 +1,1105 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package scheduler + +import ( + "errors" + "fmt" + "math" + "math/rand" + "net" + "os/user" + "sync" + "time" + + "code.google.com/p/go-uuid/uuid" + "github.com/gogo/protobuf/proto" + log "github.com/golang/glog" + "github.com/mesos/mesos-go/auth" + "github.com/mesos/mesos-go/detector" + mesos "github.com/mesos/mesos-go/mesosproto" + util "github.com/mesos/mesos-go/mesosutil" + "github.com/mesos/mesos-go/mesosutil/process" + "github.com/mesos/mesos-go/messenger" + "github.com/mesos/mesos-go/upid" + "golang.org/x/net/context" +) + +const ( + authTimeout = 5 * time.Second // timeout interval for an authentication attempt + registrationRetryIntervalMax = float64(1 * time.Minute) + registrationBackoffFactor = 2 * time.Second +) + +var ( + authenticationCanceledError = errors.New("authentication canceled") +) + +// helper to track authentication progress and to prevent multiple close() ops +// against a signalling chan. it's safe to invoke the func's of this struct +// even if the receiver pointer is nil. +type authenticationAttempt struct { + done chan struct{} + doneOnce sync.Once +} + +func (a *authenticationAttempt) cancel() { + if a != nil { + a.doneOnce.Do(func() { close(a.done) }) + } +} + +func (a *authenticationAttempt) inProgress() bool { + if a != nil { + select { + case <-a.done: + return false + default: + return true + } + } + return false +} + +type DriverConfig struct { + Scheduler Scheduler + Framework *mesos.FrameworkInfo + Master string + Credential *mesos.Credential // optional + WithAuthContext func(context.Context) context.Context // required when Credential != nil + HostnameOverride string // optional + BindingAddress net.IP // optional + BindingPort uint16 // optional + NewMessenger func() (messenger.Messenger, error) // optional +} + +// Concrete implementation of a SchedulerDriver that connects a +// Scheduler with a Mesos master. The MesosSchedulerDriver is +// thread-safe. +// +// Note that scheduler failover is supported in Mesos. After a +// scheduler is registered with Mesos it may failover (to a new +// process on the same machine or across multiple machines) by +// creating a new driver with the ID given to it in +// Scheduler.Registered(). +// +// The driver is responsible for invoking the Scheduler callbacks as +// it communicates with the Mesos master. +// +// Note that blocking on the MesosSchedulerDriver (e.g., via +// MesosSchedulerDriver.Join) doesn't affect the scheduler callbacks +// in anyway because they are handled by a different thread. +// +// TODO(yifan): examples. +// See src/examples/test_framework.cpp for an example of using the +// MesosSchedulerDriver. +type MesosSchedulerDriver struct { + Scheduler Scheduler + MasterPid *upid.UPID + FrameworkInfo *mesos.FrameworkInfo + + lock sync.RWMutex + self *upid.UPID + stopCh chan struct{} + stopped bool + status mesos.Status + messenger messenger.Messenger + masterDetector detector.Master + connected bool + connection uuid.UUID + failoverTimeout float64 + failover bool + cache *schedCache + updates map[string]*mesos.StatusUpdate // Key is a UUID string. + tasks map[string]*mesos.TaskInfo // Key is a UUID string. + credential *mesos.Credential + authenticated bool + authenticating *authenticationAttempt + reauthenticate bool + withAuthContext func(context.Context) context.Context +} + +// Create a new mesos scheduler driver with the given +// scheduler, framework info, +// master address, and credential(optional) +func NewMesosSchedulerDriver(config DriverConfig) (initializedDriver *MesosSchedulerDriver, err error) { + if config.Scheduler == nil { + err = fmt.Errorf("Scheduler callbacks required.") + } else if config.Master == "" { + err = fmt.Errorf("Missing master location URL.") + } else if config.Framework == nil { + err = fmt.Errorf("FrameworkInfo must be provided.") + } else if config.Credential != nil && config.WithAuthContext == nil { + err = fmt.Errorf("WithAuthContext must be provided when Credential != nil") + } + if err != nil { + return + } + + framework := cloneFrameworkInfo(config.Framework) + + // set default userid + if framework.GetUser() == "" { + user, err := user.Current() + if err != nil || user == nil { + if err != nil { + log.Warningf("Failed to obtain username: %v\n", err) + } else { + log.Warningln("Failed to obtain username.") + } + framework.User = proto.String("") + } else { + framework.User = proto.String(user.Username) + } + } + + // default hostname + hostname := util.GetHostname(config.HostnameOverride) + if framework.GetHostname() == "" { + framework.Hostname = proto.String(hostname) + } + + driver := &MesosSchedulerDriver{ + Scheduler: config.Scheduler, + FrameworkInfo: framework, + stopCh: make(chan struct{}), + status: mesos.Status_DRIVER_NOT_STARTED, + stopped: true, + cache: newSchedCache(), + credential: config.Credential, + failover: framework.Id != nil && len(framework.Id.GetValue()) > 0, + withAuthContext: config.WithAuthContext, + } + + if framework.FailoverTimeout != nil && *framework.FailoverTimeout > 0 { + driver.failoverTimeout = *framework.FailoverTimeout * float64(time.Second) + log.V(1).Infof("found failover_timeout = %v", time.Duration(driver.failoverTimeout)) + } + + newMessenger := config.NewMessenger + if newMessenger == nil { + newMessenger = func() (messenger.Messenger, error) { + process := process.New("scheduler") + return messenger.ForHostname(process, hostname, config.BindingAddress, config.BindingPort) + } + } + + // initialize new detector. + if driver.masterDetector, err = detector.New(config.Master); err != nil { + return + } else if driver.messenger, err = newMessenger(); err != nil { + return + } else if err = driver.init(); err != nil { + return + } else { + initializedDriver = driver + } + return +} + +func cloneFrameworkInfo(framework *mesos.FrameworkInfo) *mesos.FrameworkInfo { + if framework == nil { + return nil + } + + clonedInfo := *framework + if clonedInfo.Id != nil { + clonedId := *clonedInfo.Id + clonedInfo.Id = &clonedId + if framework.FailoverTimeout != nil { + clonedInfo.FailoverTimeout = proto.Float64(*framework.FailoverTimeout) + } + if framework.Checkpoint != nil { + clonedInfo.Checkpoint = proto.Bool(*framework.Checkpoint) + } + } + return &clonedInfo +} + +// init initializes the driver. +func (driver *MesosSchedulerDriver) init() error { + log.Infof("Initializing mesos scheduler driver\n") + + // Install handlers. + driver.messenger.Install(driver.frameworkRegistered, &mesos.FrameworkRegisteredMessage{}) + driver.messenger.Install(driver.frameworkReregistered, &mesos.FrameworkReregisteredMessage{}) + driver.messenger.Install(driver.resourcesOffered, &mesos.ResourceOffersMessage{}) + driver.messenger.Install(driver.resourceOfferRescinded, &mesos.RescindResourceOfferMessage{}) + driver.messenger.Install(driver.statusUpdated, &mesos.StatusUpdateMessage{}) + driver.messenger.Install(driver.slaveLost, &mesos.LostSlaveMessage{}) + driver.messenger.Install(driver.frameworkMessageRcvd, &mesos.ExecutorToFrameworkMessage{}) + driver.messenger.Install(driver.frameworkErrorRcvd, &mesos.FrameworkErrorMessage{}) + driver.messenger.Install(driver.handleMasterChanged, &mesos.InternalMasterChangeDetected{}) + driver.messenger.Install(driver.handleAuthenticationResult, &mesos.InternalAuthenticationResult{}) + return nil +} + +// lead master detection callback. +func (driver *MesosSchedulerDriver) handleMasterChanged(from *upid.UPID, pbMsg proto.Message) { + if driver.Status() == mesos.Status_DRIVER_ABORTED { + log.Info("Ignoring master change because the driver is aborted.") + return + } else if !from.Equal(driver.self) { + log.Errorf("ignoring master changed message received from upid '%v'", from) + return + } + + // Reconnect every time a master is dected. + if driver.Connected() { + log.V(3).Info("Disconnecting scheduler.") + driver.MasterPid = nil + driver.Scheduler.Disconnected(driver) + } + + msg := pbMsg.(*mesos.InternalMasterChangeDetected) + master := msg.Master + + driver.setConnected(false) + driver.authenticated = false + + if master != nil { + log.Infof("New master %s detected\n", master.GetPid()) + + pid, err := upid.Parse(master.GetPid()) + if err != nil { + panic("Unable to parse Master's PID value.") // this should not happen. + } + + driver.MasterPid = pid // save for downstream ops. + driver.tryAuthentication() + } else { + log.Infoln("No master detected.") + } +} + +func (driver *MesosSchedulerDriver) tryAuthentication() { + if driver.authenticated { + // programming error + panic("already authenticated") + } + + masterPid := driver.MasterPid // save for referencing later in goroutine + if masterPid == nil { + log.Info("skipping authentication attempt because we lost the master") + return + } + + if driver.authenticating.inProgress() { + // authentication is in progress, try to cancel it (we may too late already) + driver.authenticating.cancel() + driver.reauthenticate = true + return + } + + if driver.credential != nil { + // authentication can block and we don't want to hold up the messenger loop + authenticating := &authenticationAttempt{done: make(chan struct{})} + go func() { + defer authenticating.cancel() + result := &mesos.InternalAuthenticationResult{ + //TODO(jdef): is this really needed? + Success: proto.Bool(false), + Completed: proto.Bool(false), + Pid: proto.String(masterPid.String()), + } + // don't reference driver.authenticating here since it may have changed + if err := driver.authenticate(masterPid, authenticating); err != nil { + log.Errorf("Scheduler failed to authenticate: %v\n", err) + if err == auth.AuthenticationFailed { + result.Completed = proto.Bool(true) + } + } else { + result.Completed = proto.Bool(true) + result.Success = proto.Bool(true) + } + driver.messenger.Route(context.TODO(), driver.messenger.UPID(), result) + }() + driver.authenticating = authenticating + } else { + log.Infoln("No credentials were provided. " + + "Attempting to register scheduler without authentication.") + driver.authenticated = true + go driver.doReliableRegistration(float64(registrationBackoffFactor)) + } +} + +func (driver *MesosSchedulerDriver) handleAuthenticationResult(from *upid.UPID, pbMsg proto.Message) { + if driver.Status() != mesos.Status_DRIVER_RUNNING { + log.V(1).Info("ignoring authenticate because driver is not running") + return + } + if !from.Equal(driver.self) { + log.Errorf("ignoring authentication result message received from upid '%v'", from) + return + } + if driver.authenticated { + // programming error + panic("already authenticated") + } + if driver.MasterPid == nil { + log.Infoln("ignoring authentication result because master is lost") + driver.authenticating.cancel() // cancel any in-progress background attempt + + // disable future retries until we get a new master + driver.reauthenticate = false + return + } + msg := pbMsg.(*mesos.InternalAuthenticationResult) + if driver.reauthenticate || !msg.GetCompleted() || driver.MasterPid.String() != msg.GetPid() { + log.Infof("failed to authenticate with master %v: master changed", driver.MasterPid) + driver.authenticating.cancel() // cancel any in-progress background authentication + driver.reauthenticate = false + driver.tryAuthentication() + return + } + if !msg.GetSuccess() { + log.Errorf("master %v refused authentication", driver.MasterPid) + return + } + driver.authenticated = true + go driver.doReliableRegistration(float64(registrationBackoffFactor)) +} + +// ------------------------- Accessors ----------------------- // +func (driver *MesosSchedulerDriver) Status() mesos.Status { + driver.lock.RLock() + defer driver.lock.RUnlock() + return driver.status +} +func (driver *MesosSchedulerDriver) setStatus(stat mesos.Status) { + driver.lock.Lock() + driver.status = stat + driver.lock.Unlock() +} + +func (driver *MesosSchedulerDriver) Stopped() bool { + driver.lock.RLock() + defer driver.lock.RUnlock() + return driver.stopped +} + +func (driver *MesosSchedulerDriver) setStopped(val bool) { + driver.lock.Lock() + driver.stopped = val + driver.lock.Unlock() +} + +func (driver *MesosSchedulerDriver) Connected() bool { + driver.lock.RLock() + defer driver.lock.RUnlock() + return driver.connected +} + +func (driver *MesosSchedulerDriver) setConnected(val bool) { + driver.lock.Lock() + driver.connected = val + if val { + driver.failover = false + } + driver.lock.Unlock() +} + +// ---------------------- Handlers for Events from Master --------------- // +func (driver *MesosSchedulerDriver) frameworkRegistered(from *upid.UPID, pbMsg proto.Message) { + log.V(2).Infoln("Handling scheduler driver framework registered event.") + + msg := pbMsg.(*mesos.FrameworkRegisteredMessage) + masterInfo := msg.GetMasterInfo() + masterPid := masterInfo.GetPid() + frameworkId := msg.GetFrameworkId() + + if driver.Status() == mesos.Status_DRIVER_ABORTED { + log.Infof("ignoring FrameworkRegisteredMessage from master %s, driver is aborted", masterPid) + return + } + + if driver.connected { + log.Infoln("ignoring FrameworkRegisteredMessage from master, driver is already connected", masterPid) + return + } + + if driver.stopped { + log.Infof("ignoring FrameworkRegisteredMessage from master %s, driver is stopped", masterPid) + return + } + if !driver.MasterPid.Equal(from) { + log.Warningf("ignoring framework registered message because it was sent from '%v' instead of leading master '%v'", from, driver.MasterPid) + return + } + + log.Infof("Framework registered with ID=%s\n", frameworkId.GetValue()) + driver.FrameworkInfo.Id = frameworkId // generated by master. + + driver.setConnected(true) + driver.connection = uuid.NewUUID() + driver.Scheduler.Registered(driver, frameworkId, masterInfo) +} + +func (driver *MesosSchedulerDriver) frameworkReregistered(from *upid.UPID, pbMsg proto.Message) { + log.V(1).Infoln("Handling Scheduler re-registered event.") + msg := pbMsg.(*mesos.FrameworkReregisteredMessage) + + if driver.Status() == mesos.Status_DRIVER_ABORTED { + log.Infoln("Ignoring FrameworkReregisteredMessage from master, driver is aborted!") + return + } + if driver.connected { + log.Infoln("Ignoring FrameworkReregisteredMessage from master,driver is already connected!") + return + } + if !driver.MasterPid.Equal(from) { + log.Warningf("ignoring framework re-registered message because it was sent from '%v' instead of leading master '%v'", from, driver.MasterPid) + return + } + + // TODO(vv) detect if message was from leading-master (sched.cpp) + log.Infof("Framework re-registered with ID [%s] ", msg.GetFrameworkId().GetValue()) + driver.setConnected(true) + driver.connection = uuid.NewUUID() + + driver.Scheduler.Reregistered(driver, msg.GetMasterInfo()) + +} + +func (driver *MesosSchedulerDriver) resourcesOffered(from *upid.UPID, pbMsg proto.Message) { + log.V(1).Infoln("Handling resource offers.") + + msg := pbMsg.(*mesos.ResourceOffersMessage) + if driver.Status() == mesos.Status_DRIVER_ABORTED { + log.Infoln("Ignoring ResourceOffersMessage, the driver is aborted!") + return + } + + if !driver.connected { + log.Infoln("Ignoring ResourceOffersMessage, the driver is not connected!") + return + } + + pidStrings := msg.GetPids() + if len(pidStrings) != len(msg.Offers) { + log.Errorln("Ignoring offers, Offer count does not match Slave PID count.") + return + } + + for i, offer := range msg.Offers { + if pid, err := upid.Parse(pidStrings[i]); err == nil { + driver.cache.putOffer(offer, pid) + log.V(1).Infof("Cached offer %s from SlavePID %s", offer.Id.GetValue(), pid) + } else { + log.Warningf("Failed to parse offer PID '%v': %v", pid, err) + } + } + + driver.Scheduler.ResourceOffers(driver, msg.Offers) +} + +func (driver *MesosSchedulerDriver) resourceOfferRescinded(from *upid.UPID, pbMsg proto.Message) { + log.V(1).Infoln("Handling resource offer rescinded.") + + msg := pbMsg.(*mesos.RescindResourceOfferMessage) + + if driver.Status() == mesos.Status_DRIVER_ABORTED { + log.Infoln("Ignoring RescindResourceOfferMessage, the driver is aborted!") + return + } + + if !driver.connected { + log.Infoln("Ignoring ResourceOffersMessage, the driver is not connected!") + return + } + + // TODO(vv) check for leading master (see sched.cpp) + + log.V(1).Infoln("Rescinding offer ", msg.OfferId.GetValue()) + driver.cache.removeOffer(msg.OfferId) + driver.Scheduler.OfferRescinded(driver, msg.OfferId) +} + +func (driver *MesosSchedulerDriver) send(upid *upid.UPID, msg proto.Message) error { + //TODO(jdef) should implement timeout here + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + c := make(chan error, 1) + go func() { c <- driver.messenger.Send(ctx, upid, msg) }() + + select { + case <-ctx.Done(): + <-c // wait for Send(...) + return ctx.Err() + case err := <-c: + return err + } +} + +func (driver *MesosSchedulerDriver) statusUpdated(from *upid.UPID, pbMsg proto.Message) { + msg := pbMsg.(*mesos.StatusUpdateMessage) + + if driver.Status() == mesos.Status_DRIVER_ABORTED { + log.V(1).Infoln("Ignoring StatusUpdate message, the driver is aborted!") + return + } + if !driver.connected { + log.V(1).Infoln("Ignoring StatusUpdate message, the driver is not connected!") + return + } + if !driver.MasterPid.Equal(from) { + log.Warningf("ignoring status message because it was sent from '%v' instead of leading master '%v'", from, driver.MasterPid) + return + } + + log.V(2).Infoln("Received status update from ", from.String(), " status source:", msg.GetPid()) + + driver.Scheduler.StatusUpdate(driver, msg.Update.GetStatus()) + + if driver.Status() == mesos.Status_DRIVER_ABORTED { + log.V(1).Infoln("Not sending StatusUpdate ACK, the driver is aborted!") + return + } + + // Send StatusUpdate Acknowledgement + // Only send ACK if udpate was not from this driver + if !from.Equal(driver.self) && msg.GetPid() != from.String() { + ackMsg := &mesos.StatusUpdateAcknowledgementMessage{ + SlaveId: msg.Update.SlaveId, + FrameworkId: driver.FrameworkInfo.Id, + TaskId: msg.Update.Status.TaskId, + Uuid: msg.Update.Uuid, + } + + log.V(2).Infoln("Sending status update ACK to ", from.String()) + if err := driver.send(driver.MasterPid, ackMsg); err != nil { + log.Errorf("Failed to send StatusUpdate ACK message: %v\n", err) + return + } + } else { + log.V(1).Infoln("Not sending ACK, update is not from slave:", from.String()) + } +} + +func (driver *MesosSchedulerDriver) slaveLost(from *upid.UPID, pbMsg proto.Message) { + log.V(1).Infoln("Handling LostSlave event.") + + msg := pbMsg.(*mesos.LostSlaveMessage) + + if driver.Status() == mesos.Status_DRIVER_ABORTED { + log.V(1).Infoln("Ignoring LostSlave message, the driver is aborted!") + return + } + + if !driver.connected { + log.V(1).Infoln("Ignoring LostSlave message, the driver is not connected!") + return + } + + // TODO(VV) - detect leading master (see sched.cpp) + + log.V(2).Infoln("Lost slave ", msg.SlaveId.GetValue()) + driver.cache.removeSlavePid(msg.SlaveId) + + driver.Scheduler.SlaveLost(driver, msg.SlaveId) +} + +func (driver *MesosSchedulerDriver) frameworkMessageRcvd(from *upid.UPID, pbMsg proto.Message) { + log.V(1).Infoln("Handling framework message event.") + + msg := pbMsg.(*mesos.ExecutorToFrameworkMessage) + + if driver.Status() == mesos.Status_DRIVER_ABORTED { + log.V(1).Infoln("Ignoring framwork message, the driver is aborted!") + return + } + + log.V(1).Infoln("Received Framwork Message ", msg.String()) + + driver.Scheduler.FrameworkMessage(driver, msg.ExecutorId, msg.SlaveId, string(msg.Data)) +} + +func (driver *MesosSchedulerDriver) frameworkErrorRcvd(from *upid.UPID, pbMsg proto.Message) { + log.V(1).Infoln("Handling framework error event.") + msg := pbMsg.(*mesos.FrameworkErrorMessage) + driver.error(msg.GetMessage(), true) +} + +// ---------------------- Interface Methods ---------------------- // + +// Starts the scheduler driver. +// Returns immediately if an error occurs within start sequence. +func (driver *MesosSchedulerDriver) Start() (mesos.Status, error) { + log.Infoln("Starting the scheduler driver...") + + if stat := driver.Status(); stat != mesos.Status_DRIVER_NOT_STARTED { + return stat, fmt.Errorf("Unable to Start, expecting driver status %s, but is %s:", mesos.Status_DRIVER_NOT_STARTED, stat) + } + + driver.setStopped(true) + driver.setStatus(mesos.Status_DRIVER_NOT_STARTED) + + // Start the messenger. + if err := driver.messenger.Start(); err != nil { + log.Errorf("Scheduler failed to start the messenger: %v\n", err) + return driver.Status(), err + } + + driver.self = driver.messenger.UPID() + driver.setStatus(mesos.Status_DRIVER_RUNNING) + driver.setStopped(false) + + log.Infof("Mesos scheduler driver started with PID=%v", driver.self) + + listener := detector.OnMasterChanged(func(m *mesos.MasterInfo) { + driver.messenger.Route(context.TODO(), driver.self, &mesos.InternalMasterChangeDetected{ + Master: m, + }) + }) + + // register with Detect() AFTER we have a self pid from the messenger, otherwise things get ugly + // because our internal messaging depends on it. detector callbacks are routed over the messenger + // bus, maintaining serial (concurrency-safe) callback execution. + log.V(1).Infof("starting master detector %T: %+v", driver.masterDetector, driver.masterDetector) + driver.masterDetector.Detect(listener) + + log.V(2).Infoln("master detector started") + return driver.Status(), nil +} + +// authenticate against the spec'd master pid using the configured authenticationProvider. +// the authentication process is canceled upon either cancelation of authenticating, or +// else because it timed out (authTimeout). +// +// TODO(jdef) perhaps at some point in the future this will get pushed down into +// the messenger layer (e.g. to use HTTP-based authentication). We'd probably still +// specify the callback.Handler here, along with the user-selected authentication +// provider. Perhaps in the form of some messenger.AuthenticationConfig. +// +func (driver *MesosSchedulerDriver) authenticate(pid *upid.UPID, authenticating *authenticationAttempt) error { + log.Infof("authenticating with master %v", pid) + ctx, cancel := context.WithTimeout(context.Background(), authTimeout) + handler := &CredentialHandler{ + pid: pid, + client: driver.self, + credential: driver.credential, + } + ctx = driver.withAuthContext(ctx) + ctx = auth.WithParentUPID(ctx, *driver.self) + + ch := make(chan error, 1) + go func() { ch <- auth.Login(ctx, handler) }() + select { + case <-ctx.Done(): + <-ch + return ctx.Err() + case <-authenticating.done: + cancel() + <-ch + return authenticationCanceledError + case e := <-ch: + cancel() + return e + } +} + +func (driver *MesosSchedulerDriver) doReliableRegistration(maxBackoff float64) { + for { + if !driver.registerOnce() { + return + } + maxBackoff = math.Min(maxBackoff, registrationRetryIntervalMax) + + // If failover timeout is present, bound the maximum backoff + // by 1/10th of the failover timeout. + if driver.failoverTimeout > 0 { + maxBackoff = math.Min(maxBackoff, driver.failoverTimeout/10.0) + } + + // Determine the delay for next attempt by picking a random + // duration between 0 and 'maxBackoff'. + delay := time.Duration(maxBackoff * rand.Float64()) + + log.V(1).Infof("will retry registration in %v if necessary", delay) + + select { + case <-driver.stopCh: + return + case <-time.After(delay): + maxBackoff *= 2 + } + } +} + +// return true if we should attempt another registration later +func (driver *MesosSchedulerDriver) registerOnce() bool { + + var ( + failover bool + pid *upid.UPID + ) + if func() bool { + driver.lock.RLock() + defer driver.lock.RUnlock() + + if driver.stopped || driver.connected || driver.MasterPid == nil || (driver.credential != nil && !driver.authenticated) { + log.V(1).Infof("skipping registration request: stopped=%v, connected=%v, authenticated=%v", + driver.stopped, driver.connected, driver.authenticated) + return false + } + failover = driver.failover + pid = driver.MasterPid + return true + }() { + // register framework + var message proto.Message + if driver.FrameworkInfo.Id != nil && len(driver.FrameworkInfo.Id.GetValue()) > 0 { + // not the first time, or failing over + log.V(1).Infof("Reregistering with master: %v", pid) + message = &mesos.ReregisterFrameworkMessage{ + Framework: driver.FrameworkInfo, + Failover: proto.Bool(failover), + } + } else { + log.V(1).Infof("Registering with master: %v", pid) + message = &mesos.RegisterFrameworkMessage{ + Framework: driver.FrameworkInfo, + } + } + if err := driver.send(pid, message); err != nil { + log.Errorf("failed to send RegisterFramework message: %v", err) + if _, err = driver.Stop(failover); err != nil { + log.Errorf("failed to stop scheduler driver: %v", err) + } + } + return true + } + return false +} + +//Join blocks until the driver is stopped. +//Should follow a call to Start() +func (driver *MesosSchedulerDriver) Join() (mesos.Status, error) { + if stat := driver.Status(); stat != mesos.Status_DRIVER_RUNNING { + return stat, fmt.Errorf("Unable to Join, expecting driver status %s, but is %s", mesos.Status_DRIVER_RUNNING, stat) + } + <-driver.stopCh // wait for stop signal + return driver.Status(), nil +} + +//Run starts and joins driver process and waits to be stopped or aborted. +func (driver *MesosSchedulerDriver) Run() (mesos.Status, error) { + stat, err := driver.Start() + + if err != nil { + return driver.Stop(false) + } + + if stat != mesos.Status_DRIVER_RUNNING { + return stat, fmt.Errorf("Unable to Run, expecting driver status %s, but is %s:", mesos.Status_DRIVER_RUNNING, driver.status) + } + + log.Infoln("Scheduler driver running. Waiting to be stopped.") + return driver.Join() +} + +//Stop stops the driver. +func (driver *MesosSchedulerDriver) Stop(failover bool) (mesos.Status, error) { + log.Infoln("Stopping the scheduler driver") + if stat := driver.Status(); stat != mesos.Status_DRIVER_RUNNING { + return stat, fmt.Errorf("Unable to Stop, expected driver status %s, but is %s", mesos.Status_DRIVER_RUNNING, stat) + } + + if driver.connected && failover { + // unregister the framework + message := &mesos.UnregisterFrameworkMessage{ + FrameworkId: driver.FrameworkInfo.Id, + } + if err := driver.send(driver.MasterPid, message); err != nil { + log.Errorf("Failed to send UnregisterFramework message while stopping driver: %v\n", err) + return driver.stop(mesos.Status_DRIVER_ABORTED) + } + } + + // stop messenger + return driver.stop(mesos.Status_DRIVER_STOPPED) +} + +func (driver *MesosSchedulerDriver) stop(stopStatus mesos.Status) (mesos.Status, error) { + // stop messenger + err := driver.messenger.Stop() + defer func() { + select { + case <-driver.stopCh: + // already closed + default: + close(driver.stopCh) + } + }() + + driver.setStatus(stopStatus) + driver.setStopped(true) + driver.connected = false + + if err != nil { + return stopStatus, err + } + + return stopStatus, nil +} + +func (driver *MesosSchedulerDriver) Abort() (stat mesos.Status, err error) { + defer driver.masterDetector.Cancel() + log.Infof("Aborting framework [%+v]", driver.FrameworkInfo.Id) + if driver.connected { + _, err = driver.Stop(true) + } else { + driver.messenger.Stop() + } + stat = mesos.Status_DRIVER_ABORTED + driver.setStatus(stat) + return +} + +func (driver *MesosSchedulerDriver) LaunchTasks(offerIds []*mesos.OfferID, tasks []*mesos.TaskInfo, filters *mesos.Filters) (mesos.Status, error) { + if stat := driver.Status(); stat != mesos.Status_DRIVER_RUNNING { + return stat, fmt.Errorf("Unable to LaunchTasks, expected driver status %s, but got %s", mesos.Status_DRIVER_RUNNING, stat) + } + + // Launch tasks + if !driver.connected { + log.Infoln("Ignoring LaunchTasks message, disconnected from master.") + // Send statusUpdate with status=TASK_LOST for each task. + // See sched.cpp L#823 + for _, task := range tasks { + driver.pushLostTask(task, "Master is disconnected") + } + return driver.Status(), fmt.Errorf("Not connected to master. Tasks marked as lost.") + } + + okTasks := make([]*mesos.TaskInfo, 0, len(tasks)) + + // Set TaskInfo.executor.framework_id, if it's missing. + for _, task := range tasks { + if task.Executor != nil && task.Executor.FrameworkId == nil { + task.Executor.FrameworkId = driver.FrameworkInfo.Id + } + okTasks = append(okTasks, task) + } + + for _, offerId := range offerIds { + for _, task := range okTasks { + // Keep only the slave PIDs where we run tasks so we can send + // framework messages directly. + if driver.cache.containsOffer(offerId) { + if driver.cache.getOffer(offerId).offer.SlaveId.Equal(task.SlaveId) { + // cache the tasked slave, for future communication + pid := driver.cache.getOffer(offerId).slavePid + driver.cache.putSlavePid(task.SlaveId, pid) + } else { + log.Warningf("Attempting to launch task %s with the wrong slaveId offer %s\n", task.TaskId.GetValue(), task.SlaveId.GetValue()) + } + } else { + log.Warningf("Attempting to launch task %s with unknown offer %s\n", task.TaskId.GetValue(), offerId.GetValue()) + } + } + + driver.cache.removeOffer(offerId) // if offer + } + + // launch tasks + message := &mesos.LaunchTasksMessage{ + FrameworkId: driver.FrameworkInfo.Id, + OfferIds: offerIds, + Tasks: okTasks, + Filters: filters, + } + + if err := driver.send(driver.MasterPid, message); err != nil { + for _, task := range tasks { + driver.pushLostTask(task, "Unable to launch tasks: "+err.Error()) + } + log.Errorf("Failed to send LaunchTask message: %v\n", err) + return driver.Status(), err + } + + return driver.Status(), nil +} + +func (driver *MesosSchedulerDriver) pushLostTask(taskInfo *mesos.TaskInfo, why string) { + msg := &mesos.StatusUpdateMessage{ + Update: &mesos.StatusUpdate{ + FrameworkId: driver.FrameworkInfo.Id, + Status: &mesos.TaskStatus{ + TaskId: taskInfo.TaskId, + State: mesos.TaskState_TASK_LOST.Enum(), + Message: proto.String(why), + }, + SlaveId: taskInfo.SlaveId, + ExecutorId: taskInfo.Executor.ExecutorId, + Timestamp: proto.Float64(float64(time.Now().Unix())), + Uuid: []byte(uuid.NewUUID()), + }, + } + + // put it on internal chanel + // will cause handler to push to attached Scheduler + driver.statusUpdated(driver.self, msg) +} + +func (driver *MesosSchedulerDriver) KillTask(taskId *mesos.TaskID) (mesos.Status, error) { + if stat := driver.Status(); stat != mesos.Status_DRIVER_RUNNING { + return stat, fmt.Errorf("Unable to KillTask, expecting driver status %s, but got %s", mesos.Status_DRIVER_RUNNING, stat) + } + + if !driver.connected { + log.Infoln("Ignoring kill task message, disconnected from master.") + return driver.Status(), fmt.Errorf("Not connected to master") + } + + message := &mesos.KillTaskMessage{ + FrameworkId: driver.FrameworkInfo.Id, + TaskId: taskId, + } + + if err := driver.send(driver.MasterPid, message); err != nil { + log.Errorf("Failed to send KillTask message: %v\n", err) + return driver.Status(), err + } + + return driver.Status(), nil +} + +func (driver *MesosSchedulerDriver) RequestResources(requests []*mesos.Request) (mesos.Status, error) { + if stat := driver.Status(); stat != mesos.Status_DRIVER_RUNNING { + return stat, fmt.Errorf("Unable to RequestResources, expecting driver status %s, but got %s", mesos.Status_DRIVER_RUNNING, stat) + } + + if !driver.connected { + log.Infoln("Ignoring request resource message, disconnected from master.") + return driver.status, fmt.Errorf("Not connected to master") + } + + message := &mesos.ResourceRequestMessage{ + FrameworkId: driver.FrameworkInfo.Id, + Requests: requests, + } + + if err := driver.send(driver.MasterPid, message); err != nil { + log.Errorf("Failed to send ResourceRequest message: %v\n", err) + return driver.status, err + } + + return driver.status, nil +} + +func (driver *MesosSchedulerDriver) DeclineOffer(offerId *mesos.OfferID, filters *mesos.Filters) (mesos.Status, error) { + // launching an empty task list will decline the offer + return driver.LaunchTasks([]*mesos.OfferID{offerId}, []*mesos.TaskInfo{}, filters) +} + +func (driver *MesosSchedulerDriver) ReviveOffers() (mesos.Status, error) { + if stat := driver.Status(); stat != mesos.Status_DRIVER_RUNNING { + return stat, fmt.Errorf("Unable to ReviveOffers, expecting driver status %s, but got %s", mesos.Status_DRIVER_RUNNING, stat) + } + if !driver.connected { + log.Infoln("Ignoring revive offers message, disconnected from master.") + return driver.Status(), fmt.Errorf("Not connected to master.") + } + + message := &mesos.ReviveOffersMessage{ + FrameworkId: driver.FrameworkInfo.Id, + } + if err := driver.send(driver.MasterPid, message); err != nil { + log.Errorf("Failed to send ReviveOffers message: %v\n", err) + return driver.Status(), err + } + + return driver.Status(), nil +} + +func (driver *MesosSchedulerDriver) SendFrameworkMessage(executorId *mesos.ExecutorID, slaveId *mesos.SlaveID, data string) (mesos.Status, error) { + if stat := driver.Status(); stat != mesos.Status_DRIVER_RUNNING { + return stat, fmt.Errorf("Unable to SendFrameworkMessage, expecting driver status %s, but got %s", mesos.Status_DRIVER_RUNNING, stat) + } + if !driver.connected { + log.Infoln("Ignoring send framework message, disconnected from master.") + return driver.Status(), fmt.Errorf("Not connected to master") + } + + message := &mesos.FrameworkToExecutorMessage{ + SlaveId: slaveId, + FrameworkId: driver.FrameworkInfo.Id, + ExecutorId: executorId, + Data: []byte(data), + } + // Use list of cached slaveIds from previous offers. + // Send frameworkMessage directly to cached slave, otherwise to master. + if driver.cache.containsSlavePid(slaveId) { + slavePid := driver.cache.getSlavePid(slaveId) + if slavePid.Equal(driver.self) { + return driver.Status(), nil + } + if err := driver.send(slavePid, message); err != nil { + log.Errorf("Failed to send framework to executor message: %v\n", err) + return driver.Status(), err + } + } else { + // slavePid not cached, send to master. + if err := driver.send(driver.MasterPid, message); err != nil { + log.Errorf("Failed to send framework to executor message: %v\n", err) + return driver.Status(), err + } + } + + return driver.Status(), nil +} + +func (driver *MesosSchedulerDriver) ReconcileTasks(statuses []*mesos.TaskStatus) (mesos.Status, error) { + if stat := driver.Status(); stat != mesos.Status_DRIVER_RUNNING { + return stat, fmt.Errorf("Unable to ReconcileTasks, expecting driver status %s, but got %s", mesos.Status_DRIVER_RUNNING, stat) + } + if !driver.connected { + log.Infoln("Ignoring send Reconcile Tasks message, disconnected from master.") + return driver.Status(), fmt.Errorf("Not connected to master.") + } + + message := &mesos.ReconcileTasksMessage{ + FrameworkId: driver.FrameworkInfo.Id, + Statuses: statuses, + } + if err := driver.send(driver.MasterPid, message); err != nil { + log.Errorf("Failed to send reconcile tasks message: %v\n", err) + return driver.Status(), err + } + + return driver.Status(), nil +} + +func (driver *MesosSchedulerDriver) error(err string, abortDriver bool) { + if abortDriver { + if driver.Status() == mesos.Status_DRIVER_ABORTED { + log.V(3).Infoln("Ignoring error message, the driver is aborted!") + return + } + + log.Infoln("Aborting driver, got error '", err, "'") + + driver.Abort() + } + + log.V(3).Infof("Sending error '%v'", err) + driver.Scheduler.Error(driver, err) +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/scheduler_intgr_test.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/scheduler_intgr_test.go new file mode 100644 index 00000000000..fc4137c2b9e --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/scheduler_intgr_test.go @@ -0,0 +1,442 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package scheduler + +import ( + "io/ioutil" + "net/http" + "reflect" + "sync" + "testing" + "time" + + "github.com/gogo/protobuf/proto" + log "github.com/golang/glog" + mesos "github.com/mesos/mesos-go/mesosproto" + util "github.com/mesos/mesos-go/mesosutil" + "github.com/mesos/mesos-go/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +// testScuduler is used for testing Schduler callbacks. +type testScheduler struct { + ch chan bool + wg *sync.WaitGroup + s *SchedulerIntegrationTestSuite +} + +// convenience +func (sched *testScheduler) T() *testing.T { + return sched.s.T() +} + +func (sched *testScheduler) Registered(dr SchedulerDriver, fw *mesos.FrameworkID, mi *mesos.MasterInfo) { + log.Infoln("Sched.Registered() called.") + sched.s.Equal(fw.GetValue(), sched.s.registeredFrameworkId.GetValue(), "driver did not register the expected framework ID") + sched.s.Equal(mi.GetIp(), uint32(123456)) + sched.ch <- true +} + +func (sched *testScheduler) Reregistered(dr SchedulerDriver, mi *mesos.MasterInfo) { + log.Infoln("Sched.Reregistered() called") + sched.s.Equal(mi.GetIp(), uint32(123456)) + sched.ch <- true +} + +func (sched *testScheduler) Disconnected(dr SchedulerDriver) { + log.Infoln("Shed.Disconnected() called") +} + +func (sched *testScheduler) ResourceOffers(dr SchedulerDriver, offers []*mesos.Offer) { + log.Infoln("Sched.ResourceOffers called.") + sched.s.NotNil(offers) + sched.s.Equal(len(offers), 1) + sched.ch <- true +} + +func (sched *testScheduler) OfferRescinded(dr SchedulerDriver, oid *mesos.OfferID) { + log.Infoln("Sched.OfferRescinded() called.") + sched.s.NotNil(oid) + sched.s.Equal("test-offer-001", oid.GetValue()) + sched.ch <- true +} + +func (sched *testScheduler) StatusUpdate(dr SchedulerDriver, stat *mesos.TaskStatus) { + log.Infoln("Sched.StatusUpdate() called.") + sched.s.NotNil(stat) + sched.s.Equal("test-task-001", stat.GetTaskId().GetValue()) + sched.wg.Done() + log.Infof("Status update done with waitGroup %v \n", sched.wg) +} + +func (sched *testScheduler) SlaveLost(dr SchedulerDriver, slaveId *mesos.SlaveID) { + log.Infoln("Sched.SlaveLost() called.") + sched.s.NotNil(slaveId) + sched.s.Equal(slaveId.GetValue(), "test-slave-001") + sched.ch <- true +} + +func (sched *testScheduler) FrameworkMessage(dr SchedulerDriver, execId *mesos.ExecutorID, slaveId *mesos.SlaveID, data string) { + log.Infoln("Sched.FrameworkMessage() called.") + sched.s.NotNil(slaveId) + sched.s.Equal(slaveId.GetValue(), "test-slave-001") + sched.s.NotNil(execId) + sched.s.NotNil(data) + sched.s.Equal("test-data-999", string(data)) + sched.ch <- true +} + +func (sched *testScheduler) ExecutorLost(SchedulerDriver, *mesos.ExecutorID, *mesos.SlaveID, int) { + log.Infoln("Sched.ExecutorLost called") +} + +func (sched *testScheduler) Error(dr SchedulerDriver, err string) { + log.Infoln("Sched.Error() called.") + sched.s.Equal("test-error-999", err) + sched.ch <- true +} + +func (sched *testScheduler) waitForCallback(timeout time.Duration) bool { + if timeout == 0 { + timeout = 2 * time.Second + } + select { + case <-sched.ch: + //callback complete + return true + case <-time.After(timeout): + sched.T().Fatalf("timed out after waiting %v for callback", timeout) + } + return false +} + +func newTestScheduler(s *SchedulerIntegrationTestSuite) *testScheduler { + return &testScheduler{ch: make(chan bool), s: s} +} + +type mockServerConfigurator func(frameworkId *mesos.FrameworkID, suite *SchedulerIntegrationTestSuite) + +type SchedulerIntegrationTestSuiteCore struct { + SchedulerTestSuiteCore + server *testutil.MockMesosHttpServer + driver *MesosSchedulerDriver + sched *testScheduler + config mockServerConfigurator + validator http.HandlerFunc + registeredFrameworkId *mesos.FrameworkID +} + +type SchedulerIntegrationTestSuite struct { + suite.Suite + SchedulerIntegrationTestSuiteCore +} + +// sets up a mock Mesos HTTP master listener, scheduler, and scheduler driver for testing. +// attempts to wait for a registered or re-registered callback on the suite.sched. +func (suite *SchedulerIntegrationTestSuite) configure(frameworkId *mesos.FrameworkID) bool { + t := suite.T() + // start mock master server to handle connection + suite.server = testutil.NewMockMasterHttpServer(t, func(rsp http.ResponseWriter, req *http.Request) { + log.Infoln("MockMaster - rcvd ", req.RequestURI) + if suite.validator != nil { + suite.validator(rsp, req) + } else { + ioutil.ReadAll(req.Body) + defer req.Body.Close() + rsp.WriteHeader(http.StatusAccepted) + } + }) + + t.Logf("test HTTP server listening on %v", suite.server.Addr) + suite.sched = newTestScheduler(suite) + suite.sched.ch = make(chan bool, 10) // big enough that it doesn't block callback processing + + suite.driver = newTestSchedulerDriver(suite.T(), suite.sched, suite.framework, suite.server.Addr, nil) + + suite.config(frameworkId, suite) + + stat, err := suite.driver.Start() + suite.NoError(err) + suite.Equal(mesos.Status_DRIVER_RUNNING, stat) + + ok := waitForConnected(t, suite.driver, 2*time.Second) + if ok { + ok = suite.sched.waitForCallback(0) // registered or re-registered callback + } + return ok +} + +func (suite *SchedulerIntegrationTestSuite) configureServerWithRegisteredFramework() bool { + // suite.framework is used to initialize the FrameworkInfo of + // the driver, so if we clear the Id then we'll expect a registration message + id := suite.framework.Id + suite.framework.Id = nil + suite.registeredFrameworkId = id + return suite.configure(id) +} + +var defaultMockServerConfigurator = mockServerConfigurator(func(frameworkId *mesos.FrameworkID, suite *SchedulerIntegrationTestSuite) { + t := suite.T() + masterInfo := util.NewMasterInfo("master", 123456, 1234) + suite.server.On("/master/mesos.internal.RegisterFrameworkMessage").Do(func(rsp http.ResponseWriter, req *http.Request) { + if suite.validator != nil { + t.Logf("validating registration request") + suite.validator(rsp, req) + } else { + ioutil.ReadAll(req.Body) + defer req.Body.Close() + rsp.WriteHeader(http.StatusAccepted) + } + // this is what the mocked scheduler is expecting to receive + suite.driver.frameworkRegistered(suite.driver.MasterPid, &mesos.FrameworkRegisteredMessage{ + FrameworkId: frameworkId, + MasterInfo: masterInfo, + }) + }) + suite.server.On("/master/mesos.internal.ReregisterFrameworkMessage").Do(func(rsp http.ResponseWriter, req *http.Request) { + if suite.validator != nil { + suite.validator(rsp, req) + } else { + ioutil.ReadAll(req.Body) + defer req.Body.Close() + rsp.WriteHeader(http.StatusAccepted) + } + // this is what the mocked scheduler is expecting to receive + suite.driver.frameworkReregistered(suite.driver.MasterPid, &mesos.FrameworkReregisteredMessage{ + FrameworkId: frameworkId, + MasterInfo: masterInfo, + }) + }) +}) + +func (s *SchedulerIntegrationTestSuite) newMockClient() *testutil.MockMesosClient { + return testutil.NewMockMesosClient(s.T(), s.server.PID) +} + +func (s *SchedulerIntegrationTestSuite) SetupTest() { + s.SchedulerTestSuiteCore.SetupTest() + s.config = defaultMockServerConfigurator +} + +func (s *SchedulerIntegrationTestSuite) TearDownTest() { + if s.server != nil { + s.server.Close() + } + if s.driver != nil && s.driver.Status() == mesos.Status_DRIVER_RUNNING { + s.driver.Abort() + } +} + +// ---------------------------------- Tests ---------------------------------- // + +func TestSchedulerIntegrationSuite(t *testing.T) { + suite.Run(t, new(SchedulerIntegrationTestSuite)) +} + +func (suite *SchedulerIntegrationTestSuite) TestSchedulerDriverRegisterFrameworkMessage() { + t := suite.T() + + id := suite.framework.Id + suite.framework.Id = nil + validated := make(chan struct{}) + var closeOnce sync.Once + suite.validator = http.HandlerFunc(func(rsp http.ResponseWriter, req *http.Request) { + t.Logf("RCVD request %s", req.URL) + + data, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Fatalf("Missing message data from request") + } + defer req.Body.Close() + + if "/master/mesos.internal.RegisterFrameworkMessage" != req.RequestURI { + rsp.WriteHeader(http.StatusAccepted) + return + } + + defer closeOnce.Do(func() { close(validated) }) + + message := new(mesos.RegisterFrameworkMessage) + err = proto.Unmarshal(data, message) + if err != nil { + t.Fatal("Problem unmarshaling expected RegisterFrameworkMessage") + } + + suite.NotNil(message) + info := message.GetFramework() + suite.NotNil(info) + suite.Equal(suite.framework.GetName(), info.GetName()) + suite.True(reflect.DeepEqual(suite.framework.GetId(), info.GetId())) + rsp.WriteHeader(http.StatusOK) + }) + ok := suite.configure(id) + suite.True(ok, "failed to establish running test server and driver") + select { + case <-time.After(1 * time.Second): + t.Fatalf("failed to complete validation of framework registration message") + case <-validated: + // noop + } +} + +func (suite *SchedulerIntegrationTestSuite) TestSchedulerDriverFrameworkRegisteredEvent() { + ok := suite.configureServerWithRegisteredFramework() + suite.True(ok, "failed to establish running test server and driver") +} + +func (suite *SchedulerIntegrationTestSuite) TestSchedulerDriverFrameworkReregisteredEvent() { + ok := suite.configure(suite.framework.Id) + suite.True(ok, "failed to establish running test server and driver") +} + +func (suite *SchedulerIntegrationTestSuite) TestSchedulerDriverResourceOffersEvent() { + ok := suite.configureServerWithRegisteredFramework() + suite.True(ok, "failed to establish running test server and driver") + + // Send a event to this SchedulerDriver (via http) to test handlers. + offer := util.NewOffer( + util.NewOfferID("test-offer-001"), + suite.registeredFrameworkId, + util.NewSlaveID("test-slave-001"), + "test-localhost", + ) + pbMsg := &mesos.ResourceOffersMessage{ + Offers: []*mesos.Offer{offer}, + Pids: []string{"test-offer-001@test-slave-001:5051"}, + } + + c := suite.newMockClient() + c.SendMessage(suite.driver.self, pbMsg) + suite.sched.waitForCallback(0) +} + +func (suite *SchedulerIntegrationTestSuite) TestSchedulerDriverRescindOfferEvent() { + ok := suite.configureServerWithRegisteredFramework() + suite.True(ok, "failed to establish running test server and driver") + + // Send a event to this SchedulerDriver (via http) to test handlers. + pbMsg := &mesos.RescindResourceOfferMessage{ + OfferId: util.NewOfferID("test-offer-001"), + } + + c := suite.newMockClient() + c.SendMessage(suite.driver.self, pbMsg) + suite.sched.waitForCallback(0) +} + +func (suite *SchedulerIntegrationTestSuite) TestSchedulerDriverStatusUpdatedEvent() { + t := suite.T() + var wg sync.WaitGroup + wg.Add(2) + suite.config = mockServerConfigurator(func(frameworkId *mesos.FrameworkID, suite *SchedulerIntegrationTestSuite) { + defaultMockServerConfigurator(frameworkId, suite) + suite.server.On("/master/mesos.internal.StatusUpdateAcknowledgementMessage").Do(func(rsp http.ResponseWriter, req *http.Request) { + log.Infoln("Master cvd ACK") + data, _ := ioutil.ReadAll(req.Body) + defer req.Body.Close() + assert.NotNil(t, data) + wg.Done() + log.Infof("MockMaster - Done with wait group %v \n", wg) + }) + suite.sched.wg = &wg + }) + + ok := suite.configureServerWithRegisteredFramework() + suite.True(ok, "failed to establish running test server and driver") + + // Send a event to this SchedulerDriver (via http) to test handlers. + pbMsg := &mesos.StatusUpdateMessage{ + Update: util.NewStatusUpdate( + suite.registeredFrameworkId, + util.NewTaskStatus(util.NewTaskID("test-task-001"), mesos.TaskState_TASK_STARTING), + float64(time.Now().Unix()), + []byte("test-abcd-ef-3455-454-001"), + ), + Pid: proto.String(suite.driver.self.String()), + } + pbMsg.Update.SlaveId = &mesos.SlaveID{Value: proto.String("test-slave-001")} + + c := suite.newMockClient() + c.SendMessage(suite.driver.self, pbMsg) + wg.Wait() +} + +func (suite *SchedulerIntegrationTestSuite) TestSchedulerDriverLostSlaveEvent() { + ok := suite.configureServerWithRegisteredFramework() + suite.True(ok, "failed to establish running test server and driver") + + // Send a event to this SchedulerDriver (via http) to test handlers. offer := util.NewOffer( + pbMsg := &mesos.LostSlaveMessage{ + SlaveId: util.NewSlaveID("test-slave-001"), + } + + c := suite.newMockClient() + c.SendMessage(suite.driver.self, pbMsg) + suite.sched.waitForCallback(0) +} + +func (suite *SchedulerIntegrationTestSuite) TestSchedulerDriverFrameworkMessageEvent() { + ok := suite.configureServerWithRegisteredFramework() + suite.True(ok, "failed to establish running test server and driver") + + // Send a event to this SchedulerDriver (via http) to test handlers. offer := util.NewOffer( + pbMsg := &mesos.ExecutorToFrameworkMessage{ + SlaveId: util.NewSlaveID("test-slave-001"), + FrameworkId: suite.registeredFrameworkId, + ExecutorId: util.NewExecutorID("test-executor-001"), + Data: []byte("test-data-999"), + } + + c := suite.newMockClient() + c.SendMessage(suite.driver.self, pbMsg) + suite.sched.waitForCallback(0) +} + +func waitForConnected(t *testing.T, driver *MesosSchedulerDriver, timeout time.Duration) bool { + connected := make(chan struct{}) + go func() { + defer close(connected) + for !driver.Connected() { + time.Sleep(200 * time.Millisecond) + } + }() + select { + case <-time.After(timeout): + t.Fatalf("driver failed to establish connection within %v", timeout) + return false + case <-connected: + return true + } +} + +func (suite *SchedulerIntegrationTestSuite) TestSchedulerDriverFrameworkErrorEvent() { + ok := suite.configureServerWithRegisteredFramework() + suite.True(ok, "failed to establish running test server and driver") + + // Send an error event to this SchedulerDriver (via http) to test handlers. + pbMsg := &mesos.FrameworkErrorMessage{ + Message: proto.String("test-error-999"), + } + + c := suite.newMockClient() + c.SendMessage(suite.driver.self, pbMsg) + suite.sched.waitForCallback(0) + suite.Equal(mesos.Status_DRIVER_ABORTED, suite.driver.Status()) +} diff --git a/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/scheduler_unit_test.go b/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/scheduler_unit_test.go new file mode 100644 index 00000000000..423459f6a8a --- /dev/null +++ b/Godeps/_workspace/src/github.com/mesos/mesos-go/scheduler/scheduler_unit_test.go @@ -0,0 +1,653 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package scheduler + +import ( + "fmt" + "os" + "os/user" + "sync" + "testing" + "time" + + "github.com/gogo/protobuf/proto" + log "github.com/golang/glog" + "github.com/mesos/mesos-go/detector" + "github.com/mesos/mesos-go/detector/zoo" + mesos "github.com/mesos/mesos-go/mesosproto" + util "github.com/mesos/mesos-go/mesosutil" + "github.com/mesos/mesos-go/messenger" + "github.com/mesos/mesos-go/upid" + "github.com/samuel/go-zookeeper/zk" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" +) + +var ( + registerMockDetectorOnce sync.Once +) + +func ensureMockDetectorRegistered() { + registerMockDetectorOnce.Do(func() { + var s *SchedulerTestSuite + err := s.registerMockDetector("testing://") + if err != nil { + log.Error(err) + } + }) +} + +type MockDetector struct { + mock.Mock + address string +} + +func (m *MockDetector) Detect(listener detector.MasterChanged) error { + if listener != nil { + if pid, err := upid.Parse("master(2)@" + m.address); err != nil { + return err + } else { + go listener.OnMasterChanged(detector.CreateMasterInfo(pid)) + } + } + return nil +} + +func (m *MockDetector) Done() <-chan struct{} { + return nil +} + +func (m *MockDetector) Cancel() {} + +type SchedulerTestSuiteCore struct { + master string + masterUpid string + masterId string + frameworkID string + framework *mesos.FrameworkInfo +} + +type SchedulerTestSuite struct { + suite.Suite + SchedulerTestSuiteCore +} + +func (s *SchedulerTestSuite) registerMockDetector(prefix string) error { + address := "" + if s != nil { + address = s.master + } else { + address = "127.0.0.1:8080" + } + return detector.Register(prefix, detector.PluginFactory(func(spec string) (detector.Master, error) { + return &MockDetector{address: address}, nil + })) +} + +func (s *SchedulerTestSuiteCore) SetupTest() { + s.master = "127.0.0.1:8080" + s.masterUpid = "master(2)@" + s.master + s.masterId = "some-master-id-uuid" + s.frameworkID = "some-framework-id-uuid" + s.framework = util.NewFrameworkInfo( + "test-user", + "test-name", + util.NewFrameworkID(s.frameworkID), + ) +} + +func TestSchedulerSuite(t *testing.T) { + t.Logf("running scheduler test suite..") + suite.Run(t, new(SchedulerTestSuite)) +} + +func newTestSchedulerDriver(t *testing.T, sched Scheduler, framework *mesos.FrameworkInfo, master string, cred *mesos.Credential) *MesosSchedulerDriver { + dconfig := DriverConfig{ + Scheduler: sched, + Framework: framework, + Master: master, + Credential: cred, + } + driver, err := NewMesosSchedulerDriver(dconfig) + if err != nil { + t.Fatal(err) + } + return driver +} + +func TestSchedulerDriverNew(t *testing.T) { + masterAddr := "localhost:5050" + driver := newTestSchedulerDriver(t, NewMockScheduler(), &mesos.FrameworkInfo{}, masterAddr, nil) + user, _ := user.Current() + assert.Equal(t, user.Username, driver.FrameworkInfo.GetUser()) + host, _ := os.Hostname() + assert.Equal(t, host, driver.FrameworkInfo.GetHostname()) +} + +func TestSchedulerDriverNew_WithPid(t *testing.T) { + masterAddr := "master@127.0.0.1:5050" + mUpid, err := upid.Parse(masterAddr) + assert.NoError(t, err) + driver := newTestSchedulerDriver(t, NewMockScheduler(), &mesos.FrameworkInfo{}, masterAddr, nil) + driver.handleMasterChanged(driver.self, &mesos.InternalMasterChangeDetected{Master: &mesos.MasterInfo{Pid: proto.String(mUpid.String())}}) + assert.True(t, driver.MasterPid.Equal(mUpid), fmt.Sprintf("expected upid %+v instead of %+v", mUpid, driver.MasterPid)) + assert.NoError(t, err) +} + +func (suite *SchedulerTestSuite) TestSchedulerDriverNew_WithZkUrl() { + masterAddr := "zk://127.0.0.1:5050/mesos" + driver := newTestSchedulerDriver(suite.T(), NewMockScheduler(), suite.framework, masterAddr, nil) + md, err := zoo.NewMockMasterDetector(masterAddr) + suite.NoError(err) + suite.NotNil(md) + driver.masterDetector = md // override internal master detector + + md.ScheduleConnEvent(zk.StateConnected) + + done := make(chan struct{}) + driver.masterDetector.Detect(detector.OnMasterChanged(func(m *mesos.MasterInfo) { + suite.NotNil(m) + suite.NotEqual(m.GetPid, suite.masterUpid) + close(done) + })) + + //TODO(vlad) revisit, detector not responding. + + //NOTE(jdef) this works for me, I wonder if the timeouts are too short, or if + //GOMAXPROCS settings are affecting the result? + + // md.ScheduleSessEvent(zk.EventNodeChildrenChanged) + // select { + // case <-done: + // case <-time.After(time.Millisecond * 1000): + // suite.T().Errorf("Timed out waiting for children event.") + // } +} + +func (suite *SchedulerTestSuite) TestSchedulerDriverNew_WithFrameworkInfo_Override() { + suite.framework.Hostname = proto.String("local-host") + driver := newTestSchedulerDriver(suite.T(), NewMockScheduler(), suite.framework, "127.0.0.1:5050", nil) + suite.Equal(driver.FrameworkInfo.GetUser(), "test-user") + suite.Equal("local-host", driver.FrameworkInfo.GetHostname()) +} + +func (suite *SchedulerTestSuite) TestSchedulerDriverStartOK() { + sched := NewMockScheduler() + + messenger := messenger.NewMockedMessenger() + messenger.On("Start").Return(nil) + messenger.On("UPID").Return(&upid.UPID{}) + messenger.On("Send").Return(nil) + messenger.On("Stop").Return(nil) + + driver := newTestSchedulerDriver(suite.T(), sched, suite.framework, suite.master, nil) + driver.messenger = messenger + suite.True(driver.Stopped()) + + stat, err := driver.Start() + suite.NoError(err) + suite.Equal(mesos.Status_DRIVER_RUNNING, stat) + suite.False(driver.Stopped()) +} + +func (suite *SchedulerTestSuite) TestSchedulerDriverStartWithMessengerFailure() { + sched := NewMockScheduler() + sched.On("Error").Return() + + messenger := messenger.NewMockedMessenger() + messenger.On("Start").Return(fmt.Errorf("Failed to start messenger")) + messenger.On("Stop").Return() + + driver := newTestSchedulerDriver(suite.T(), sched, suite.framework, suite.master, nil) + driver.messenger = messenger + suite.True(driver.Stopped()) + + stat, err := driver.Start() + suite.Error(err) + suite.True(driver.Stopped()) + suite.True(!driver.Connected()) + suite.Equal(mesos.Status_DRIVER_NOT_STARTED, driver.Status()) + suite.Equal(mesos.Status_DRIVER_NOT_STARTED, stat) + +} + +func (suite *SchedulerTestSuite) TestSchedulerDriverStartWithRegistrationFailure() { + sched := NewMockScheduler() + sched.On("Error").Return() + + // Set expections and return values. + messenger := messenger.NewMockedMessenger() + messenger.On("Start").Return(nil) + messenger.On("UPID").Return(&upid.UPID{}) + messenger.On("Stop").Return(nil) + + driver := newTestSchedulerDriver(suite.T(), sched, suite.framework, suite.master, nil) + + driver.messenger = messenger + suite.True(driver.Stopped()) + + // reliable registration loops until the driver is stopped, connected, etc.. + stat, err := driver.Start() + suite.NoError(err) + suite.Equal(mesos.Status_DRIVER_RUNNING, stat) + + time.Sleep(5 * time.Second) // wait a bit, registration should be looping... + + suite.False(driver.Stopped()) + suite.Equal(mesos.Status_DRIVER_RUNNING, driver.Status()) + + // stop the driver, should not panic! + driver.Stop(false) // not failing over + suite.True(driver.Stopped()) + suite.Equal(mesos.Status_DRIVER_STOPPED, driver.Status()) + + messenger.AssertExpectations(suite.T()) +} + +func (suite *SchedulerTestSuite) TestSchedulerDriverJoinUnstarted() { + driver := newTestSchedulerDriver(suite.T(), NewMockScheduler(), suite.framework, suite.master, nil) + suite.True(driver.Stopped()) + + stat, err := driver.Join() + suite.Error(err) + suite.Equal(mesos.Status_DRIVER_NOT_STARTED, stat) +} + +func (suite *SchedulerTestSuite) TestSchedulerDriverJoinOK() { + // Set expections and return values. + messenger := messenger.NewMockedMessenger() + messenger.On("Start").Return(nil) + messenger.On("UPID").Return(&upid.UPID{}) + messenger.On("Send").Return(nil) + messenger.On("Stop").Return(nil) + + driver := newTestSchedulerDriver(suite.T(), NewMockScheduler(), suite.framework, suite.master, nil) + driver.messenger = messenger + suite.True(driver.Stopped()) + + stat, err := driver.Start() + suite.NoError(err) + suite.Equal(mesos.Status_DRIVER_RUNNING, stat) + suite.False(driver.Stopped()) + + testCh := make(chan mesos.Status) + go func() { + stat, _ := driver.Join() + testCh <- stat + }() + + close(driver.stopCh) // manually stopping + stat = <-testCh // when Stop() is called, stat will be DRIVER_STOPPED. +} + +func (suite *SchedulerTestSuite) TestSchedulerDriverRun() { + // Set expections and return values. + messenger := messenger.NewMockedMessenger() + messenger.On("Start").Return(nil) + messenger.On("UPID").Return(&upid.UPID{}) + messenger.On("Send").Return(nil) + messenger.On("Stop").Return(nil) + + driver := newTestSchedulerDriver(suite.T(), NewMockScheduler(), suite.framework, suite.master, nil) + driver.messenger = messenger + suite.True(driver.Stopped()) + + go func() { + stat, err := driver.Run() + suite.NoError(err) + suite.Equal(mesos.Status_DRIVER_STOPPED, stat) + }() + time.Sleep(time.Millisecond * 1) + + suite.False(driver.Stopped()) + suite.Equal(mesos.Status_DRIVER_RUNNING, driver.Status()) + + // close it all. + driver.setStatus(mesos.Status_DRIVER_STOPPED) + close(driver.stopCh) + time.Sleep(time.Millisecond * 1) +} + +func (suite *SchedulerTestSuite) TestSchedulerDriverStopUnstarted() { + driver := newTestSchedulerDriver(suite.T(), NewMockScheduler(), suite.framework, suite.master, nil) + suite.True(driver.Stopped()) + + stat, err := driver.Stop(true) + suite.NotNil(err) + suite.True(driver.Stopped()) + suite.Equal(mesos.Status_DRIVER_NOT_STARTED, stat) +} + +func (suite *SchedulerTestSuite) TestSchdulerDriverStopOK() { + // Set expections and return values. + messenger := messenger.NewMockedMessenger() + messenger.On("Start").Return(nil) + messenger.On("UPID").Return(&upid.UPID{}) + messenger.On("Send").Return(nil) + messenger.On("Stop").Return(nil) + messenger.On("Route").Return(nil) + + driver := newTestSchedulerDriver(suite.T(), NewMockScheduler(), suite.framework, suite.master, nil) + driver.messenger = messenger + suite.True(driver.Stopped()) + + go func() { + stat, err := driver.Run() + suite.NoError(err) + suite.Equal(mesos.Status_DRIVER_STOPPED, stat) + }() + time.Sleep(time.Millisecond * 1) + + suite.False(driver.Stopped()) + suite.Equal(mesos.Status_DRIVER_RUNNING, driver.Status()) + + driver.Stop(false) + time.Sleep(time.Millisecond * 1) + + suite.True(driver.Stopped()) + suite.Equal(mesos.Status_DRIVER_STOPPED, driver.Status()) +} + +func (suite *SchedulerTestSuite) TestSchdulerDriverAbort() { + // Set expections and return values. + messenger := messenger.NewMockedMessenger() + messenger.On("Start").Return(nil) + messenger.On("UPID").Return(&upid.UPID{}) + messenger.On("Send").Return(nil) + messenger.On("Stop").Return(nil) + messenger.On("Route").Return(nil) + + driver := newTestSchedulerDriver(suite.T(), NewMockScheduler(), suite.framework, suite.master, nil) + driver.messenger = messenger + suite.True(driver.Stopped()) + + go func() { + stat, err := driver.Run() + suite.NoError(err) + suite.Equal(mesos.Status_DRIVER_ABORTED, stat) + }() + time.Sleep(time.Millisecond * 1) + driver.setConnected(true) // simulated + + suite.False(driver.Stopped()) + suite.Equal(mesos.Status_DRIVER_RUNNING, driver.Status()) + + stat, err := driver.Abort() + time.Sleep(time.Millisecond * 1) + suite.NoError(err) + suite.True(driver.Stopped()) + suite.Equal(mesos.Status_DRIVER_ABORTED, stat) + suite.Equal(mesos.Status_DRIVER_ABORTED, driver.Status()) +} + +func (suite *SchedulerTestSuite) TestSchdulerDriverLunchTasksUnstarted() { + sched := NewMockScheduler() + sched.On("Error").Return() + + // Set expections and return values. + messenger := messenger.NewMockedMessenger() + messenger.On("Route").Return(nil) + + driver := newTestSchedulerDriver(suite.T(), sched, suite.framework, suite.master, nil) + driver.messenger = messenger + suite.True(driver.Stopped()) + + stat, err := driver.LaunchTasks( + []*mesos.OfferID{&mesos.OfferID{}}, + []*mesos.TaskInfo{}, + &mesos.Filters{}, + ) + suite.Error(err) + suite.Equal(mesos.Status_DRIVER_NOT_STARTED, stat) +} + +func (suite *SchedulerTestSuite) TestSchdulerDriverLaunchTasksWithError() { + sched := NewMockScheduler() + sched.On("StatusUpdate").Return(nil) + sched.On("Error").Return() + + msgr := messenger.NewMockedMessenger() + msgr.On("Start").Return(nil) + msgr.On("Send").Return(nil) + msgr.On("UPID").Return(&upid.UPID{}) + msgr.On("Stop").Return(nil) + msgr.On("Route").Return(nil) + + driver := newTestSchedulerDriver(suite.T(), sched, suite.framework, suite.master, nil) + driver.messenger = msgr + suite.True(driver.Stopped()) + + go func() { + driver.Run() + }() + time.Sleep(time.Millisecond * 1) + driver.setConnected(true) // simulated + suite.False(driver.Stopped()) + suite.Equal(mesos.Status_DRIVER_RUNNING, driver.Status()) + + // to trigger error + msgr2 := messenger.NewMockedMessenger() + msgr2.On("Start").Return(nil) + msgr2.On("UPID").Return(&upid.UPID{}) + msgr2.On("Send").Return(fmt.Errorf("Unable to send message")) + msgr2.On("Stop").Return(nil) + msgr.On("Route").Return(nil) + driver.messenger = msgr2 + + // setup an offer + offer := util.NewOffer( + util.NewOfferID("test-offer-001"), + suite.framework.Id, + util.NewSlaveID("test-slave-001"), + "test-slave(1)@localhost:5050", + ) + + pid, err := upid.Parse("test-slave(1)@localhost:5050") + suite.NoError(err) + driver.cache.putOffer(offer, pid) + + // launch task + task := util.NewTaskInfo( + "simple-task", + util.NewTaskID("simpe-task-1"), + util.NewSlaveID("test-slave-001"), + []*mesos.Resource{util.NewScalarResource("mem", 400)}, + ) + task.Command = util.NewCommandInfo("pwd") + task.Executor = util.NewExecutorInfo(util.NewExecutorID("test-exec"), task.Command) + tasks := []*mesos.TaskInfo{task} + + stat, err := driver.LaunchTasks( + []*mesos.OfferID{offer.Id}, + tasks, + &mesos.Filters{}, + ) + suite.Error(err) + suite.Equal(mesos.Status_DRIVER_RUNNING, stat) + +} + +func (suite *SchedulerTestSuite) TestSchdulerDriverLaunchTasks() { + messenger := messenger.NewMockedMessenger() + messenger.On("Start").Return(nil) + messenger.On("UPID").Return(&upid.UPID{}) + messenger.On("Send").Return(nil) + messenger.On("Stop").Return(nil) + messenger.On("Route").Return(nil) + + driver := newTestSchedulerDriver(suite.T(), NewMockScheduler(), suite.framework, suite.master, nil) + driver.messenger = messenger + suite.True(driver.Stopped()) + + go func() { + driver.Run() + }() + time.Sleep(time.Millisecond * 1) + driver.setConnected(true) // simulated + suite.False(driver.Stopped()) + suite.Equal(mesos.Status_DRIVER_RUNNING, driver.Status()) + + task := util.NewTaskInfo( + "simple-task", + util.NewTaskID("simpe-task-1"), + util.NewSlaveID("slave-1"), + []*mesos.Resource{util.NewScalarResource("mem", 400)}, + ) + task.Command = util.NewCommandInfo("pwd") + tasks := []*mesos.TaskInfo{task} + + stat, err := driver.LaunchTasks( + []*mesos.OfferID{&mesos.OfferID{}}, + tasks, + &mesos.Filters{}, + ) + suite.NoError(err) + suite.Equal(mesos.Status_DRIVER_RUNNING, stat) +} + +func (suite *SchedulerTestSuite) TestSchdulerDriverKillTask() { + messenger := messenger.NewMockedMessenger() + messenger.On("Start").Return(nil) + messenger.On("UPID").Return(&upid.UPID{}) + messenger.On("Send").Return(nil) + messenger.On("Stop").Return(nil) + messenger.On("Route").Return(nil) + + driver := newTestSchedulerDriver(suite.T(), NewMockScheduler(), suite.framework, suite.master, nil) + driver.messenger = messenger + suite.True(driver.Stopped()) + + go func() { + driver.Run() + }() + time.Sleep(time.Millisecond * 1) + driver.setConnected(true) // simulated + suite.False(driver.Stopped()) + suite.Equal(mesos.Status_DRIVER_RUNNING, driver.Status()) + + stat, err := driver.KillTask(util.NewTaskID("test-task-1")) + suite.NoError(err) + suite.Equal(mesos.Status_DRIVER_RUNNING, stat) +} + +func (suite *SchedulerTestSuite) TestSchdulerDriverRequestResources() { + messenger := messenger.NewMockedMessenger() + messenger.On("Start").Return(nil) + messenger.On("UPID").Return(&upid.UPID{}) + messenger.On("Send").Return(nil) + messenger.On("Stop").Return(nil) + messenger.On("Route").Return(nil) + + driver := newTestSchedulerDriver(suite.T(), NewMockScheduler(), suite.framework, suite.master, nil) + driver.messenger = messenger + suite.True(driver.Stopped()) + + driver.Start() + driver.setConnected(true) // simulated + suite.Equal(mesos.Status_DRIVER_RUNNING, driver.Status()) + + stat, err := driver.RequestResources( + []*mesos.Request{ + &mesos.Request{ + SlaveId: util.NewSlaveID("test-slave-001"), + Resources: []*mesos.Resource{ + util.NewScalarResource("test-res-001", 33.00), + }, + }, + }, + ) + suite.NoError(err) + suite.Equal(mesos.Status_DRIVER_RUNNING, stat) +} + +func (suite *SchedulerTestSuite) TestSchdulerDriverDeclineOffers() { + // see LaunchTasks test +} + +func (suite *SchedulerTestSuite) TestSchdulerDriverReviveOffers() { + messenger := messenger.NewMockedMessenger() + messenger.On("Start").Return(nil) + messenger.On("UPID").Return(&upid.UPID{}) + messenger.On("Send").Return(nil) + messenger.On("Stop").Return(nil) + messenger.On("Route").Return(nil) + + driver := newTestSchedulerDriver(suite.T(), NewMockScheduler(), suite.framework, suite.master, nil) + driver.messenger = messenger + suite.True(driver.Stopped()) + + driver.Start() + driver.setConnected(true) // simulated + suite.Equal(mesos.Status_DRIVER_RUNNING, driver.Status()) + + stat, err := driver.ReviveOffers() + suite.NoError(err) + suite.Equal(mesos.Status_DRIVER_RUNNING, stat) +} + +func (suite *SchedulerTestSuite) TestSchdulerDriverSendFrameworkMessage() { + messenger := messenger.NewMockedMessenger() + messenger.On("Start").Return(nil) + messenger.On("UPID").Return(&upid.UPID{}) + messenger.On("Send").Return(nil) + messenger.On("Stop").Return(nil) + messenger.On("Route").Return(nil) + + driver := newTestSchedulerDriver(suite.T(), NewMockScheduler(), suite.framework, suite.master, nil) + driver.messenger = messenger + suite.True(driver.Stopped()) + + driver.Start() + driver.setConnected(true) // simulated + suite.Equal(mesos.Status_DRIVER_RUNNING, driver.Status()) + + stat, err := driver.SendFrameworkMessage( + util.NewExecutorID("test-exec-001"), + util.NewSlaveID("test-slave-001"), + "Hello!", + ) + suite.NoError(err) + suite.Equal(mesos.Status_DRIVER_RUNNING, stat) +} + +func (suite *SchedulerTestSuite) TestSchdulerDriverReconcileTasks() { + messenger := messenger.NewMockedMessenger() + messenger.On("Start").Return(nil) + messenger.On("UPID").Return(&upid.UPID{}) + messenger.On("Send").Return(nil) + messenger.On("Stop").Return(nil) + messenger.On("Route").Return(nil) + + driver := newTestSchedulerDriver(suite.T(), NewMockScheduler(), suite.framework, suite.master, nil) + driver.messenger = messenger + suite.True(driver.Stopped()) + + driver.Start() + driver.setConnected(true) // simulated + suite.Equal(mesos.Status_DRIVER_RUNNING, driver.Status()) + + stat, err := driver.ReconcileTasks( + []*mesos.TaskStatus{ + util.NewTaskStatus(util.NewTaskID("test-task-001"), mesos.TaskState_TASK_FINISHED), + }, + ) + suite.NoError(err) + suite.Equal(mesos.Status_DRIVER_RUNNING, stat) +}