diff --git a/conn_linux.go b/conn_linux.go index 6029d54..0453edf 100644 --- a/conn_linux.go +++ b/conn_linux.go @@ -16,7 +16,7 @@ import ( type conn = socket.Conn // dial is the entry point for Dial on Linux. -func dial(cid, port uint32, _ *Config) (*Conn, error) { +func dial(ctx context.Context, cid, port uint32, _ *Config) (*Conn, error) { // TODO(mdlayher): Config default nil check and initialize. Pass options to // socket.Config where necessary. @@ -26,7 +26,7 @@ func dial(cid, port uint32, _ *Config) (*Conn, error) { } sa := &unix.SockaddrVM{CID: cid, Port: port} - rsa, err := c.Connect(context.Background(), sa) + rsa, err := c.Connect(ctx, sa) if err != nil { _ = c.Close() return nil, err diff --git a/vsock.go b/vsock.go index 7876393..7960177 100644 --- a/vsock.go +++ b/vsock.go @@ -1,6 +1,7 @@ package vsock import ( + "context" "errors" "fmt" "io" @@ -176,7 +177,21 @@ func (l *Listener) opError(op string, err error) error { // When the connection is no longer needed, Close must be called to free // resources. func Dial(contextID, port uint32, cfg *Config) (*Conn, error) { - c, err := dial(contextID, port, cfg) + return DialContext(context.Background(), contextID, port, cfg) +} + +// DialContext connects to the address on the named network using +// the provided context. +// +// The provided Context must be non-nil. If the context expires before +// the connection is complete, an error is returned. Once successfully +// connected, any expiration of the context will not affect the +// connection. +// +// See func Dial for a description of the contextID and port +// parameters. +func DialContext(ctx context.Context, contextID, port uint32, cfg *Config) (*Conn, error) { + c, err := dial(ctx, contextID, port, cfg) if err != nil { // No local address, but we have a remote address we can return. return nil, opError(opDial, err, nil, &Addr{