containerd/vendor/github.com/Microsoft/go-winio/internal/socket/socket.go
Daniel Canter acb2964d65 go.mod: Bump go-winio to v0.6.0
Some minor improvements, but biggest for here is ErrPipeListenerClosed
is no longer an errors.New where the string matches the text of the now
exported net.ErrClosed in the stdlib, but is just assigned to net.ErrClosed
directly. This should allow us to get rid of the string check for "use of closed
network connection" here now..

Signed-off-by: Daniel Canter <dcanter@microsoft.com>
2022-09-28 10:30:38 -07:00

180 lines
4.9 KiB
Go

//go:build windows
package socket
import (
"errors"
"fmt"
"net"
"sync"
"syscall"
"unsafe"
"github.com/Microsoft/go-winio/pkg/guid"
"golang.org/x/sys/windows"
)
//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go socket.go
//sys getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getsockname
//sys getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getpeername
//sys bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind
const socketError = uintptr(^uint32(0))
var (
// todo(helsaawy): create custom error types to store the desired vs actual size and addr family?
ErrBufferSize = errors.New("buffer size")
ErrAddrFamily = errors.New("address family")
ErrInvalidPointer = errors.New("invalid pointer")
ErrSocketClosed = fmt.Errorf("socket closed: %w", net.ErrClosed)
)
// todo(helsaawy): replace these with generics, ie: GetSockName[S RawSockaddr](s windows.Handle) (S, error)
// GetSockName writes the local address of socket s to the [RawSockaddr] rsa.
// If rsa is not large enough, the [windows.WSAEFAULT] is returned.
func GetSockName(s windows.Handle, rsa RawSockaddr) error {
ptr, l, err := rsa.Sockaddr()
if err != nil {
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
}
// although getsockname returns WSAEFAULT if the buffer is too small, it does not set
// &l to the correct size, so--apart from doubling the buffer repeatedly--there is no remedy
return getsockname(s, ptr, &l)
}
// GetPeerName returns the remote address the socket is connected to.
//
// See [GetSockName] for more information.
func GetPeerName(s windows.Handle, rsa RawSockaddr) error {
ptr, l, err := rsa.Sockaddr()
if err != nil {
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
}
return getpeername(s, ptr, &l)
}
func Bind(s windows.Handle, rsa RawSockaddr) (err error) {
ptr, l, err := rsa.Sockaddr()
if err != nil {
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
}
return bind(s, ptr, l)
}
// "golang.org/x/sys/windows".ConnectEx and .Bind only accept internal implementations of the
// their sockaddr interface, so they cannot be used with HvsockAddr
// Replicate functionality here from
// https://cs.opensource.google/go/x/sys/+/master:windows/syscall_windows.go
// The function pointers to `AcceptEx`, `ConnectEx` and `GetAcceptExSockaddrs` must be loaded at
// runtime via a WSAIoctl call:
// https://docs.microsoft.com/en-us/windows/win32/api/Mswsock/nc-mswsock-lpfn_connectex#remarks
type runtimeFunc struct {
id guid.GUID
once sync.Once
addr uintptr
err error
}
func (f *runtimeFunc) Load() error {
f.once.Do(func() {
var s windows.Handle
s, f.err = windows.Socket(windows.AF_INET, windows.SOCK_STREAM, windows.IPPROTO_TCP)
if f.err != nil {
return
}
defer windows.CloseHandle(s) //nolint:errcheck
var n uint32
f.err = windows.WSAIoctl(s,
windows.SIO_GET_EXTENSION_FUNCTION_POINTER,
(*byte)(unsafe.Pointer(&f.id)),
uint32(unsafe.Sizeof(f.id)),
(*byte)(unsafe.Pointer(&f.addr)),
uint32(unsafe.Sizeof(f.addr)),
&n,
nil, //overlapped
0, //completionRoutine
)
})
return f.err
}
var (
// todo: add `AcceptEx` and `GetAcceptExSockaddrs`
WSAID_CONNECTEX = guid.GUID{ //revive:disable-line:var-naming ALL_CAPS
Data1: 0x25a207b9,
Data2: 0xddf3,
Data3: 0x4660,
Data4: [8]byte{0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e},
}
connectExFunc = runtimeFunc{id: WSAID_CONNECTEX}
)
func ConnectEx(
fd windows.Handle,
rsa RawSockaddr,
sendBuf *byte,
sendDataLen uint32,
bytesSent *uint32,
overlapped *windows.Overlapped,
) error {
if err := connectExFunc.Load(); err != nil {
return fmt.Errorf("failed to load ConnectEx function pointer: %w", err)
}
ptr, n, err := rsa.Sockaddr()
if err != nil {
return err
}
return connectEx(fd, ptr, n, sendBuf, sendDataLen, bytesSent, overlapped)
}
// BOOL LpfnConnectex(
// [in] SOCKET s,
// [in] const sockaddr *name,
// [in] int namelen,
// [in, optional] PVOID lpSendBuffer,
// [in] DWORD dwSendDataLength,
// [out] LPDWORD lpdwBytesSent,
// [in] LPOVERLAPPED lpOverlapped
// )
func connectEx(
s windows.Handle,
name unsafe.Pointer,
namelen int32,
sendBuf *byte,
sendDataLen uint32,
bytesSent *uint32,
overlapped *windows.Overlapped,
) (err error) {
// todo: after upgrading to 1.18, switch from syscall.Syscall9 to syscall.SyscallN
r1, _, e1 := syscall.Syscall9(connectExFunc.addr,
7,
uintptr(s),
uintptr(name),
uintptr(namelen),
uintptr(unsafe.Pointer(sendBuf)),
uintptr(sendDataLen),
uintptr(unsafe.Pointer(bytesSent)),
uintptr(unsafe.Pointer(overlapped)),
0,
0)
if r1 == 0 {
if e1 != 0 {
err = error(e1)
} else {
err = syscall.EINVAL
}
}
return err
}