
Some minor improvements, but biggest for here is ErrPipeListenerClosed is no longer an errors.New where the string matches the text of the now exported net.ErrClosed in the stdlib, but is just assigned to net.ErrClosed directly. This should allow us to get rid of the string check for "use of closed network connection" here now.. Signed-off-by: Daniel Canter <dcanter@microsoft.com>
180 lines
4.9 KiB
Go
180 lines
4.9 KiB
Go
//go:build windows
|
|
|
|
package socket
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"syscall"
|
|
"unsafe"
|
|
|
|
"github.com/Microsoft/go-winio/pkg/guid"
|
|
"golang.org/x/sys/windows"
|
|
)
|
|
|
|
//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go socket.go
|
|
|
|
//sys getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getsockname
|
|
//sys getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getpeername
|
|
//sys bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind
|
|
|
|
const socketError = uintptr(^uint32(0))
|
|
|
|
var (
|
|
// todo(helsaawy): create custom error types to store the desired vs actual size and addr family?
|
|
|
|
ErrBufferSize = errors.New("buffer size")
|
|
ErrAddrFamily = errors.New("address family")
|
|
ErrInvalidPointer = errors.New("invalid pointer")
|
|
ErrSocketClosed = fmt.Errorf("socket closed: %w", net.ErrClosed)
|
|
)
|
|
|
|
// todo(helsaawy): replace these with generics, ie: GetSockName[S RawSockaddr](s windows.Handle) (S, error)
|
|
|
|
// GetSockName writes the local address of socket s to the [RawSockaddr] rsa.
|
|
// If rsa is not large enough, the [windows.WSAEFAULT] is returned.
|
|
func GetSockName(s windows.Handle, rsa RawSockaddr) error {
|
|
ptr, l, err := rsa.Sockaddr()
|
|
if err != nil {
|
|
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
|
|
}
|
|
|
|
// although getsockname returns WSAEFAULT if the buffer is too small, it does not set
|
|
// &l to the correct size, so--apart from doubling the buffer repeatedly--there is no remedy
|
|
return getsockname(s, ptr, &l)
|
|
}
|
|
|
|
// GetPeerName returns the remote address the socket is connected to.
|
|
//
|
|
// See [GetSockName] for more information.
|
|
func GetPeerName(s windows.Handle, rsa RawSockaddr) error {
|
|
ptr, l, err := rsa.Sockaddr()
|
|
if err != nil {
|
|
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
|
|
}
|
|
|
|
return getpeername(s, ptr, &l)
|
|
}
|
|
|
|
func Bind(s windows.Handle, rsa RawSockaddr) (err error) {
|
|
ptr, l, err := rsa.Sockaddr()
|
|
if err != nil {
|
|
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
|
|
}
|
|
|
|
return bind(s, ptr, l)
|
|
}
|
|
|
|
// "golang.org/x/sys/windows".ConnectEx and .Bind only accept internal implementations of the
|
|
// their sockaddr interface, so they cannot be used with HvsockAddr
|
|
// Replicate functionality here from
|
|
// https://cs.opensource.google/go/x/sys/+/master:windows/syscall_windows.go
|
|
|
|
// The function pointers to `AcceptEx`, `ConnectEx` and `GetAcceptExSockaddrs` must be loaded at
|
|
// runtime via a WSAIoctl call:
|
|
// https://docs.microsoft.com/en-us/windows/win32/api/Mswsock/nc-mswsock-lpfn_connectex#remarks
|
|
|
|
type runtimeFunc struct {
|
|
id guid.GUID
|
|
once sync.Once
|
|
addr uintptr
|
|
err error
|
|
}
|
|
|
|
func (f *runtimeFunc) Load() error {
|
|
f.once.Do(func() {
|
|
var s windows.Handle
|
|
s, f.err = windows.Socket(windows.AF_INET, windows.SOCK_STREAM, windows.IPPROTO_TCP)
|
|
if f.err != nil {
|
|
return
|
|
}
|
|
defer windows.CloseHandle(s) //nolint:errcheck
|
|
|
|
var n uint32
|
|
f.err = windows.WSAIoctl(s,
|
|
windows.SIO_GET_EXTENSION_FUNCTION_POINTER,
|
|
(*byte)(unsafe.Pointer(&f.id)),
|
|
uint32(unsafe.Sizeof(f.id)),
|
|
(*byte)(unsafe.Pointer(&f.addr)),
|
|
uint32(unsafe.Sizeof(f.addr)),
|
|
&n,
|
|
nil, //overlapped
|
|
0, //completionRoutine
|
|
)
|
|
})
|
|
return f.err
|
|
}
|
|
|
|
var (
|
|
// todo: add `AcceptEx` and `GetAcceptExSockaddrs`
|
|
WSAID_CONNECTEX = guid.GUID{ //revive:disable-line:var-naming ALL_CAPS
|
|
Data1: 0x25a207b9,
|
|
Data2: 0xddf3,
|
|
Data3: 0x4660,
|
|
Data4: [8]byte{0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e},
|
|
}
|
|
|
|
connectExFunc = runtimeFunc{id: WSAID_CONNECTEX}
|
|
)
|
|
|
|
func ConnectEx(
|
|
fd windows.Handle,
|
|
rsa RawSockaddr,
|
|
sendBuf *byte,
|
|
sendDataLen uint32,
|
|
bytesSent *uint32,
|
|
overlapped *windows.Overlapped,
|
|
) error {
|
|
if err := connectExFunc.Load(); err != nil {
|
|
return fmt.Errorf("failed to load ConnectEx function pointer: %w", err)
|
|
}
|
|
ptr, n, err := rsa.Sockaddr()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return connectEx(fd, ptr, n, sendBuf, sendDataLen, bytesSent, overlapped)
|
|
}
|
|
|
|
// BOOL LpfnConnectex(
|
|
// [in] SOCKET s,
|
|
// [in] const sockaddr *name,
|
|
// [in] int namelen,
|
|
// [in, optional] PVOID lpSendBuffer,
|
|
// [in] DWORD dwSendDataLength,
|
|
// [out] LPDWORD lpdwBytesSent,
|
|
// [in] LPOVERLAPPED lpOverlapped
|
|
// )
|
|
|
|
func connectEx(
|
|
s windows.Handle,
|
|
name unsafe.Pointer,
|
|
namelen int32,
|
|
sendBuf *byte,
|
|
sendDataLen uint32,
|
|
bytesSent *uint32,
|
|
overlapped *windows.Overlapped,
|
|
) (err error) {
|
|
// todo: after upgrading to 1.18, switch from syscall.Syscall9 to syscall.SyscallN
|
|
r1, _, e1 := syscall.Syscall9(connectExFunc.addr,
|
|
7,
|
|
uintptr(s),
|
|
uintptr(name),
|
|
uintptr(namelen),
|
|
uintptr(unsafe.Pointer(sendBuf)),
|
|
uintptr(sendDataLen),
|
|
uintptr(unsafe.Pointer(bytesSent)),
|
|
uintptr(unsafe.Pointer(overlapped)),
|
|
0,
|
|
0)
|
|
if r1 == 0 {
|
|
if e1 != 0 {
|
|
err = error(e1)
|
|
} else {
|
|
err = syscall.EINVAL
|
|
}
|
|
}
|
|
return err
|
|
}
|