diff --git a/conn_darwin.go b/conn_darwin.go new file mode 100644 index 0000000..2e0e2fc --- /dev/null +++ b/conn_darwin.go @@ -0,0 +1,62 @@ +//go:build darwin +// +build darwin + +package vsock + +import ( + "context" + + "github.com/mdlayher/socket" + "golang.org/x/sys/unix" +) + +// A conn is the net.Conn implementation for connection-oriented VM sockets. +// We can use socket.Conn directly on Linux to implement all of the necessary +// methods. +type conn = socket.Conn + +// dial is the entry point for Dial on Linux. +func dial(ctx context.Context, cid, port uint32, _ *Config) (*Conn, error) { + // TODO(mdlayher): Config default nil check and initialize. Pass options to + // socket.Config where necessary. + + c, err := socket.Socket(unix.AF_VSOCK, unix.SOCK_STREAM, 0, "vsock", nil) + if err != nil { + return nil, err + } + + sa := &unix.SockaddrVM{CID: cid, Port: port} + rsa, err := c.Connect(ctx, sa) + if err != nil { + _ = c.Close() + return nil, err + } + + // TODO(mdlayher): getpeername(2) appears to return nil in the GitHub CI + // environment, so in the event of a nil sockaddr, fall back to the previous + // method of synthesizing the remote address. + if rsa == nil { + rsa = sa + } + + lsa, err := c.Getsockname() + if err != nil { + _ = c.Close() + return nil, err + } + + lsavm := lsa.(*unix.SockaddrVM) + rsavm := rsa.(*unix.SockaddrVM) + + return &Conn{ + c: c, + local: &Addr{ + ContextID: lsavm.CID, + Port: lsavm.Port, + }, + remote: &Addr{ + ContextID: rsavm.CID, + Port: rsavm.Port, + }, + }, nil +} diff --git a/conn_linux.go b/conn_linux.go index 46902d4..efaac46 100644 --- a/conn_linux.go +++ b/conn_linux.go @@ -15,7 +15,7 @@ import ( type conn = socket.Conn // dial is the entry point for Dial on Linux. -func dial(cid, port uint32, _ *Config) (*Conn, error) { +func dial(ctx context.Context, cid, port uint32, _ *Config) (*Conn, error) { // TODO(mdlayher): Config default nil check and initialize. Pass options to // socket.Config where necessary. @@ -25,7 +25,7 @@ func dial(cid, port uint32, _ *Config) (*Conn, error) { } sa := &unix.SockaddrVM{CID: cid, Port: port} - rsa, err := c.Connect(context.Background(), sa) + rsa, err := c.Connect(ctx, sa) if err != nil { _ = c.Close() return nil, err diff --git a/fd_darwin.go b/fd_darwin.go new file mode 100644 index 0000000..4955766 --- /dev/null +++ b/fd_darwin.go @@ -0,0 +1,37 @@ +package vsock + +import ( + "fmt" + + "golang.org/x/sys/unix" +) + +// contextID retrieves the local context ID for this system. +func contextID() (uint32, error) { + if fd, err := unix.Socket(unix.AF_VSOCK, unix.SOCK_STREAM, 0); err != nil { + return 2, nil + } else { + defer unix.Close(fd) + + cid, err := unix.IoctlGetInt(fd, unix.IOCTL_VM_SOCKETS_GET_LOCAL_CID) + + return uint32(cid), err + } +} + +// isErrno determines if an error a matches UNIX error number. +func isErrno(err error, errno int) bool { + switch errno { + case ebadf: + return err == unix.EBADF + case enotconn: + return err == unix.ENOTCONN + default: + panicf("vsock: isErrno called with unhandled error number parameter: %d", errno) + return false + } +} + +func panicf(format string, a ...interface{}) { + panic(fmt.Sprintf(format, a...)) +} diff --git a/listener_darwin.go b/listener_darwin.go new file mode 100644 index 0000000..e6f7f8e --- /dev/null +++ b/listener_darwin.go @@ -0,0 +1,133 @@ +//go:build darwin +// +build darwin + +package vsock + +import ( + "context" + "net" + "os" + "time" + + "github.com/mdlayher/socket" + "golang.org/x/sys/unix" +) + +var _ net.Listener = &listener{} + +// A listener is the net.Listener implementation for connection-oriented +// VM sockets. +type listener struct { + c *socket.Conn + addr *Addr +} + +// Addr and Close implement the net.Listener interface for listener. +func (l *listener) Addr() net.Addr { return l.addr } +func (l *listener) Close() error { return l.c.Close() } +func (l *listener) SetDeadline(t time.Time) error { return l.c.SetDeadline(t) } + +// Accept accepts a single connection from the listener, and sets up +// a net.Conn backed by conn. +func (l *listener) Accept() (net.Conn, error) { + c, rsa, err := l.c.Accept(context.Background(), 0) + if err != nil { + return nil, err + } + + savm := rsa.(*unix.SockaddrVM) + remote := &Addr{ + ContextID: savm.CID, + Port: savm.Port, + } + + return &Conn{ + c: c, + local: l.addr, + remote: remote, + }, nil +} + +// name is the socket name passed to package socket. +const name = "vsock" + +// listen is the entry point for Listen on Linux. +func listen(cid, port uint32, _ *Config) (*Listener, error) { + // TODO(mdlayher): Config default nil check and initialize. Pass options to + // socket.Config where necessary. + + c, err := socket.Socket(unix.AF_VSOCK, unix.SOCK_STREAM, 0, name, nil) + if err != nil { + return nil, err + } + + // Be sure to close the Conn if any of the system calls fail before we + // return the Conn to the caller. + + if port == 0 { + port = unix.VMADDR_PORT_ANY + } + + if err := c.Bind(&unix.SockaddrVM{CID: cid, Port: port}); err != nil { + _ = c.Close() + return nil, err + } + + if err := c.Listen(unix.SOMAXCONN); err != nil { + _ = c.Close() + return nil, err + } + + l, err := newListener(c) + if err != nil { + _ = c.Close() + return nil, err + } + + return l, nil +} + +// fileListener is the entry point for FileListener on Linux. +func fileListener(f *os.File) (*Listener, error) { + c, err := socket.FileConn(f, name) + if err != nil { + return nil, err + } + + l, err := newListener(c) + if err != nil { + _ = c.Close() + return nil, err + } + + return l, nil +} + +// newListener creates a Listener from a raw socket.Conn. +func newListener(c *socket.Conn) (*Listener, error) { + lsa, err := c.Getsockname() + if err != nil { + return nil, err + } + + // Now that the library can also accept arbitrary os.Files, we have to + // verify the address family so we don't accidentally create a + // *vsock.Listener backed by TCP or some other socket type. + lsavm, ok := lsa.(*unix.SockaddrVM) + if !ok { + // All errors should wrapped with os.SyscallError. + return nil, os.NewSyscallError("listen", unix.EINVAL) + } + + addr := &Addr{ + ContextID: lsavm.CID, + Port: lsavm.Port, + } + + return &Listener{ + l: &listener{ + c: c, + addr: addr, + }, + }, nil +} diff --git a/vsock.go b/vsock.go index 1cc0520..6c71d8d 100644 --- a/vsock.go +++ b/vsock.go @@ -1,10 +1,12 @@ package vsock import ( + "context" "fmt" "io" "net" "os" + "runtime" "strings" "syscall" "time" @@ -53,6 +55,10 @@ const ( opWrite = "write" ) +// errUnimplemented is returned by all functions on platforms that +// cannot make use of VM sockets. +var errUnimplemented = fmt.Errorf("vsock: not implemented on %s", runtime.GOOS) + // TODO(mdlayher): plumb through socket.Config.NetNS if it makes sense. // Config contains options for a Conn or Listener. @@ -175,7 +181,21 @@ func (l *Listener) opError(op string, err error) error { // When the connection is no longer needed, Close must be called to free // resources. func Dial(contextID, port uint32, cfg *Config) (*Conn, error) { - c, err := dial(contextID, port, cfg) + return dial(context.Background(), contextID, port, cfg) +} + +// DialWithContext connects to the address on the named network using +// the provided context. +// +// The provided Context must be non-nil. If the context expires before +// the connection is complete, an error is returned. Once successfully +// connected, any expiration of the context will not affect the +// connection. +// +// See func Dial for a description of the contextID and port +// parameters. +func DialWithContext(ctx context.Context, contextID, port uint32, cfg *Config) (*Conn, error) { + c, err := dial(ctx, contextID, port, cfg) if err != nil { // No local address, but we have a remote address we can return. return nil, opError(opDial, err, nil, &Addr{ diff --git a/vsock_darwin_test.go b/vsock_darwin_test.go new file mode 100644 index 0000000..f966a0c --- /dev/null +++ b/vsock_darwin_test.go @@ -0,0 +1,267 @@ +package vsock + +import ( + "errors" + "io" + "net" + "os" + "testing" + + "github.com/google/go-cmp/cmp" + "golang.org/x/sys/unix" +) + +func Test_opError(t *testing.T) { + // The default op for empty op fields. + const defaultOp = "read" + + var ( + // Unfortunate, but string matching it is for now. + errClosed = errors.New("use of closed network connection") + + local = &Addr{ + ContextID: Host, + Port: 1024, + } + + remote = &Addr{ + ContextID: 3, + Port: 2048, + } + ) + + tests := []struct { + name string + op string + err error + local net.Addr + remote net.Addr + want error + }{ + { + name: "nil error", + }, + { + name: "unknown", + err: errors.New("foo"), + want: &net.OpError{ + Err: errors.New("foo"), + }, + }, + { + name: "EOF", + err: io.EOF, + want: io.EOF, + }, + { + name: "ENOTCONN", + err: unix.ENOTCONN, + want: io.EOF, + }, + { + name: "PathError ENOTCONN", + err: &os.PathError{ + Err: unix.ENOTCONN, + }, + want: io.EOF, + }, + { + name: "ErrClosed", + err: os.ErrClosed, + want: &net.OpError{ + Err: errClosed, + }, + }, + { + name: "EBADF", + err: unix.EBADF, + want: &net.OpError{ + Err: errClosed, + }, + }, + { + name: "string use of closed", + err: errors.New("use of closed file"), + want: &net.OpError{ + Err: errClosed, + }, + }, + { + name: "op close", + op: opClose, + err: errClosed, + local: local, + remote: remote, + want: &net.OpError{ + Op: opClose, + Source: local, + Addr: remote, + Err: errClosed, + }, + }, + { + name: "op dial", + op: opDial, + err: errClosed, + local: local, + remote: remote, + want: &net.OpError{ + Op: opDial, + Source: local, + Addr: remote, + Err: errClosed, + }, + }, + { + name: "op raw-read", + op: opRawRead, + err: errClosed, + local: local, + remote: remote, + want: &net.OpError{ + Op: opRawRead, + Source: local, + Addr: remote, + Err: errClosed, + }, + }, + { + name: "op raw-write", + op: opRawWrite, + err: errClosed, + local: local, + remote: remote, + want: &net.OpError{ + Op: opRawWrite, + Source: local, + Addr: remote, + Err: errClosed, + }, + }, + { + name: "op read", + op: opRead, + err: errClosed, + local: local, + remote: remote, + want: &net.OpError{ + Op: opRead, + Source: local, + Addr: remote, + Err: errClosed, + }, + }, + { + name: "op write", + op: opWrite, + err: errClosed, + local: local, + remote: remote, + want: &net.OpError{ + Op: opWrite, + Source: local, + Addr: remote, + Err: errClosed, + }, + }, + { + name: "op accept", + op: opAccept, + err: errClosed, + local: local, + want: &net.OpError{ + Op: opAccept, + Addr: local, + Err: errClosed, + }, + }, + { + name: "op listen", + op: opListen, + err: errClosed, + local: local, + want: &net.OpError{ + Op: opListen, + Addr: local, + Err: errClosed, + }, + }, + { + name: "op raw-control", + op: opRawControl, + err: errClosed, + local: local, + want: &net.OpError{ + Op: opRawControl, + Addr: local, + Err: errClosed, + }, + }, + { + name: "op set", + op: opSet, + err: errClosed, + local: local, + want: &net.OpError{ + Op: opSet, + Addr: local, + Err: errClosed, + }, + }, + { + name: "op syscall-conn", + op: opSyscallConn, + err: errClosed, + local: local, + want: &net.OpError{ + Op: opSyscallConn, + Addr: local, + Err: errClosed, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + op := tt.op + if op == "" { + op = defaultOp + } + + err := opError(op, tt.err, tt.local, tt.remote) + if err == nil { + if tt.want != nil { + t.Fatal("expected an output error, but none occurred") + } + + return + } + + // Populate sane defaults to save some typing. + want := tt.want + if nerr, ok := tt.want.(*net.OpError); ok { + if nerr.Op == "" { + nerr.Op = defaultOp + } + + if nerr.Net == "" { + nerr.Net = network + } + + want = nerr + } + + if diff := cmp.Diff(want, err, cmp.Comparer(errorsEqual)); diff != "" { + t.Fatalf("unexpected error (-want +got):\n%s", diff) + } + }) + } +} + +func errorsEqual(x, y error) bool { + if x == nil || y == nil { + return x == nil && y == nil + } + + return x.Error() == y.Error() +} diff --git a/vsock_others.go b/vsock_others.go index 5c1e88e..fe25690 100644 --- a/vsock_others.go +++ b/vsock_others.go @@ -1,21 +1,15 @@ -//go:build !linux -// +build !linux +//go:build !linux && !darwin +// +build !linux,!darwin package vsock import ( - "fmt" "net" "os" - "runtime" "syscall" "time" ) -// errUnimplemented is returned by all functions on platforms that -// cannot make use of VM sockets. -var errUnimplemented = fmt.Errorf("vsock: not implemented on %s", runtime.GOOS) - func fileListener(_ *os.File) (*Listener, error) { return nil, errUnimplemented } func listen(_, _ uint32, _ *Config) (*Listener, error) { return nil, errUnimplemented } diff --git a/vsock_others_test.go b/vsock_others_test.go index 4d33df9..5d88715 100644 --- a/vsock_others_test.go +++ b/vsock_others_test.go @@ -1,5 +1,5 @@ -//go:build !linux -// +build !linux +//go:build !linux && !darwin +// +build !linux,!darwin package vsock