Use bind filer for mounts

The bind filter supports bind-like mounts and volume mounts. It also
allows us to have read-only mounts.

Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
This commit is contained in:
Gabriel Adrian Samfira 2022-10-12 00:03:11 +03:00
parent d591bb0421
commit 36dc2782c4
4 changed files with 493 additions and 73 deletions

View File

@ -0,0 +1,327 @@
package mount
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"syscall"
"unicode/utf16"
"unsafe"
"golang.org/x/sys/windows"
)
//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go ./*.go
//sys BfSetupFilter(jobHandle windows.Handle, flags uint32, virtRootPath *uint16, virtTargetPath *uint16, virtExceptions **uint16, virtExceptionPathCount uint32) (hr error) = bindfltapi.BfSetupFilter?
//sys BfRemoveMapping(jobHandle windows.Handle, virtRootPath *uint16) (hr error) = bindfltapi.BfRemoveMapping?
//sys BfGetMappings(flags uint32, jobHandle windows.Handle, virtRootPath *uint16, sid *windows.SID, bufferSize *uint32, outBuffer uintptr) (hr error) = bindfltapi.BfGetMappings?
// BfSetupFilter flags. See:
// https://github.com/microsoft/BuildXL/blob/a6dce509f0d4f774255e5fbfb75fa6d5290ed163/Public/Src/Utilities/Native/Processes/Windows/NativeContainerUtilities.cs#L193-L240
const (
BINDFLT_FLAG_READ_ONLY_MAPPING uint32 = 0x00000001
// Generates a merged binding, mapping target entries to the virtualization root.
BINDFLT_FLAG_MERGED_BIND_MAPPING uint32 = 0x00000002
// Use the binding mapping attached to the mapped-in job object (silo) instead of the default global mapping.
BINDFLT_FLAG_USE_CURRENT_SILO_MAPPING uint32 = 0x00000004
BINDFLT_FLAG_REPARSE_ON_FILES uint32 = 0x00000008
// Skips checks on file/dir creation inside a non-merged, read-only mapping.
// Only usable when READ_ONLY_MAPPING is set.
BINDFLT_FLAG_SKIP_SHARING_CHECK uint32 = 0x00000010
BINDFLT_FLAG_CLOUD_FILES_ECPS uint32 = 0x00000020
// Tells bindflt to fail mapping with STATUS_INVALID_PARAMETER if a mapping produces
// multiple targets.
BINDFLT_FLAG_NO_MULTIPLE_TARGETS uint32 = 0x00000040
// Turns on caching by asserting that the backing store for name mappings is immutable.
BINDFLT_FLAG_IMMUTABLE_BACKING uint32 = 0x00000080
BINDFLT_FLAG_PREVENT_CASE_SENSITIVE_BINDING uint32 = 0x00000100
// Tells bindflt to fail with STATUS_OBJECT_PATH_NOT_FOUND when a mapping is being added
// but its parent paths (ancestors) have not already been added.
BINDFLT_FLAG_EMPTY_VIRT_ROOT uint32 = 0x00000200
BINDFLT_FLAG_NO_REPARSE_ON_ROOT uint32 = 0x10000000
BINDFLT_FLAG_BATCHED_REMOVE_MAPPINGS uint32 = 0x20000000
)
const (
BINDFLT_GET_MAPPINGS_FLAG_VOLUME uint32 = 0x00000001
BINDFLT_GET_MAPPINGS_FLAG_SILO uint32 = 0x00000002
BINDFLT_GET_MAPPINGS_FLAG_USER uint32 = 0x00000004
)
// ApplyFileBinding creates a global mount of the source in root, with an optional
// read only flag.
// The bind filter allows us to create mounts of directories and volumes. By default it allows
// us to mount multiple sources inside a single root, acting as an overlay. Files from the
// second source will superscede the first source that was mounted.
// This function disables this behavior and sets the BINDFLT_FLAG_NO_MULTIPLE_TARGETS flag
// on the mount.
func ApplyFileBinding(root, source string, readOnly bool) error {
// The parent directory needs to exist for the bind to work. MkdirAll stats and
// returns nil if the directory exists internally so we should be fine to mkdirall
// every time.
if err := os.MkdirAll(filepath.Dir(root), 0); err != nil {
return err
}
if strings.Contains(source, "Volume{") && !strings.HasSuffix(source, "\\") {
// Add trailing slash to volumes, otherwise we get an error when binding it to
// a folder.
source = source + "\\"
}
rootPtr, err := windows.UTF16PtrFromString(root)
if err != nil {
return err
}
targetPtr, err := windows.UTF16PtrFromString(source)
if err != nil {
return err
}
flags := BINDFLT_FLAG_NO_MULTIPLE_TARGETS
if readOnly {
flags |= BINDFLT_FLAG_READ_ONLY_MAPPING
}
// Set the job handle to 0 to create a global mount.
if err := BfSetupFilter(
0,
flags,
rootPtr,
targetPtr,
nil,
0,
); err != nil {
return fmt.Errorf("failed to bind target %q to root %q: %w", source, root, err)
}
return nil
}
func RemoveFileBinding(root string) error {
rootPtr, err := windows.UTF16PtrFromString(root)
if err != nil {
return fmt.Errorf("converting path to utf-16: %w", err)
}
if err := BfRemoveMapping(0, rootPtr); err != nil {
return fmt.Errorf("removing file binding: %w", err)
}
return nil
}
// mappingEntry holds information about where in the response buffer we can
// find information about the virtual root (the mount point) and the targets (sources)
// that get mounted, as well as the flags used to bind the targets to the virtual root.
type mappingEntry struct {
VirtRootLength uint32
VirtRootOffset uint32
Flags uint32
NumberOfTargets uint32
TargetEntriesOffset uint32
}
type mappingTargetEntry struct {
TargetRootLength uint32
TargetRootOffset uint32
}
// getMappingsResponseHeader represents the first 12 bytes of the BfGetMappings() response.
// It gives us the size of the buffer, the status of the call and the number of mappings.
// A response
type getMappingsResponseHeader struct {
Size uint32
Status uint32
MappingCount uint32
}
type BindMapping struct {
MountPoint string
Flags uint32
Targets []string
}
func decodeEntry(buffer []byte) (string, error) {
name := make([]uint16, len(buffer)/2)
err := binary.Read(bytes.NewReader(buffer), binary.LittleEndian, &name)
if err != nil {
return "", fmt.Errorf("decoding name: %w", err)
}
return string(utf16.Decode(name)), nil
}
func getTargetsFromBuffer(buffer []byte, offset, count int) ([]string, error) {
if len(buffer) < offset+count*6 {
return nil, fmt.Errorf("invalid buffer")
}
targets := make([]string, count)
for i := 0; i < count; i++ {
entryBuf := buffer[offset+i*8 : offset+i*8+8]
tgt := *(*mappingTargetEntry)(unsafe.Pointer(&entryBuf[0]))
if len(buffer) < int(tgt.TargetRootOffset)+int(tgt.TargetRootLength) {
return nil, fmt.Errorf("invalid buffer")
}
decoded, err := decodeEntry(buffer[tgt.TargetRootOffset : uint32(tgt.TargetRootOffset)+uint32(tgt.TargetRootLength)])
if err != nil {
return nil, fmt.Errorf("decoding name: %w", err)
}
decoded, err = getFinalPath(decoded)
if err != nil {
return nil, fmt.Errorf("fetching final path: %w", err)
}
targets[i] = decoded
}
return targets, nil
}
func getFinalPath(pth string) (string, error) {
// BfGetMappings returns VOLUME_NAME_NT paths like \Device\HarddiskVolume2\ProgramData.
// These can be accessed by prepending \\.\GLOBALROOT to the path. We use this to get the
// DOS paths for these files.
if strings.HasPrefix(pth, `\Device`) {
pth = `\\.\GLOBALROOT` + pth
}
han, err := getFileHandle(pth)
if err != nil {
return "", fmt.Errorf("fetching file handle: %w", err)
}
buf := make([]uint16, 100)
var flags uint32 = 0x0
for {
n, err := windows.GetFinalPathNameByHandle(windows.Handle(han), &buf[0], uint32(len(buf)), flags)
if err != nil {
// if we mounted a volume that does not also have a drive letter assigned, attempting to
// fetch the VOLUME_NAME_DOS will fail with os.ErrNotExist. Attempt to get the VOLUME_NAME_GUID.
if errors.Is(err, os.ErrNotExist) && flags != 0x1 {
flags = 0x1
continue
}
return "", fmt.Errorf("getting final path name: %w", err)
}
if n < uint32(len(buf)) {
break
}
buf = make([]uint16, n)
}
finalPath := syscall.UTF16ToString(buf)
// We got VOLUME_NAME_DOS, we need to strip away some leading slashes.
// Leave unchanged if we ended up requesting VOLUME_NAME_GUID
if len(finalPath) > 4 && finalPath[:4] == `\\?\` && flags == 0x0 {
finalPath = finalPath[4:]
if len(finalPath) > 3 && finalPath[:3] == `UNC` {
// return path like \\server\share\...
finalPath = `\` + finalPath[3:]
}
}
return finalPath, nil
}
func getBindMappingFromBuffer(buffer []byte, entry mappingEntry) (BindMapping, error) {
if len(buffer) < int(entry.VirtRootOffset)+int(entry.VirtRootLength) {
return BindMapping{}, fmt.Errorf("invalid buffer")
}
src, err := decodeEntry(buffer[entry.VirtRootOffset : entry.VirtRootOffset+uint32(entry.VirtRootLength)])
if err != nil {
return BindMapping{}, fmt.Errorf("decoding entry: %w", err)
}
targets, err := getTargetsFromBuffer(buffer, int(entry.TargetEntriesOffset), int(entry.NumberOfTargets))
if err != nil {
return BindMapping{}, fmt.Errorf("fetching targets: %w", err)
}
src, err = getFinalPath(src)
if err != nil {
return BindMapping{}, fmt.Errorf("fetching final path: %w", err)
}
return BindMapping{
Flags: entry.Flags,
Targets: targets,
MountPoint: src,
}, nil
}
func getFileHandle(pth string) (syscall.Handle, error) {
info, err := os.Lstat(pth)
if err != nil {
return 0, fmt.Errorf("accessing file: %w", err)
}
p, err := syscall.UTF16PtrFromString(pth)
if err != nil {
return 0, err
}
attrs := uint32(syscall.FILE_FLAG_BACKUP_SEMANTICS)
if info.Mode()&os.ModeSymlink != 0 {
attrs |= syscall.FILE_FLAG_OPEN_REPARSE_POINT
}
h, err := syscall.CreateFile(p, 0, 0, nil, syscall.OPEN_EXISTING, attrs, 0)
if err != nil {
return 0, err
}
return h, nil
}
// GetBindMappings returns a list of bind mappings that have their root on a
// particular volume. The volumePath parameter can be any path that exists on
// a volume. For example, if a number of mappings are created in C:\ProgramData\test,
// to get a list of those mappings, the volumePath parameter would have to be set to
// C:\ or the VOLUME_NAME_GUID notation of C:\ (\\?\Volume{GUID}\), or any child
// path that exists.
func GetBindMappings(volumePath string) ([]BindMapping, error) {
rootPtr, err := windows.UTF16PtrFromString(volumePath)
if err != nil {
return nil, err
}
var flags uint32 = BINDFLT_GET_MAPPINGS_FLAG_VOLUME
// allocate a large buffer for results
var outBuffSize uint32 = 256 * 1024
buf := make([]byte, outBuffSize)
if err := BfGetMappings(flags, 0, rootPtr, nil, &outBuffSize, uintptr(unsafe.Pointer(&buf[0]))); err != nil {
return nil, err
}
if outBuffSize < 12 {
return nil, fmt.Errorf("invalid buffer returned")
}
result := buf[:outBuffSize]
// The first 12 bytes are the three uint32 fields in getMappingsResponseHeader{}
headerBuffer := result[:12]
// The alternative to using unsafe and casting it to the above defined structures, is to manually
// parse the fields. Not too terrible, but not sure it'd worth the trouble.
header := *(*getMappingsResponseHeader)(unsafe.Pointer(&headerBuffer[0]))
if header.MappingCount == 0 {
// no mappings
return []BindMapping{}, nil
}
mappingsBuffer := result[12 : int(unsafe.Sizeof(mappingEntry{}))*int(header.MappingCount)]
// Get a pointer to the first mapping in the slice
mappingsPointer := (*mappingEntry)(unsafe.Pointer(&mappingsBuffer[0]))
// Get slice of mappings
mappings := unsafe.Slice(mappingsPointer, header.MappingCount)
mappingEntries := make([]BindMapping, header.MappingCount)
for i := 0; i < int(header.MappingCount); i++ {
bindMapping, err := getBindMappingFromBuffer(result, mappings[i])
if err != nil {
return nil, fmt.Errorf("fetching bind mappings: %w", err)
}
mappingEntries[i] = bindMapping
}
return mappingEntries, nil
}

View File

@ -36,8 +36,16 @@ var (
// Mount to the provided target.
func (m *Mount) mount(target string) error {
readOnly := false
for _, option := range m.Options {
if option == "ro" {
readOnly = true
break
}
}
if m.Type == "bind" {
if err := m.bindMount(target); err != nil {
if err := ApplyFileBinding(target, m.Source, readOnly); err != nil {
return fmt.Errorf("failed to bind-mount to %s: %w", target, err)
}
return nil
@ -81,12 +89,12 @@ func (m *Mount) mount(target string) error {
return fmt.Errorf("failed to get volume path for layer %s: %w", m.Source, err)
}
if err = setVolumeMountPoint(target, volume); err != nil {
if err = ApplyFileBinding(target, volume, readOnly); err != nil {
return fmt.Errorf("failed to set volume mount path for layer %s: %w", m.Source, err)
}
defer func() {
if err != nil {
deleteVolumeMountPoint(target)
RemoveFileBinding(target)
}
}()
@ -120,50 +128,60 @@ func (m *Mount) GetParentPaths() ([]string, error) {
// Unmount the mount at the provided path
func Unmount(mount string, flags int) error {
var err error
mount = filepath.Clean(mount)
// Helpfully, both reparse points and symlinks look like links to Go
// Less-helpfully, ReadLink cannot return \\?\Volume{GUID} for a volume mount,
// and ends up returning the directory we gave it for some reason.
if mountTarget, err := os.Readlink(mount); err != nil {
// Not a mount point.
// This isn't an error, per the EINVAL handling in the Linux version
return nil
} else if mount != filepath.Clean(mountTarget) {
// Directory symlink
if err := bindUnmount(mount); err != nil {
return fmt.Errorf("failed to bind-unmount from %s: %w", mount, err)
}
return nil
}
layerPathb, err := os.ReadFile(mount + ":" + sourceStreamName)
// This should expand paths like ADMINI~1 and PROGRA~1 to long names.
mount, err = getFinalPath(mount)
if err != nil {
return fmt.Errorf("failed to retrieve source for layer %s: %w", mount, err)
return fmt.Errorf("fetching real path: %w", err)
}
layerPath := string(layerPathb)
if err := deleteVolumeMountPoint(mount); err != nil {
return fmt.Errorf("failed failed to release volume mount path for layer %s: %w", mount, err)
mounts, err := GetBindMappings(mount)
if err != nil {
return fmt.Errorf("fetching bind mappings: %w", err)
}
var (
home, layerID = filepath.Split(layerPath)
di = hcsshim.DriverInfo{
HomeDir: home,
found := false
for _, mnt := range mounts {
if mnt.MountPoint == mount {
found = true
break
}
)
if err := hcsshim.UnprepareLayer(di, layerID); err != nil {
return fmt.Errorf("failed to unprepare layer %s: %w", mount, err)
}
if !found {
// not a mount point
return nil
}
if err := hcsshim.DeactivateLayer(di, layerID); err != nil {
return fmt.Errorf("failed to deactivate layer %s: %w", mount, err)
adsFile := mount + ":" + sourceStreamName
var layerPath string
if _, err := os.Lstat(adsFile); err == nil {
layerPathb, err := os.ReadFile(mount + ":" + sourceStreamName)
if err != nil {
return fmt.Errorf("failed to retrieve source for layer %s: %w", mount, err)
}
layerPath = string(layerPathb)
}
if err := RemoveFileBinding(mount); err != nil {
return fmt.Errorf("removing mount: %w", err)
}
if layerPath != "" {
var (
home, layerID = filepath.Split(layerPath)
di = hcsshim.DriverInfo{
HomeDir: home,
}
)
if err := hcsshim.UnprepareLayer(di, layerID); err != nil {
return fmt.Errorf("failed to unprepare layer %s: %w", mount, err)
}
if err := hcsshim.DeactivateLayer(di, layerID); err != nil {
return fmt.Errorf("failed to deactivate layer %s: %w", mount, err)
}
}
return nil
}
@ -183,21 +201,13 @@ func UnmountRecursive(mount string, flags int) error {
}
func (m *Mount) bindMount(target string) error {
readOnly := false
for _, option := range m.Options {
if option == "ro" {
return fmt.Errorf("read-only bind mount: %w", ErrNotImplementOnWindows)
readOnly = true
break
}
}
if err := os.Remove(target); err != nil {
return err
}
// TODO: We don't honour the Read-Only flag.
// It's possible that Windows simply lacks this.
return os.Symlink(m.Source, target)
}
func bindUnmount(target string) error {
return os.Remove(target)
return ApplyFileBinding(target, m.Source, readOnly)
}

93
mount/zsyscall_windows.go Normal file
View File

@ -0,0 +1,93 @@
//go:build windows
// Code generated by 'go generate' using "github.com/Microsoft/go-winio/tools/mkwinsyscall"; DO NOT EDIT.
package mount
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modbindfltapi = windows.NewLazySystemDLL("bindfltapi.dll")
procBfGetMappings = modbindfltapi.NewProc("BfGetMappings")
procBfRemoveMapping = modbindfltapi.NewProc("BfRemoveMapping")
procBfSetupFilter = modbindfltapi.NewProc("BfSetupFilter")
)
func BfGetMappings(flags uint32, jobHandle windows.Handle, virtRootPath *uint16, sid *windows.SID, bufferSize *uint32, outBuffer uintptr) (hr error) {
hr = procBfGetMappings.Find()
if hr != nil {
return
}
r0, _, _ := syscall.Syscall6(procBfGetMappings.Addr(), 6, uintptr(flags), uintptr(jobHandle), uintptr(unsafe.Pointer(virtRootPath)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(bufferSize)), uintptr(outBuffer))
if int32(r0) < 0 {
if r0&0x1fff0000 == 0x00070000 {
r0 &= 0xffff
}
hr = syscall.Errno(r0)
}
return
}
func BfRemoveMapping(jobHandle windows.Handle, virtRootPath *uint16) (hr error) {
hr = procBfRemoveMapping.Find()
if hr != nil {
return
}
r0, _, _ := syscall.Syscall(procBfRemoveMapping.Addr(), 2, uintptr(jobHandle), uintptr(unsafe.Pointer(virtRootPath)), 0)
if int32(r0) < 0 {
if r0&0x1fff0000 == 0x00070000 {
r0 &= 0xffff
}
hr = syscall.Errno(r0)
}
return
}
func BfSetupFilter(jobHandle windows.Handle, flags uint32, virtRootPath *uint16, virtTargetPath *uint16, virtExceptions **uint16, virtExceptionPathCount uint32) (hr error) {
hr = procBfSetupFilter.Find()
if hr != nil {
return
}
r0, _, _ := syscall.Syscall6(procBfSetupFilter.Addr(), 6, uintptr(jobHandle), uintptr(flags), uintptr(unsafe.Pointer(virtRootPath)), uintptr(unsafe.Pointer(virtTargetPath)), uintptr(unsafe.Pointer(virtExceptions)), uintptr(virtExceptionPathCount))
if int32(r0) < 0 {
if r0&0x1fff0000 == 0x00070000 {
r0 &= 0xffff
}
hr = syscall.Errno(r0)
}
return
}

View File

@ -358,11 +358,7 @@ func (s *snapshotter) createSnapshot(ctx context.Context, kind snapshots.Kind, k
return fmt.Errorf("failed to create snapshot: %w", err)
}
if kind != snapshots.KindActive {
return nil
}
log.G(ctx).Debug("createSnapshot active")
log.G(ctx).Debug("createSnapshot")
// Create the new snapshot dir
snDir := s.getSnapshotDir(newSnapshot.ID)
if err = os.MkdirAll(snDir, 0700); err != nil {
@ -393,28 +389,22 @@ func (s *snapshotter) createSnapshot(ctx context.Context, kind snapshots.Kind, k
}
var sizeInBytes uint64
if kind == snapshots.KindActive {
if sizeGBstr, ok := snapshotInfo.Labels[rootfsSizeInGBLabel]; ok {
log.G(ctx).Warnf("%q label is deprecated, please use %q instead.", rootfsSizeInGBLabel, rootfsSizeInBytesLabel)
if sizeGBstr, ok := snapshotInfo.Labels[rootfsSizeInGBLabel]; ok {
log.G(ctx).Warnf("%q label is deprecated, please use %q instead.", rootfsSizeInGBLabel, rootfsSizeInBytesLabel)
sizeInGB, err := strconv.ParseUint(sizeGBstr, 10, 32)
if err != nil {
return fmt.Errorf("failed to parse label %q=%q: %w", rootfsSizeInGBLabel, sizeGBstr, err)
}
sizeInBytes = sizeInGB * 1024 * 1024 * 1024
sizeInGB, err := strconv.ParseUint(sizeGBstr, 10, 32)
if err != nil {
return fmt.Errorf("failed to parse label %q=%q: %w", rootfsSizeInGBLabel, sizeGBstr, err)
}
sizeInBytes = sizeInGB * 1024 * 1024 * 1024
}
// Prefer the newer label in bytes over the deprecated Windows specific GB variant.
if sizeBytesStr, ok := snapshotInfo.Labels[rootfsSizeInBytesLabel]; ok {
sizeInBytes, err = strconv.ParseUint(sizeBytesStr, 10, 64)
if err != nil {
return fmt.Errorf("failed to parse label %q=%q: %w", rootfsSizeInBytesLabel, sizeBytesStr, err)
}
// Prefer the newer label in bytes over the deprecated Windows specific GB variant.
if sizeBytesStr, ok := snapshotInfo.Labels[rootfsSizeInBytesLabel]; ok {
sizeInBytes, err = strconv.ParseUint(sizeBytesStr, 10, 64)
if err != nil {
return fmt.Errorf("failed to parse label %q=%q: %w", rootfsSizeInBytesLabel, sizeBytesStr, err)
}
} else {
// A view is just a read-write snapshot with a _really_ small sandbox, since we cannot actually
// make a read-only mount or junction point. https://superuser.com/q/881544/112473
sizeInBytes = 1024 * 1024 * 1024
}
var makeUVMScratch bool