diff --git a/api/next.txtpb b/api/next.txtpb index e78ce00..fee2964 100644 --- a/api/next.txtpb +++ b/api/next.txtpb @@ -850,6 +850,611 @@ file: { is_syntax_unspecified: false } } +file: { + name: "proto/nerdbox/services/socketforward/v1/socketforward.proto" + package: "nerdbox.services.socketforward.v1" + dependency: "google/protobuf/empty.proto" + message_type: { + name: "BindRequest" + field: { + name: "sockets" + number: 1 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".nerdbox.services.socketforward.v1.Socket" + json_name: "sockets" + } + } + message_type: { + name: "Socket" + field: { + name: "forward_id" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + json_name: "forwardId" + } + field: { + name: "socket_path" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + json_name: "socketPath" + } + } + message_type: { + name: "ConnectRequest" + field: { + name: "stream_id" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + json_name: "streamId" + } + field: { + name: "forward_id" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + json_name: "forwardId" + } + } + message_type: { + name: "ConnectResult" + field: { + name: "stream_id" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + json_name: "streamId" + } + field: { + name: "error" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + json_name: "error" + } + } + service: { + name: "SocketForward" + method: { + name: "Bind" + input_type: ".nerdbox.services.socketforward.v1.BindRequest" + output_type: ".google.protobuf.Empty" + } + method: { + name: "Accept" + input_type: ".nerdbox.services.socketforward.v1.ConnectResult" + output_type: ".nerdbox.services.socketforward.v1.ConnectRequest" + client_streaming: true + server_streaming: true + } + } + options: { + go_package: "github.com/containerd/nerdbox/api/services/socketforward/v1;socketforward" + } + source_code_info: { + location: { + span: 16 + span: 0 + span: 94 + span: 1 + } + location: { + path: 12 + span: 16 + span: 0 + span: 18 + leading_detached_comments: "\nCopyright The containerd Authors.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n" + } + location: { + path: 2 + span: 18 + span: 0 + span: 42 + } + location: { + path: 3 + path: 0 + span: 20 + span: 0 + span: 37 + } + location: { + path: 8 + span: 22 + span: 0 + span: 96 + } + location: { + path: 8 + path: 11 + span: 22 + span: 0 + span: 96 + } + location: { + path: 6 + path: 0 + span: 38 + span: 0 + span: 52 + span: 1 + leading_comments: " SocketForward provides UNIX domain socket forwarding across the VM boundary\n over vsock streams. Container processes connect to listener sockets created\n by vminitd; connections are relayed to the corresponding host-side socket.\n\n The ttrpc server runs inside the VM (vminitd) and the ttrpc client runs on\n the host (shim). All RPCs are initiated by the host:\n\n - Bind: the host tells the VM which sockets to set up. The VM creates\n UNIX listener sockets in the root mount ns so that crun can bind-mount\n them into the container. This must be called before the container is\n created.\n - Accept: the host opens a bidirectional streaming channel. The VM sends\n ConnectRequest messages when a process inside the container connects to\n a forwarded socket, and the host replies with ConnectResult messages.\n" + } + location: { + path: 6 + path: 0 + path: 1 + span: 38 + span: 8 + span: 21 + } + location: { + path: 6 + path: 0 + path: 2 + path: 0 + span: 43 + span: 8 + span: 62 + leading_comments: " Bind sets up the socket forward entries on the VM side. The VM creates\n UNIX listener sockets at the paths given in socket_path. This RPC returns\n only after all socket files have been created, so the caller can safely\n proceed with container creation (crun bind mounts).\n" + } + location: { + path: 6 + path: 0 + path: 2 + path: 0 + path: 1 + span: 43 + span: 12 + span: 16 + } + location: { + path: 6 + path: 0 + path: 2 + path: 0 + path: 2 + span: 43 + span: 17 + span: 28 + } + location: { + path: 6 + path: 0 + path: 2 + path: 0 + path: 3 + span: 43 + span: 39 + span: 60 + } + location: { + path: 6 + path: 0 + path: 2 + path: 1 + span: 51 + span: 8 + span: 73 + leading_comments: " Accept is a bidirectional streaming RPC used to coordinate forwarded\n connections. The VM sends a ConnectRequest when a container process\n connects to a forwarded socket; the host resolves the forward_id,\n dials the target host socket, opens a vsock stream, and sends back a\n ConnectResult reporting success or failure. On failure the VM closes\n the pending container connection immediately.\n" + } + location: { + path: 6 + path: 0 + path: 2 + path: 1 + path: 1 + span: 51 + span: 12 + span: 18 + } + location: { + path: 6 + path: 0 + path: 2 + path: 1 + path: 5 + span: 51 + span: 19 + span: 25 + } + location: { + path: 6 + path: 0 + path: 2 + path: 1 + path: 2 + span: 51 + span: 26 + span: 39 + } + location: { + path: 6 + path: 0 + path: 2 + path: 1 + path: 6 + span: 51 + span: 50 + span: 56 + } + location: { + path: 6 + path: 0 + path: 2 + path: 1 + path: 3 + span: 51 + span: 57 + span: 71 + } + location: { + path: 4 + path: 0 + span: 56 + span: 0 + span: 58 + span: 1 + leading_comments: " BindRequest is sent by the host before container creation to set up socket\n forwards on the VM side.\n" + } + location: { + path: 4 + path: 0 + path: 1 + span: 56 + span: 8 + span: 19 + } + location: { + path: 4 + path: 0 + path: 2 + path: 0 + span: 57 + span: 8 + span: 36 + } + location: { + path: 4 + path: 0 + path: 2 + path: 0 + path: 4 + span: 57 + span: 8 + span: 16 + } + location: { + path: 4 + path: 0 + path: 2 + path: 0 + path: 6 + span: 57 + span: 17 + span: 23 + } + location: { + path: 4 + path: 0 + path: 2 + path: 0 + path: 1 + span: 57 + span: 24 + span: 31 + } + location: { + path: 4 + path: 0 + path: 2 + path: 0 + path: 3 + span: 57 + span: 34 + span: 35 + } + location: { + path: 4 + path: 1 + span: 61 + span: 0 + span: 70 + span: 1 + leading_comments: " Socket describes a single forwarded UNIX socket.\n" + } + location: { + path: 4 + path: 1 + path: 1 + span: 61 + span: 8 + span: 14 + } + location: { + path: 4 + path: 1 + path: 2 + path: 0 + span: 64 + span: 8 + span: 30 + leading_comments: " Opaque identifier for this forward. Each side uses it to resolve the local\n socket path from its own configuration.\n" + } + location: { + path: 4 + path: 1 + path: 2 + path: 0 + path: 5 + span: 64 + span: 8 + span: 14 + } + location: { + path: 4 + path: 1 + path: 2 + path: 0 + path: 1 + span: 64 + span: 15 + span: 25 + } + location: { + path: 4 + path: 1 + path: 2 + path: 0 + path: 3 + span: 64 + span: 28 + span: 29 + } + location: { + path: 4 + path: 1 + path: 2 + path: 1 + span: 69 + span: 8 + span: 31 + leading_comments: " Path of the UNIX listener socket in the VM's root filesystem (e.g.\n /run/socketfwd/{forward_id}.sock). vminitd creates the socket here;\n the shim has already rewritten the uds mount to a bind mount from\n this path to the user-specified destination inside the container.\n" + } + location: { + path: 4 + path: 1 + path: 2 + path: 1 + path: 5 + span: 69 + span: 8 + span: 14 + } + location: { + path: 4 + path: 1 + path: 2 + path: 1 + path: 1 + span: 69 + span: 15 + span: 26 + } + location: { + path: 4 + path: 1 + path: 2 + path: 1 + path: 3 + span: 69 + span: 29 + span: 30 + } + location: { + path: 4 + path: 2 + span: 74 + span: 0 + span: 82 + span: 1 + leading_comments: " ConnectRequest is sent by the VM on the Accept stream to notify the host of\n a new container-initiated connection.\n" + } + location: { + path: 4 + path: 2 + path: 1 + span: 74 + span: 8 + span: 22 + } + location: { + path: 4 + path: 2 + path: 2 + path: 0 + span: 77 + span: 8 + span: 29 + leading_comments: " ID of the vsock stream. The host must open a stream with this ID after\n a successful ConnectResult so that the relay can start.\n" + } + location: { + path: 4 + path: 2 + path: 2 + path: 0 + path: 5 + span: 77 + span: 8 + span: 14 + } + location: { + path: 4 + path: 2 + path: 2 + path: 0 + path: 1 + span: 77 + span: 15 + span: 24 + } + location: { + path: 4 + path: 2 + path: 2 + path: 0 + path: 3 + span: 77 + span: 27 + span: 28 + } + location: { + path: 4 + path: 2 + path: 2 + path: 1 + span: 81 + span: 8 + span: 30 + leading_comments: " Identifier of the socket forward entry. The host uses this to resolve the\n target host socket path from its own configuration rather than trusting a\n path supplied by the VM.\n" + } + location: { + path: 4 + path: 2 + path: 2 + path: 1 + path: 5 + span: 81 + span: 8 + span: 14 + } + location: { + path: 4 + path: 2 + path: 2 + path: 1 + path: 1 + span: 81 + span: 15 + span: 25 + } + location: { + path: 4 + path: 2 + path: 2 + path: 1 + path: 3 + span: 81 + span: 28 + span: 29 + } + location: { + path: 4 + path: 3 + span: 88 + span: 0 + span: 94 + span: 1 + leading_comments: " ConnectResult is sent by the host on the Accept stream in response to each\n ConnectRequest. It reports whether the host successfully dialed the target\n socket. On failure the VM closes the pending container connection\n immediately.\n" + } + location: { + path: 4 + path: 3 + path: 1 + span: 88 + span: 8 + span: 21 + } + location: { + path: 4 + path: 3 + path: 2 + path: 0 + span: 90 + span: 8 + span: 29 + leading_comments: " ID of the vsock stream from the corresponding ConnectRequest.\n" + } + location: { + path: 4 + path: 3 + path: 2 + path: 0 + path: 5 + span: 90 + span: 8 + span: 14 + } + location: { + path: 4 + path: 3 + path: 2 + path: 0 + path: 1 + span: 90 + span: 15 + span: 24 + } + location: { + path: 4 + path: 3 + path: 2 + path: 0 + path: 3 + span: 90 + span: 27 + span: 28 + } + location: { + path: 4 + path: 3 + path: 2 + path: 1 + span: 93 + span: 8 + span: 25 + leading_comments: " Non-empty if the host failed to dial the target socket. The VM closes\n the pending connection when error is set.\n" + } + location: { + path: 4 + path: 3 + path: 2 + path: 1 + path: 5 + span: 93 + span: 8 + span: 14 + } + location: { + path: 4 + path: 3 + path: 2 + path: 1 + path: 1 + span: 93 + span: 15 + span: 20 + } + location: { + path: 4 + path: 3 + path: 2 + path: 1 + path: 3 + span: 93 + span: 23 + span: 24 + } + } + syntax: "proto3" + buf_extension: { + is_import: false + is_syntax_unspecified: false + } +} file: { name: "proto/nerdbox/services/system/v1/info.proto" package: "containerd.vminitd.services.system.v1" diff --git a/api/proto/nerdbox/services/socketforward/v1/socketforward.proto b/api/proto/nerdbox/services/socketforward/v1/socketforward.proto new file mode 100644 index 0000000..b8f373e --- /dev/null +++ b/api/proto/nerdbox/services/socketforward/v1/socketforward.proto @@ -0,0 +1,95 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +syntax = "proto3"; + +package nerdbox.services.socketforward.v1; + +import "google/protobuf/empty.proto"; + +option go_package = "github.com/containerd/nerdbox/api/services/socketforward/v1;socketforward"; + +// SocketForward provides UNIX domain socket forwarding across the VM boundary +// over vsock streams. Container processes connect to listener sockets created +// by vminitd; connections are relayed to the corresponding host-side socket. +// +// The ttrpc server runs inside the VM (vminitd) and the ttrpc client runs on +// the host (shim). All RPCs are initiated by the host: +// +// - Bind: the host tells the VM which sockets to set up. The VM creates +// UNIX listener sockets in the root mount ns so that crun can bind-mount +// them into the container. This must be called before the container is +// created. +// - Accept: the host opens a bidirectional streaming channel. The VM sends +// ConnectRequest messages when a process inside the container connects to +// a forwarded socket, and the host replies with ConnectResult messages. +service SocketForward { + // Bind sets up the socket forward entries on the VM side. The VM creates + // UNIX listener sockets at the paths given in socket_path. This RPC returns + // only after all socket files have been created, so the caller can safely + // proceed with container creation (crun bind mounts). + rpc Bind(BindRequest) returns (google.protobuf.Empty); + + // Accept is a bidirectional streaming RPC used to coordinate forwarded + // connections. The VM sends a ConnectRequest when a container process + // connects to a forwarded socket; the host resolves the forward_id, + // dials the target host socket, opens a vsock stream, and sends back a + // ConnectResult reporting success or failure. On failure the VM closes + // the pending container connection immediately. + rpc Accept(stream ConnectResult) returns (stream ConnectRequest); +} + +// BindRequest is sent by the host before container creation to set up socket +// forwards on the VM side. +message BindRequest { + repeated Socket sockets = 1; +} + +// Socket describes a single forwarded UNIX socket. +message Socket { + // Opaque identifier for this forward. Each side uses it to resolve the local + // socket path from its own configuration. + string forward_id = 1; + // Path of the UNIX listener socket in the VM's root filesystem (e.g. + // /run/socketfwd/{forward_id}.sock). vminitd creates the socket here; + // the shim has already rewritten the uds mount to a bind mount from + // this path to the user-specified destination inside the container. + string socket_path = 2; +} + +// ConnectRequest is sent by the VM on the Accept stream to notify the host of +// a new container-initiated connection. +message ConnectRequest { + // ID of the vsock stream. The host must open a stream with this ID after + // a successful ConnectResult so that the relay can start. + string stream_id = 1; + // Identifier of the socket forward entry. The host uses this to resolve the + // target host socket path from its own configuration rather than trusting a + // path supplied by the VM. + string forward_id = 2; +} + +// ConnectResult is sent by the host on the Accept stream in response to each +// ConnectRequest. It reports whether the host successfully dialed the target +// socket. On failure the VM closes the pending container connection +// immediately. +message ConnectResult { + // ID of the vsock stream from the corresponding ConnectRequest. + string stream_id = 1; + // Non-empty if the host failed to dial the target socket. The VM closes + // the pending connection when error is set. + string error = 2; +} diff --git a/api/services/socketforward/v1/doc.go b/api/services/socketforward/v1/doc.go new file mode 100644 index 0000000..ddb2060 --- /dev/null +++ b/api/services/socketforward/v1/doc.go @@ -0,0 +1,18 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +// Package socketforward defines the socketforward service. +package socketforward diff --git a/api/services/socketforward/v1/socketforward.pb.go b/api/services/socketforward/v1/socketforward.pb.go new file mode 100644 index 0000000..5a9fc66 --- /dev/null +++ b/api/services/socketforward/v1/socketforward.pb.go @@ -0,0 +1,431 @@ +// +//Copyright The containerd Authors. +// +//Licensed under the Apache License, Version 2.0 (the "License"); +//you may not use this file except in compliance with the License. +//You may obtain a copy of the License at +// +//http://www.apache.org/licenses/LICENSE-2.0 +// +//Unless required by applicable law or agreed to in writing, software +//distributed under the License is distributed on an "AS IS" BASIS, +//WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +//See the License for the specific language governing permissions and +//limitations under the License. + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.28.1 +// protoc (unknown) +// source: proto/nerdbox/services/socketforward/v1/socketforward.proto + +package socketforward + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + emptypb "google.golang.org/protobuf/types/known/emptypb" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// BindRequest is sent by the host before container creation to set up socket +// forwards on the VM side. +type BindRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Sockets []*Socket `protobuf:"bytes,1,rep,name=sockets,proto3" json:"sockets,omitempty"` +} + +func (x *BindRequest) Reset() { + *x = BindRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_proto_nerdbox_services_socketforward_v1_socketforward_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *BindRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*BindRequest) ProtoMessage() {} + +func (x *BindRequest) ProtoReflect() protoreflect.Message { + mi := &file_proto_nerdbox_services_socketforward_v1_socketforward_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use BindRequest.ProtoReflect.Descriptor instead. +func (*BindRequest) Descriptor() ([]byte, []int) { + return file_proto_nerdbox_services_socketforward_v1_socketforward_proto_rawDescGZIP(), []int{0} +} + +func (x *BindRequest) GetSockets() []*Socket { + if x != nil { + return x.Sockets + } + return nil +} + +// Socket describes a single forwarded UNIX socket. +type Socket struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Opaque identifier for this forward. Each side uses it to resolve the local + // socket path from its own configuration. + ForwardID string `protobuf:"bytes,1,opt,name=forward_id,json=forwardId,proto3" json:"forward_id,omitempty"` + // Path of the UNIX listener socket in the VM's root filesystem (e.g. + // /run/socketfwd/{forward_id}.sock). vminitd creates the socket here; + // the shim has already rewritten the uds mount to a bind mount from + // this path to the user-specified destination inside the container. + SocketPath string `protobuf:"bytes,2,opt,name=socket_path,json=socketPath,proto3" json:"socket_path,omitempty"` +} + +func (x *Socket) Reset() { + *x = Socket{} + if protoimpl.UnsafeEnabled { + mi := &file_proto_nerdbox_services_socketforward_v1_socketforward_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Socket) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Socket) ProtoMessage() {} + +func (x *Socket) ProtoReflect() protoreflect.Message { + mi := &file_proto_nerdbox_services_socketforward_v1_socketforward_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Socket.ProtoReflect.Descriptor instead. +func (*Socket) Descriptor() ([]byte, []int) { + return file_proto_nerdbox_services_socketforward_v1_socketforward_proto_rawDescGZIP(), []int{1} +} + +func (x *Socket) GetForwardID() string { + if x != nil { + return x.ForwardID + } + return "" +} + +func (x *Socket) GetSocketPath() string { + if x != nil { + return x.SocketPath + } + return "" +} + +// ConnectRequest is sent by the VM on the Accept stream to notify the host of +// a new container-initiated connection. +type ConnectRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // ID of the vsock stream. The host must open a stream with this ID after + // a successful ConnectResult so that the relay can start. + StreamID string `protobuf:"bytes,1,opt,name=stream_id,json=streamId,proto3" json:"stream_id,omitempty"` + // Identifier of the socket forward entry. The host uses this to resolve the + // target host socket path from its own configuration rather than trusting a + // path supplied by the VM. + ForwardID string `protobuf:"bytes,2,opt,name=forward_id,json=forwardId,proto3" json:"forward_id,omitempty"` +} + +func (x *ConnectRequest) Reset() { + *x = ConnectRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_proto_nerdbox_services_socketforward_v1_socketforward_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ConnectRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ConnectRequest) ProtoMessage() {} + +func (x *ConnectRequest) ProtoReflect() protoreflect.Message { + mi := &file_proto_nerdbox_services_socketforward_v1_socketforward_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ConnectRequest.ProtoReflect.Descriptor instead. +func (*ConnectRequest) Descriptor() ([]byte, []int) { + return file_proto_nerdbox_services_socketforward_v1_socketforward_proto_rawDescGZIP(), []int{2} +} + +func (x *ConnectRequest) GetStreamID() string { + if x != nil { + return x.StreamID + } + return "" +} + +func (x *ConnectRequest) GetForwardID() string { + if x != nil { + return x.ForwardID + } + return "" +} + +// ConnectResult is sent by the host on the Accept stream in response to each +// ConnectRequest. It reports whether the host successfully dialed the target +// socket. On failure the VM closes the pending container connection +// immediately. +type ConnectResult struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // ID of the vsock stream from the corresponding ConnectRequest. + StreamID string `protobuf:"bytes,1,opt,name=stream_id,json=streamId,proto3" json:"stream_id,omitempty"` + // Non-empty if the host failed to dial the target socket. The VM closes + // the pending connection when error is set. + Error string `protobuf:"bytes,2,opt,name=error,proto3" json:"error,omitempty"` +} + +func (x *ConnectResult) Reset() { + *x = ConnectResult{} + if protoimpl.UnsafeEnabled { + mi := &file_proto_nerdbox_services_socketforward_v1_socketforward_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ConnectResult) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ConnectResult) ProtoMessage() {} + +func (x *ConnectResult) ProtoReflect() protoreflect.Message { + mi := &file_proto_nerdbox_services_socketforward_v1_socketforward_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ConnectResult.ProtoReflect.Descriptor instead. +func (*ConnectResult) Descriptor() ([]byte, []int) { + return file_proto_nerdbox_services_socketforward_v1_socketforward_proto_rawDescGZIP(), []int{3} +} + +func (x *ConnectResult) GetStreamID() string { + if x != nil { + return x.StreamID + } + return "" +} + +func (x *ConnectResult) GetError() string { + if x != nil { + return x.Error + } + return "" +} + +var File_proto_nerdbox_services_socketforward_v1_socketforward_proto protoreflect.FileDescriptor + +var file_proto_nerdbox_services_socketforward_v1_socketforward_proto_rawDesc = []byte{ + 0x0a, 0x3b, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x6e, 0x65, 0x72, 0x64, 0x62, 0x6f, 0x78, 0x2f, + 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x2f, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x66, + 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x2f, 0x76, 0x31, 0x2f, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, + 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x21, 0x6e, + 0x65, 0x72, 0x64, 0x62, 0x6f, 0x78, 0x2e, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x2e, + 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x2e, 0x76, 0x31, + 0x1a, 0x1b, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, + 0x66, 0x2f, 0x65, 0x6d, 0x70, 0x74, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x52, 0x0a, + 0x0b, 0x42, 0x69, 0x6e, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x43, 0x0a, 0x07, + 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x29, 0x2e, + 0x6e, 0x65, 0x72, 0x64, 0x62, 0x6f, 0x78, 0x2e, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, + 0x2e, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x2e, 0x76, + 0x31, 0x2e, 0x53, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x07, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, + 0x73, 0x22, 0x48, 0x0a, 0x06, 0x53, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x66, + 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x09, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x49, 0x64, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x6f, + 0x63, 0x6b, 0x65, 0x74, 0x5f, 0x70, 0x61, 0x74, 0x68, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0a, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x50, 0x61, 0x74, 0x68, 0x22, 0x4c, 0x0a, 0x0e, 0x43, + 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1b, 0x0a, + 0x09, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x08, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x66, 0x6f, + 0x72, 0x77, 0x61, 0x72, 0x64, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, + 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x49, 0x64, 0x22, 0x42, 0x0a, 0x0d, 0x43, 0x6f, 0x6e, + 0x6e, 0x65, 0x63, 0x74, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x73, 0x74, + 0x72, 0x65, 0x61, 0x6d, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, + 0x74, 0x72, 0x65, 0x61, 0x6d, 0x49, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x32, 0xd2, 0x01, + 0x0a, 0x0d, 0x53, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x12, + 0x4e, 0x0a, 0x04, 0x42, 0x69, 0x6e, 0x64, 0x12, 0x2e, 0x2e, 0x6e, 0x65, 0x72, 0x64, 0x62, 0x6f, + 0x78, 0x2e, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x2e, 0x73, 0x6f, 0x63, 0x6b, 0x65, + 0x74, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x2e, 0x76, 0x31, 0x2e, 0x42, 0x69, 0x6e, 0x64, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, + 0x71, 0x0a, 0x06, 0x41, 0x63, 0x63, 0x65, 0x70, 0x74, 0x12, 0x30, 0x2e, 0x6e, 0x65, 0x72, 0x64, + 0x62, 0x6f, 0x78, 0x2e, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x2e, 0x73, 0x6f, 0x63, + 0x6b, 0x65, 0x74, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x2e, 0x76, 0x31, 0x2e, 0x43, 0x6f, + 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x1a, 0x31, 0x2e, 0x6e, 0x65, + 0x72, 0x64, 0x62, 0x6f, 0x78, 0x2e, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x2e, 0x73, + 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x2e, 0x76, 0x31, 0x2e, + 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x28, 0x01, + 0x30, 0x01, 0x42, 0x4b, 0x5a, 0x49, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, + 0x2f, 0x63, 0x6f, 0x6e, 0x74, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x64, 0x2f, 0x6e, 0x65, 0x72, 0x64, + 0x62, 0x6f, 0x78, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, + 0x2f, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x2f, 0x76, + 0x31, 0x3b, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x62, + 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_proto_nerdbox_services_socketforward_v1_socketforward_proto_rawDescOnce sync.Once + file_proto_nerdbox_services_socketforward_v1_socketforward_proto_rawDescData = file_proto_nerdbox_services_socketforward_v1_socketforward_proto_rawDesc +) + +func file_proto_nerdbox_services_socketforward_v1_socketforward_proto_rawDescGZIP() []byte { + file_proto_nerdbox_services_socketforward_v1_socketforward_proto_rawDescOnce.Do(func() { + file_proto_nerdbox_services_socketforward_v1_socketforward_proto_rawDescData = protoimpl.X.CompressGZIP(file_proto_nerdbox_services_socketforward_v1_socketforward_proto_rawDescData) + }) + return file_proto_nerdbox_services_socketforward_v1_socketforward_proto_rawDescData +} + +var file_proto_nerdbox_services_socketforward_v1_socketforward_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_proto_nerdbox_services_socketforward_v1_socketforward_proto_goTypes = []interface{}{ + (*BindRequest)(nil), // 0: nerdbox.services.socketforward.v1.BindRequest + (*Socket)(nil), // 1: nerdbox.services.socketforward.v1.Socket + (*ConnectRequest)(nil), // 2: nerdbox.services.socketforward.v1.ConnectRequest + (*ConnectResult)(nil), // 3: nerdbox.services.socketforward.v1.ConnectResult + (*emptypb.Empty)(nil), // 4: google.protobuf.Empty +} +var file_proto_nerdbox_services_socketforward_v1_socketforward_proto_depIdxs = []int32{ + 1, // 0: nerdbox.services.socketforward.v1.BindRequest.sockets:type_name -> nerdbox.services.socketforward.v1.Socket + 0, // 1: nerdbox.services.socketforward.v1.SocketForward.Bind:input_type -> nerdbox.services.socketforward.v1.BindRequest + 3, // 2: nerdbox.services.socketforward.v1.SocketForward.Accept:input_type -> nerdbox.services.socketforward.v1.ConnectResult + 4, // 3: nerdbox.services.socketforward.v1.SocketForward.Bind:output_type -> google.protobuf.Empty + 2, // 4: nerdbox.services.socketforward.v1.SocketForward.Accept:output_type -> nerdbox.services.socketforward.v1.ConnectRequest + 3, // [3:5] is the sub-list for method output_type + 1, // [1:3] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_proto_nerdbox_services_socketforward_v1_socketforward_proto_init() } +func file_proto_nerdbox_services_socketforward_v1_socketforward_proto_init() { + if File_proto_nerdbox_services_socketforward_v1_socketforward_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_proto_nerdbox_services_socketforward_v1_socketforward_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*BindRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proto_nerdbox_services_socketforward_v1_socketforward_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Socket); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proto_nerdbox_services_socketforward_v1_socketforward_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ConnectRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proto_nerdbox_services_socketforward_v1_socketforward_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ConnectResult); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_proto_nerdbox_services_socketforward_v1_socketforward_proto_rawDesc, + NumEnums: 0, + NumMessages: 4, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_proto_nerdbox_services_socketforward_v1_socketforward_proto_goTypes, + DependencyIndexes: file_proto_nerdbox_services_socketforward_v1_socketforward_proto_depIdxs, + MessageInfos: file_proto_nerdbox_services_socketforward_v1_socketforward_proto_msgTypes, + }.Build() + File_proto_nerdbox_services_socketforward_v1_socketforward_proto = out.File + file_proto_nerdbox_services_socketforward_v1_socketforward_proto_rawDesc = nil + file_proto_nerdbox_services_socketforward_v1_socketforward_proto_goTypes = nil + file_proto_nerdbox_services_socketforward_v1_socketforward_proto_depIdxs = nil +} diff --git a/api/services/socketforward/v1/socketforward_ttrpc.pb.go b/api/services/socketforward/v1/socketforward_ttrpc.pb.go new file mode 100644 index 0000000..47af9c6 --- /dev/null +++ b/api/services/socketforward/v1/socketforward_ttrpc.pb.go @@ -0,0 +1,116 @@ +// Code generated by protoc-gen-go-ttrpc. DO NOT EDIT. +// source: proto/nerdbox/services/socketforward/v1/socketforward.proto +package socketforward + +import ( + context "context" + ttrpc "github.com/containerd/ttrpc" + emptypb "google.golang.org/protobuf/types/known/emptypb" +) + +type TTRPCSocketForwardService interface { + Bind(context.Context, *BindRequest) (*emptypb.Empty, error) + Accept(context.Context, TTRPCSocketForward_AcceptServer) error +} + +type TTRPCSocketForward_AcceptServer interface { + Send(*ConnectRequest) error + Recv() (*ConnectResult, error) + ttrpc.StreamServer +} + +type ttrpcsocketforwardAcceptServer struct { + ttrpc.StreamServer +} + +func (x *ttrpcsocketforwardAcceptServer) Send(m *ConnectRequest) error { + return x.StreamServer.SendMsg(m) +} + +func (x *ttrpcsocketforwardAcceptServer) Recv() (*ConnectResult, error) { + m := new(ConnectResult) + if err := x.StreamServer.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func RegisterTTRPCSocketForwardService(srv *ttrpc.Server, svc TTRPCSocketForwardService) { + srv.RegisterService("nerdbox.services.socketforward.v1.SocketForward", &ttrpc.ServiceDesc{ + Methods: map[string]ttrpc.Method{ + "Bind": func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) { + var req BindRequest + if err := unmarshal(&req); err != nil { + return nil, err + } + return svc.Bind(ctx, &req) + }, + }, + Streams: map[string]ttrpc.Stream{ + "Accept": { + Handler: func(ctx context.Context, stream ttrpc.StreamServer) (interface{}, error) { + return nil, svc.Accept(ctx, &ttrpcsocketforwardAcceptServer{stream}) + }, + StreamingClient: true, + StreamingServer: true, + }, + }, + }) +} + +type TTRPCSocketForwardClient interface { + Bind(context.Context, *BindRequest) (*emptypb.Empty, error) + Accept(context.Context) (TTRPCSocketForward_AcceptClient, error) +} + +type ttrpcsocketforwardClient struct { + client *ttrpc.Client +} + +func NewTTRPCSocketForwardClient(client *ttrpc.Client) TTRPCSocketForwardClient { + return &ttrpcsocketforwardClient{ + client: client, + } +} + +func (c *ttrpcsocketforwardClient) Bind(ctx context.Context, req *BindRequest) (*emptypb.Empty, error) { + var resp emptypb.Empty + if err := c.client.Call(ctx, "nerdbox.services.socketforward.v1.SocketForward", "Bind", req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +func (c *ttrpcsocketforwardClient) Accept(ctx context.Context) (TTRPCSocketForward_AcceptClient, error) { + stream, err := c.client.NewStream(ctx, &ttrpc.StreamDesc{ + StreamingClient: true, + StreamingServer: true, + }, "nerdbox.services.socketforward.v1.SocketForward", "Accept", nil) + if err != nil { + return nil, err + } + x := &ttrpcsocketforwardAcceptClient{stream} + return x, nil +} + +type TTRPCSocketForward_AcceptClient interface { + Send(*ConnectResult) error + Recv() (*ConnectRequest, error) + ttrpc.ClientStream +} + +type ttrpcsocketforwardAcceptClient struct { + ttrpc.ClientStream +} + +func (x *ttrpcsocketforwardAcceptClient) Send(m *ConnectResult) error { + return x.ClientStream.SendMsg(m) +} + +func (x *ttrpcsocketforwardAcceptClient) Recv() (*ConnectRequest, error) { + m := new(ConnectRequest) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} diff --git a/cmd/vminitd/main.go b/cmd/vminitd/main.go index 012ee88..3e76bc9 100644 --- a/cmd/vminitd/main.go +++ b/cmd/vminitd/main.go @@ -50,6 +50,7 @@ import ( _ "github.com/containerd/nerdbox/plugins/services/transfer" _ "github.com/containerd/nerdbox/plugins/vminit/events" + _ "github.com/containerd/nerdbox/plugins/vminit/socketforward" _ "github.com/containerd/nerdbox/plugins/vminit/streaming" _ "github.com/containerd/nerdbox/plugins/vminit/task" ) diff --git a/docs/socket-forwarding.md b/docs/socket-forwarding.md new file mode 100644 index 0000000..58f31e7 --- /dev/null +++ b/docs/socket-forwarding.md @@ -0,0 +1,116 @@ +# UNIX Socket Forwarding + +Nerdbox supports forwarding UNIX domain sockets across the VM boundary, +allowing processes inside a container to connect to host-side services. +Sockets are relayed over vsock streams managed by vminitd. + +## Configuration + +Socket forwards are specified as OCI mounts with type `uds`. The shim +intercepts these mounts before the OCI runtime sees them. + +| Mount field | Description | +|---------------|----------------------------------------------------------| +| `type` | Must be `uds` | +| `source` | Path of the UNIX socket on the host (the target service) | +| `destination` | Path of the UNIX socket inside the container | + +A **forward identifier** (`forward_id`) is derived from +`SHA256(container_id + ":" + destination)` and serves as an opaque key +used in host-to-VM communication so the VM never needs to know (or supply) +host-side paths (see [Security model](#security-model)). Including the +container ID ensures that the same socket mount on two different containers +within the same VM gets a distinct identifier. + +The VM listener socket is created at `/run/socketfwd/{forward_id}.sock` +(the `vm_path`). The shim rewrites each `uds` mount to a bind mount from +`vm_path` to `destination` before passing the spec to crun. + +### Example: Forward a host socket into a container + +Make the host's Docker socket available inside the container at +`/var/run/docker.sock`: + +```json +{ + "type": "uds", + "source": "/var/run/docker.sock", + "destination": "/var/run/docker.sock" +} +``` + +## How It Works + +Socket forwarding is coordinated between the shim (host side) and a vminitd +plugin (container side) using the `SocketForward` ttrpc service. + +### Setup (Bind) + +After the VM is started, but before the container is created, the shim calls +the `Bind` RPC with the list of socket forward entries. For each entry the +VM-side plugin creates a UNIX listener socket at `vm_path` (`/run/socketfwd/{forward_id}.sock`). +The shim has already rewritten the `uds` mount to a bind mount from `vm_path` +to `destination`, so once crun processes the spec the socket appears inside +the container at the user-specified `destination`. + +`Bind` returns only after all socket files have been created, so the shim +can safely proceed with container creation. + +### Forwarding flow (Accept) + +1. The `Bind` RPC (see above) has already created a UNIX listener at + `vm_path`, with a bind mount so the socket appears inside the container + at `destination`. +2. When a container process connects, the vminitd plugin generates a unique + `stream_id` and sends a `ConnectRequest` containing the `stream_id` + and `forward_id` through the `Accept` streaming RPC. +3. The shim receives the notification and resolves `forward_id` to a + host-side socket path using its own local configuration. It then dials + that host socket and opens a vsock stream with the given `stream_id`. +4. The vminitd plugin matches the vsock stream to the pending connection and + relays data bidirectionally. + +``` +Container process ──► listener (/run/socketfwd/{forward_id}.sock) + │ ▲ + bind mount to destination + │ + ConnectRequest + (stream_id + forward_id) + via Accept stream + │ + ▼ + shim receives notification + resolve forward_id ──► host_path + │ + dial host_path ──► Host socket + open vsock stream + │ + ▼ + bidirectional relay +``` + +### Caveats + +The container-side `connect` syscall always succeeds immediately (it connects +to the vminitd listener socket). If the shim subsequently fails to dial the +host socket — for example because the host service is not running — the +connection is closed immediately after the container-side `connect` returns. +From the container process's perspective this looks like the peer closed the +connection right away. + +## Security model + +The VM is treated as a security boundary. The shim (host side) never trusts +paths or other security-sensitive values supplied by the VM. Concretely: + +- The `host_path` is **never sent to the VM**. The `Bind` RPC sends only + the `forward_id` and `vm_path`. +- When the VM notifies the shim about a new connection (via `Accept`), it + sends a `forward_id` -- not a host path. The shim resolves the identifier + to a host path from its own configuration, which was derived from the OCI + mounts at container creation time. +- A `ConnectRequest` with an unknown `forward_id` is rejected and logged. + +This design ensures that even a compromised VM cannot cause the shim to +connect to arbitrary host sockets. diff --git a/docs/vsock-streaming.md b/docs/vsock-streaming.md index ad14905..a2eda44 100644 --- a/docs/vsock-streaming.md +++ b/docs/vsock-streaming.md @@ -140,12 +140,14 @@ below), the streaming plugin also exposes a `StreamGetter` in a `vsockStream` providing `Send(typeurl.Any)` / `Recv() typeurl.Any` with length-prefixed protobuf framing. -## Coordination pattern +## Coordination patterns The host and guest must agree on a stream ID before either side can use the connection. Since the stream channel itself has no signaling mechanism, the ID is always exchanged over the **TTRPC control channel** (port 1025). +### Host-initiated streams + The host generates a stream ID, opens the stream via `StartStream`, and then sends the ID to the guest through a TTRPC request. The guest calls `streams.Get(id)` to claim the connection. @@ -161,8 +163,28 @@ RPC(streamID) ----TTRPC---> receive RPC <-- bidirectional I/O --> ``` -Currently all stream users follow this pattern -- stream IDs are embedded -in TTRPC request fields (e.g. `stream://` URIs for container stdio). +This is the most common pattern. Container stdio uses it. + +### Guest-initiated streams + +The guest generates a stream ID and notifies the host through a +bidirectional streaming TTRPC RPC. The host then opens the stream using the +ID it received. + +``` +Host Guest +---- ----- + streamID = generate() + send streamID via streaming RPC +receive streamID <--TTRPC--- +conn = StartStream(streamID) + (stream registered by accept loop) + conn = streams.Get(streamID) +<-- bidirectional I/O --> +``` + +Socket forwarding uses this pattern: the guest sends a `ConnectRequest` +through the `Accept` stream, and the host opens the vsock stream in response. ## Existing stream users @@ -212,6 +234,14 @@ Inside the VM, the `vsockStream` wrapper `Send` / `Recv` with the same framing, providing a `streaming.Stream` interface to guest-side plugins. +### Socket forwarding + +UNIX domain sockets can be forwarded across the VM boundary. Each forwarded +connection uses a dedicated vsock stream for the data relay, using the +guest-initiated pattern: the VM notifies the host via the `Accept` stream, +then the host opens a vsock stream to complete the relay. See +[Socket Forwarding](socket-forwarding.md) for full details. + ## Adding a new stream type To add a new feature that relays data through vsock streams: diff --git a/internal/shim/task/service.go b/internal/shim/task/service.go index fc57b6a..eba21ac 100644 --- a/internal/shim/task/service.go +++ b/internal/shim/task/service.go @@ -85,9 +85,34 @@ func NewTaskService(ctx context.Context, sb sandbox.Sandbox, publisher shim.Publ type container struct { ioShutdown func(context.Context) error + // forwarder is the UNIX socket forwarder for this specific container. + forwarder *socketForwarder + execShutdowns map[string]func(context.Context) error } +// shutdown shuts down the container's IO streams, socket forwarding, and all +// exec IO streams. +func (c *container) shutdown(ctx context.Context) error { + var errs []error + if c.ioShutdown != nil { + if err := c.ioShutdown(ctx); err != nil { + errs = append(errs, fmt.Errorf("io shutdown: %w", err)) + } + } + if c.forwarder != nil { + if err := c.forwarder.shutdown(); err != nil { + errs = append(errs, fmt.Errorf("socket forward shutdown: %w", err)) + } + } + for execID, ioShutdown := range c.execShutdowns { + if err := ioShutdown(ctx); err != nil { + errs = append(errs, fmt.Errorf("exec %q io shutdown: %w", execID, err)) + } + } + return errors.Join(errs...) +} + // service is the shim implementation of a remote shim over GRPC type service struct { mu sync.Mutex @@ -115,15 +140,8 @@ func (s *service) shutdown(ctx context.Context) error { var errs []error for id, c := range s.containers { - if c.ioShutdown != nil { - if err := c.ioShutdown(ctx); err != nil { - errs = append(errs, fmt.Errorf("container %q io shutdown: %w", id, err)) - } - } - for execID, ioShutdown := range c.execShutdowns { - if err := ioShutdown(ctx); err != nil { - errs = append(errs, fmt.Errorf("container %q exec %q io shutdown: %w", id, execID, err)) - } + if err := c.shutdown(ctx); err != nil { + errs = append(errs, fmt.Errorf("container %q shutdown: %w", id, err)) } } @@ -167,6 +185,7 @@ func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (_ * resCfg resourceConfig dumpInfoCfg dumpInfoConfig bm bindMounter + sfpr = socketForwardsProvider{containerID: r.ID} ) // Load the OCI bundle and apply transformers to get the bundle that'll be // set up on the VM side. @@ -176,6 +195,7 @@ func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (_ * ctrNetCfg.fromBundle, resCfg.FromBundle, dumpInfoCfg.FromBundle, + sfpr.FromBundle, func(ctx context.Context, b *bundle.Bundle) error { // If there are no VM networks, try falling back to host's resolv.conf (for TSI). return addResolvConf(ctx, b, len(nwpr.nws) == 0) @@ -325,12 +345,28 @@ func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (_ * return nil, errgrpc.ToGRPC(err) } + // Bind socket forwards on the VM before container creation so + // that crun can bind-mount the listener sockets into the container. + if err := bindSockets(ctx, s.sb, sfpr.entries); err != nil { + if err := ioShutdown(ctx); err != nil { + log.G(ctx).WithError(err).Error("failed to shutdown io after socket forwarding failure") + } + return nil, errgrpc.ToGRPC(err) + } + // setupTime is the total time to setup the VM and everything needed // to proxy the create task request. This measures the overall // overhead of creating the container inside the VM. setupTime := time.Since(presetup) - vr := &taskAPI.CreateTaskRequest{ + preCreate := time.Now() + c := &container{ + ioShutdown: ioShutdown, + execShutdowns: make(map[string]func(context.Context) error), + } + + tc := taskAPI.NewTTRPCTaskClient(vmc) + resp, err := tc.Create(ctx, &taskAPI.CreateTaskRequest{ ID: r.ID, Bundle: br.Bundle, Rootfs: m, @@ -339,25 +375,26 @@ func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (_ * Stdout: cio.Stdout, Stderr: cio.Stderr, Options: r.Options, + }) + if err != nil { + log.G(ctx).WithError(err).Error("failed to create task") + if err := c.shutdown(ctx); err != nil { + log.G(ctx).WithError(err).Error("failed to shutdown container after create failure") + } + return nil, errgrpc.ToGRPC(err) } - preCreate := time.Now() - c := &container{ - ioShutdown: ioShutdown, - execShutdowns: make(map[string]func(context.Context) error), - } - tc := taskAPI.NewTTRPCTaskClient(vmc) - resp, err := tc.Create(ctx, vr) + // Start the Accept stream after the container has been created so the + // host can relay forwarded connections from the VM. + fwder, err := startSocketForwarding(context.Background(), s.sb, r.ID, sfpr.entries) if err != nil { - log.G(ctx).WithError(err).Error("failed to create task") - if c.ioShutdown != nil { - // TODO: stop this - if err := c.ioShutdown(ctx); err != nil { - log.G(ctx).WithError(err).Error("failed to shutdown io after create failure") - } + log.G(ctx).WithError(err).Error("failed to start socket forwarding") + if err := c.shutdown(ctx); err != nil { + log.G(ctx).WithError(err).Error("failed to shutdown container after socket forwarding failure") } return nil, errgrpc.ToGRPC(err) } + c.forwarder = fwder log.G(ctx).WithFields(log.Fields{ "t_boot": bootTime, @@ -425,15 +462,8 @@ func (s *service) Delete(ctx context.Context, r *taskAPI.DeleteRequest) (*taskAP delete(c.execShutdowns, r.ExecID) } } else { - if c.ioShutdown != nil { - if err := c.ioShutdown(ctx); err != nil { - log.G(ctx).WithError(err).Error("failed to shutdown io after delete") - } - } - for execID, ioShutdown := range c.execShutdowns { - if err := ioShutdown(ctx); err != nil { - log.G(ctx).WithError(err).WithField("exec", execID).Error("failed to shutdown exec io after delete") - } + if err := c.shutdown(ctx); err != nil { + log.G(ctx).WithError(err).Error("failed to shutdown container after delete") } delete(s.containers, r.ID) } diff --git a/internal/shim/task/socketforward.go b/internal/shim/task/socketforward.go new file mode 100644 index 0000000..b83361d --- /dev/null +++ b/internal/shim/task/socketforward.go @@ -0,0 +1,323 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package task + +import ( + "context" + "crypto/sha256" + "errors" + "fmt" + "io" + "net" + "strings" + + "github.com/containerd/log" + "github.com/opencontainers/runtime-spec/specs-go" + + socketforward "github.com/containerd/nerdbox/api/services/socketforward/v1" + "github.com/containerd/nerdbox/internal/shim/sandbox" + "github.com/containerd/nerdbox/internal/shim/task/bundle" +) + +// socketForwardEntry describes a single UDS socket forward. The host-side +// socket is the target service; the container-side path is the listener +// socket created by vminitd so that container processes can connect to it. +type socketForwardEntry struct { + // id is an opaque identifier for this forward, derived from + // SHA256(containerID + ":" + destination). It is passed to vminitd so + // the VM can reference the forward without knowing the host-side path. + // The container ID is included so that the same socket mount on two + // different containers within the same VM gets a distinct identifier. + id string + // hostPath is the path of the UNIX socket on the host. This value is + // only used by the shim (host side) and is never sent to the VM. + hostPath string + // vmPath is the path of the UNIX listener socket in the VM's root + // filesystem, set to /run/socketfwd/{id}.sock. It is the source of the + // OCI bind mount that makes the socket visible inside the container at + // containerPath. vminitd creates the socket here before container + // creation so that crun can complete the bind mount. + vmPath string + // containerPath is the user-specified destination: where the socket + // appears inside the container rootfs. + containerPath string +} + +// socketForwardsProvider parses OCI mounts with type "uds" and generates the +// init args and sandbox options needed for socket forwarding. +type socketForwardsProvider struct { + // containerID must be set before calling FromBundle. + containerID string + entries []socketForwardEntry +} + +// FromBundle rewrites UDS mounts (type "uds") in the OCI bundle spec to +// ordinary bind mounts that crun can process. Non-UDS mounts are left +// untouched. +// +// UDS mount fields: +// - Source: host-side UNIX socket path (the target service) +// - Destination: container-side UNIX socket path +// - Options: none +// +// Each UDS mount is replaced in-place with a bind mount whose source is +// vmPath (/run/socketfwd/{id}.sock) and whose destination is the original +// container-side path. vminitd creates the listener socket at vmPath before +// container creation so that crun can complete the bind mount. +func (p *socketForwardsProvider) FromBundle(ctx context.Context, b *bundle.Bundle) error { + for i, m := range b.Spec.Mounts { + if m.Type != "uds" { + continue + } + + entry, err := parseUDSMount(p.containerID, m) + if err != nil { + return fmt.Errorf("parsing uds mount for %q: %w", m.Destination, err) + } + + b.Spec.Mounts[i] = specs.Mount{ + Destination: entry.containerPath, + Type: "bind", + Source: entry.vmPath, + Options: []string{"bind"}, + } + + log.G(ctx).WithFields(log.Fields{ + "id": entry.id, + "source": entry.hostPath, + "destination": entry.containerPath, + "vm_path": entry.vmPath, + }).Debug("socketforward: added bind mount for UDS forward") + + p.entries = append(p.entries, entry) + } + return nil +} + +func parseUDSMount(containerID string, m specs.Mount) (socketForwardEntry, error) { + if m.Source == "" { + return socketForwardEntry{}, fmt.Errorf("source (host path) is required") + } + if m.Destination == "" { + return socketForwardEntry{}, fmt.Errorf("destination (container path) is required") + } + if len(m.Options) > 0 { + return socketForwardEntry{}, fmt.Errorf("unknown option %q", strings.Join(m.Options, ", ")) + } + + hash := sha256.Sum256([]byte(containerID + ":" + m.Destination)) + return socketForwardEntry{ + id: fmt.Sprintf("%x", hash), + hostPath: m.Source, + vmPath: fmt.Sprintf("/run/socketfwd/%x.sock", hash), + containerPath: m.Destination, + }, nil +} + +// bindSockets calls the Bind RPC on the VM to set up socket forward +// listener sockets. This must be called before container creation so that +// crun can bind-mount the listener sockets into the container. +func bindSockets(ctx context.Context, sb sandbox.Sandbox, entries []socketForwardEntry) error { + if len(entries) == 0 { + return nil + } + + vmc, err := sb.Client() + if err != nil { + return fmt.Errorf("getting ttrpc client for socket forwarding: %w", err) + } + + sockets := make([]*socketforward.Socket, 0, len(entries)) + for _, e := range entries { + sockets = append(sockets, &socketforward.Socket{ + ForwardID: e.id, + SocketPath: e.vmPath, + }) + } + + sfClient := socketforward.NewTTRPCSocketForwardClient(vmc) + if _, err := sfClient.Bind(ctx, &socketforward.BindRequest{ + Sockets: sockets, + }); err != nil { + return fmt.Errorf("bind RPC: %w", err) + } + return nil +} + +// socketForwarder manages active UDS socket forwarding for a single container. +// It is started after the container is created and runs for the container +// lifetime. +type socketForwarder struct { + sb sandbox.Sandbox + containerID string + // entries maps a forward identifier to its entry. + entries map[string]socketForwardEntry + // closeCh is closed when the socket forwarder is shutting down. + closeCh chan struct{} +} + +// startSocketForwarding creates a socketForwarder and starts the Accept loop +// for forwarded connection notifications from the VM. +func startSocketForwarding(ctx context.Context, sb sandbox.Sandbox, containerID string, entries []socketForwardEntry) (*socketForwarder, error) { + fwd := &socketForwarder{ + sb: sb, + containerID: containerID, + entries: make(map[string]socketForwardEntry, len(entries)), + closeCh: make(chan struct{}), + } + + vmc, err := sb.Client() + if err != nil { + return nil, fmt.Errorf("getting ttrpc client for socket forwarding: %w", err) + } + sfClient := socketforward.NewTTRPCSocketForwardClient(vmc) + + for _, e := range entries { + fwd.entries[e.id] = e + } + + go fwd.acceptLoop(ctx, sb, sfClient) + + return fwd, nil +} + +// acceptLoop connects to the Accept stream and dispatches forwarded +// connection notifications from the VM. +func (fwd *socketForwarder) acceptLoop(ctx context.Context, sb sandbox.Sandbox, sfClient socketforward.TTRPCSocketForwardClient) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + stream, err := sfClient.Accept(ctx) + if err != nil { + log.G(ctx).WithError(err).Error("socketforward: Accept RPC failed") + return + } + + results := make(chan *socketforward.ConnectResult, 64) + go func() { + for { + select { + case r, ok := <-results: + if !ok { + return + } + if err := stream.Send(r); err != nil { + log.G(ctx).WithError(err).Error("socketforward: sending ConnectResult") + } + case <-ctx.Done(): + return + } + } + }() + + go func() { + for { + req, err := stream.Recv() + if err != nil { + if !errors.Is(err, io.EOF) && !errors.Is(err, context.Canceled) { + log.G(ctx).WithError(err).Error("socketforward: Accept stream error") + } + return + } + select { + case results <- fwd.handleConnection(ctx, sb, req): + case <-fwd.closeCh: + return + } + + } + }() + + <-fwd.closeCh +} + +func (fwd *socketForwarder) handleConnection(ctx context.Context, sb sandbox.Sandbox, req *socketforward.ConnectRequest) *socketforward.ConnectResult { + // Resolve the host path from local state. The VM only supplies the + // forward identifier; the host never trusts a path from the VM. + entry, ok := fwd.entries[req.ForwardID] + if !ok { + log.G(ctx).WithField("forward_id", req.ForwardID).Error("socketforward: unknown forward ID from VM") + return &socketforward.ConnectResult{ + StreamID: req.StreamID, + Error: fmt.Sprintf("unknown forward ID: %s", req.ForwardID), + } + } + + log.G(ctx).WithFields(log.Fields{ + "stream_id": req.StreamID, + "forward_id": req.ForwardID, + "host_path": entry.hostPath, + }).Debug("socketforward: new forwarded connection") + + // Dial the host-side target socket. + hostConn, err := net.Dial("unix", entry.hostPath) + if err != nil { + log.G(ctx).WithError(err).WithField("host_path", entry.hostPath).Error("socketforward: failed to dial host socket") + return &socketforward.ConnectResult{ + StreamID: req.StreamID, + Error: "failed to dial host socket", + } + } + + // Open a vsock stream so the VM side can associate it with the pending + // connection and start relaying. + vsockConn, err := sb.StartStream(ctx, req.StreamID) + if err != nil { + log.G(ctx).WithError(err).WithField("stream_id", req.StreamID).Error("socketforward: failed to open vsock stream") + hostConn.Close() + return &socketforward.ConnectResult{StreamID: req.StreamID, Error: err.Error()} + } + + go relay(hostConn, vsockConn, fwd.closeCh) + + return &socketforward.ConnectResult{StreamID: req.StreamID} +} + +// relay copies data bidirectionally between two connections until one side +// closes or errors. Both connections are closed when done. +func relay(a, b io.ReadWriteCloser, closeCh <-chan struct{}) { + done := make(chan struct{}, 2) + cp := func(dst io.Writer, src io.Reader) { + io.Copy(dst, src) + done <- struct{}{} + } + go cp(a, b) + go cp(b, a) + go func() { + <-closeCh + a.Close() + b.Close() + }() + + <-done // one side finished + a.Close() + b.Close() + <-done // wait for the other +} + +func (fwd *socketForwarder) shutdown() error { + if fwd == nil { + return nil + } + select { + case <-fwd.closeCh: + // already closed + default: + close(fwd.closeCh) + } + return nil +} diff --git a/internal/shim/task/socketforward_test.go b/internal/shim/task/socketforward_test.go new file mode 100644 index 0000000..7e66090 --- /dev/null +++ b/internal/shim/task/socketforward_test.go @@ -0,0 +1,136 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package task + +import ( + "context" + "testing" + + "github.com/opencontainers/runtime-spec/specs-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/containerd/nerdbox/internal/shim/task/bundle" +) + +func TestSocketForwardsProviderFromBundle(t *testing.T) { + ctx := context.Background() + + testcases := []struct { + name string + cid string + mounts []specs.Mount + wantErr string + wantMounts []specs.Mount + wantEntries []socketForwardEntry + }{ + { + name: "empty source", + cid: "c1", + mounts: []specs.Mount{ + {Type: "uds", Source: "", Destination: "/run/docker.sock"}, + }, + wantErr: "source (host path) is required", + }, + { + name: "empty destination", + cid: "c1", + mounts: []specs.Mount{ + {Type: "uds", Source: "/var/run/docker.sock", Destination: ""}, + }, + wantErr: "destination (container path) is required", + }, + { + name: "non-empty options", + cid: "c1", + mounts: []specs.Mount{ + { + Type: "uds", + Source: "/var/run/docker.sock", + Destination: "/run/docker.sock", + Options: []string{"first", "second"}, + }, + }, + wantErr: `unknown option "first, second"`, + }, + { + name: "uds mount rewritten to bind", + cid: "container-abc", + mounts: []specs.Mount{ + { + Type: "uds", + Source: "/var/run/foo.sock", + Destination: "/var/run/bar.sock", + }, + { + Type: "uds", + Source: "/var/run/abc.sock", + Destination: "/var/run/def.sock", + }, + }, + wantMounts: []specs.Mount{ + { + Type: "bind", + Source: "/run/socketfwd/954a6df32e91bb55e6fcd9df9f90728e56b4f87aa92b22fe3b63df33f18a3188.sock", + Destination: "/var/run/bar.sock", + Options: []string{"bind"}, + }, + { + Type: "bind", + Source: "/run/socketfwd/79a1c3c374d3573c0e7f1e7ca567d3531b3aefd3c202b8811df64e77fdaeab0c.sock", + Destination: "/var/run/def.sock", + Options: []string{"bind"}, + }, + }, + wantEntries: []socketForwardEntry{ + { + id: "954a6df32e91bb55e6fcd9df9f90728e56b4f87aa92b22fe3b63df33f18a3188", + hostPath: "/var/run/foo.sock", + vmPath: "/run/socketfwd/954a6df32e91bb55e6fcd9df9f90728e56b4f87aa92b22fe3b63df33f18a3188.sock", + containerPath: "/var/run/bar.sock", + }, + { + id: "79a1c3c374d3573c0e7f1e7ca567d3531b3aefd3c202b8811df64e77fdaeab0c", + hostPath: "/var/run/abc.sock", + vmPath: "/run/socketfwd/79a1c3c374d3573c0e7f1e7ca567d3531b3aefd3c202b8811df64e77fdaeab0c.sock", + containerPath: "/var/run/def.sock", + }, + }, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + mounts := append([]specs.Mount(nil), tc.mounts...) + b := &bundle.Bundle{Spec: specs.Spec{Mounts: mounts}} + p := &socketForwardsProvider{containerID: tc.cid} + + err := p.FromBundle(ctx, b) + + if tc.wantErr != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tc.wantErr) + return + } + + require.NoError(t, err) + assert.Equal(t, tc.wantMounts, b.Spec.Mounts) + assert.Equal(t, tc.wantEntries, p.entries) + }) + } +} diff --git a/internal/vminit/socketforward/socketforward.go b/internal/vminit/socketforward/socketforward.go new file mode 100644 index 0000000..c6b6477 --- /dev/null +++ b/internal/vminit/socketforward/socketforward.go @@ -0,0 +1,297 @@ +//go:build linux + +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package socketforward + +import ( + "context" + "crypto/rand" + "encoding/base64" + "errors" + "fmt" + "io" + "net" + "os" + "path/filepath" + "sync" + "time" + + "github.com/containerd/log" + "github.com/containerd/ttrpc" + "google.golang.org/protobuf/types/known/emptypb" + + socketforward "github.com/containerd/nerdbox/api/services/socketforward/v1" + "github.com/containerd/nerdbox/internal/vminit/stream" +) + +// GenerateStreamID creates a pseudo-random stream ID with the given prefix. +// The format is "{prefix}-{nanosecond}-{random}" to minimize collisions. +func GenerateStreamID(prefix string) (string, error) { + var b [4]byte + if _, err := rand.Read(b[:]); err != nil { + return "", fmt.Errorf("generating stream ID: %w", err) + } + return fmt.Sprintf("%s-%d-%s", prefix, time.Now().UnixNano(), base64.RawURLEncoding.EncodeToString(b[:])), nil +} + +// Service implements the VM-side socket forwarding ttrpc service. +// The VM creates UNIX listener sockets inside the VM and notifies the host +// when a container process connects, so the host can open a vsock stream +// and relay data to the target host-side socket. +type Service struct { + streams stream.Manager + + mu sync.Mutex + listeners []net.Listener + // pending maps stream_id to a channel that receives the host's dial + // result (nil on success, non-nil on failure) for each in-flight + // connection. Entries are added by handleConnection before sending the + // ConnectRequest and removed when the ConnectResult arrives or the + // context is cancelled. + pending map[string]chan error + + // notify delivers ConnectRequest messages to the Accept stream so the + // host shim can set up the vsock relay for each new connection. + notify chan *socketforward.ConnectRequest +} + +// NewService creates a new socket forwarding service. +func NewService(streams stream.Manager) *Service { + return &Service{ + streams: streams, + pending: make(map[string]chan error), + notify: make(chan *socketforward.ConnectRequest, 64), + } +} + +// RegisterTTRPC registers the socket forwarding service with the ttrpc server. +func (s *Service) RegisterTTRPC(server *ttrpc.Server) error { + socketforward.RegisterTTRPCSocketForwardService(server, s) + return nil +} + +// Bind sets up socket forward entries on the VM side. For each entry it +// creates a UNIX listener socket at the given socket_path. This method +// returns only after all socket files have been created. +func (s *Service) Bind(ctx context.Context, req *socketforward.BindRequest) (*emptypb.Empty, error) { + s.mu.Lock() + defer s.mu.Unlock() + + for _, sock := range req.Sockets { + if err := s.bind(ctx, sock.ForwardID, sock.SocketPath); err != nil { + return nil, fmt.Errorf("binding socket forward listener at %s: %w", sock.SocketPath, err) + } + } + + return &emptypb.Empty{}, nil +} + +// Accept is a bidirectional streaming RPC used to coordinate forwarded +// connections. The VM sends a ConnectRequest when a container process +// connects to a forwarded socket; the host resolves the forward_id, +// dials the target host socket, opens a vsock stream, and sends back a +// ConnectResult reporting success or failure. On failure the VM closes +// the pending container connection immediately. +func (s *Service) Accept(ctx context.Context, srv socketforward.TTRPCSocketForward_AcceptServer) error { + log.G(ctx).Debug("socketforward: Accept started") + + // Receive ConnectResult messages from the host and dispatch them to the + // goroutines waiting in handleConnection. + go func() { + for { + result, err := srv.Recv() + if err != nil { + return + } + s.mu.Lock() + ch, ok := s.pending[result.StreamID] + if ok { + delete(s.pending, result.StreamID) + } + s.mu.Unlock() + if !ok { + log.G(ctx).WithField("stream_id", result.StreamID).Warn("socketforward: ConnectResult for unknown stream, ignoring") + continue + } + var dialErr error + if result.Error != "" { + dialErr = errors.New(result.Error) + } + ch <- dialErr + } + }() + + for { + select { + case req := <-s.notify: + if err := srv.Send(req); err != nil { + return err + } + case <-ctx.Done(): + return ctx.Err() + } + } +} + +// bind creates a UNIX listener at socketPath. When a connection arrives, +// it generates a stream ID, sends a ConnectRequest on the notify channel +// (for the Accept stream), then waits for the host to report the dial +// result via a ConnectResult before proceeding with the vsock relay. +func (s *Service) bind(ctx context.Context, forwardID, socketPath string) error { + if err := os.MkdirAll(filepath.Dir(socketPath), 0755); err != nil { + return fmt.Errorf("creating parent directory for %s: %w", socketPath, err) + } + + // Remove any stale socket file. + if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("removing stale socket file %s: %w", socketPath, err) + } + + l, err := net.Listen("unix", socketPath) + if err != nil { + return fmt.Errorf("listening on %s: %w", socketPath, err) + } + + s.listeners = append(s.listeners, l) + + log.L.WithFields(log.Fields{ + "forward_id": forwardID, + "socket_path": socketPath, + }).Info("socketforward: listening for forwarded connections") + + go s.acceptLoop(context.Background(), l, forwardID) + return nil +} + +func (s *Service) acceptLoop(ctx context.Context, l net.Listener, forwardID string) { + for { + conn, err := l.Accept() + if err != nil { + if !errors.Is(err, net.ErrClosed) { + log.L.WithError(err).Error("socketforward: accept error") + } + return + } + go s.handleConnection(ctx, conn, forwardID) + } +} + +func (s *Service) handleConnection(ctx context.Context, udsConn net.Conn, forwardID string) { + streamID, err := GenerateStreamID("socketfwd") + if err != nil { + log.L.WithError(err).Error("socketforward: generating stream ID") + udsConn.Close() + return + } + + log.L.WithFields(log.Fields{ + "stream_id": streamID, + "forward_id": forwardID, + }).Debug("socketforward: new forwarded connection") + + resultCh := make(chan error, 1) + s.mu.Lock() + s.pending[streamID] = resultCh + s.mu.Unlock() + + req := &socketforward.ConnectRequest{ + StreamID: streamID, + ForwardID: forwardID, + } + select { + case s.notify <- req: + default: + s.mu.Lock() + delete(s.pending, streamID) + s.mu.Unlock() + log.G(ctx).WithFields(log.Fields{ + "stream_id": streamID, + "forward_id": forwardID, + }).Error("socketforward: notify channel full, dropping connection") + udsConn.Close() + return + } + + // Wait for the host to report whether it could dial the target socket. + var dialErr error + select { + case dialErr = <-resultCh: + case <-ctx.Done(): + s.mu.Lock() + delete(s.pending, streamID) + s.mu.Unlock() + udsConn.Close() + return + } + + if dialErr != nil { + log.L.WithError(dialErr).WithFields(log.Fields{ + "stream_id": streamID, + "forward_id": forwardID, + }).Error("socketforward: host failed to dial target socket") + udsConn.Close() + return + } + + // The host opens the vsock stream before sending the success + // ConnectResult, so the stream is already registered by the time we + // reach here. + vsockConn, err := s.streams.Get(streamID) + if err != nil { + log.L.WithError(err).WithField("stream_id", streamID).Error("socketforward: vsock stream not found after successful ConnectResult") + udsConn.Close() + return + } + + relay(ctx, udsConn, vsockConn) +} + +// relay copies data bidirectionally between two connections until one side +// closes or errors. Both connections are closed when the relay finishes. +func relay(ctx context.Context, a, b io.ReadWriteCloser) { + done := make(chan struct{}, 2) + cp := func(dst io.Writer, src io.Reader) { + io.Copy(dst, src) + done <- struct{}{} + } + go cp(a, b) + go cp(b, a) + + // Wait for one direction to finish, then tear down both. + select { + case <-done: + case <-ctx.Done(): + } + a.Close() + b.Close() + <-done // wait for the second goroutine +} + +// Shutdown closes all active listeners. +func (s *Service) Shutdown(_ context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + var errs []error + for _, l := range s.listeners { + if err := l.Close(); err != nil { + errs = append(errs, err) + } + } + s.listeners = nil + return errors.Join(errs...) +} diff --git a/plugins/vminit/socketforward/plugin.go b/plugins/vminit/socketforward/plugin.go new file mode 100644 index 0000000..5767f8a --- /dev/null +++ b/plugins/vminit/socketforward/plugin.go @@ -0,0 +1,56 @@ +//go:build linux + +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package socketforward + +import ( + "github.com/containerd/containerd/v2/pkg/shutdown" + cplugins "github.com/containerd/containerd/v2/plugins" + "github.com/containerd/plugin" + "github.com/containerd/plugin/registry" + + sf "github.com/containerd/nerdbox/internal/vminit/socketforward" + "github.com/containerd/nerdbox/internal/vminit/stream" + "github.com/containerd/nerdbox/plugins" +) + +func init() { + registry.Register(&plugin.Registration{ + Type: cplugins.TTRPCPlugin, + ID: "socketforward", + Requires: []plugin.Type{ + cplugins.InternalPlugin, + plugins.StreamingPlugin, + }, + InitFn: func(ic *plugin.InitContext) (interface{}, error) { + ss, err := ic.GetByID(cplugins.InternalPlugin, "shutdown") + if err != nil { + return nil, err + } + sm, err := ic.GetByID(plugins.StreamingPlugin, "vsock") + if err != nil { + return nil, err + } + + svc := sf.NewService(sm.(stream.Manager)) + ss.(shutdown.Service).RegisterCallback(svc.Shutdown) + + return svc, nil + }, + }) +}