From a4aec98577fd18fe4ad63e32b0232009e7794f43 Mon Sep 17 00:00:00 2001 From: Ben Firshman Date: Wed, 4 Mar 2026 20:14:03 -0600 Subject: [PATCH 1/3] feat: bind to 127.0.0.1 by default instead of 0.0.0.0 Port bindings now default to localhost-only, preventing prediction endpoints from being exposed to the entire network during development. - Add HostIP field to command.Port struct (defaults to 127.0.0.1) - Add --host flag to cog serve (default 127.0.0.1, use 0.0.0.0 to expose) - Support host:port syntax in cog run -p (e.g. -p 0.0.0.0:8888) - Bind cog predict/train to 127.0.0.1 - Update GetHostPortForContainer to match configured host IP --- pkg/cli/run.go | 18 ++++++-- pkg/cli/serve.go | 14 +++++- pkg/docker/command/command.go | 1 + pkg/docker/docker.go | 6 ++- pkg/docker/run.go | 11 +++-- pkg/docker/run_test.go | 81 +++++++++++++++++++++++++++++------ pkg/predict/predictor.go | 5 ++- 7 files changed, 111 insertions(+), 25 deletions(-) diff --git a/pkg/cli/run.go b/pkg/cli/run.go index 5a32d9aa94..0b3ab2f122 100644 --- a/pkg/cli/run.go +++ b/pkg/cli/run.go @@ -1,6 +1,7 @@ package cli import ( + "fmt" "os" "strconv" "strings" @@ -58,7 +59,7 @@ exploring the environment your model will run in.`, // Flags after first argument are considered args and passed to command // This is called `publish` for consistency with `docker run` - cmd.Flags().StringArrayVarP(&runPorts, "publish", "p", []string{}, "Publish a container's port to the host, e.g. -p 8000") + cmd.Flags().StringArrayVarP(&runPorts, "publish", "p", []string{}, "Publish a container's port to the host, e.g. -p 8000 or -p 0.0.0.0:8000") cmd.Flags().StringArrayVarP(&envFlags, "env", "e", []string{}, "Environment variables, in the form name=value") flags.SetInterspersed(false) @@ -113,12 +114,21 @@ func run(cmd *cobra.Command, args []string) error { } for _, portString := range runPorts { - port, err := strconv.Atoi(portString) + hostIP := "127.0.0.1" + portStr := portString + + // Support host:port syntax (e.g. "0.0.0.0:8000") + if idx := strings.LastIndex(portString, ":"); idx != -1 { + hostIP = portString[:idx] + portStr = portString[idx+1:] + } + + port, err := strconv.Atoi(portStr) if err != nil { - return err + return fmt.Errorf("invalid port %q: %w", portString, err) } - runOptions.Ports = append(runOptions.Ports, command.Port{HostPort: port, ContainerPort: port}) + runOptions.Ports = append(runOptions.Ports, command.Port{HostPort: port, ContainerPort: port, HostIP: hostIP}) } console.Info("") diff --git a/pkg/cli/serve.go b/pkg/cli/serve.go index e497469c56..7922fdad10 100644 --- a/pkg/cli/serve.go +++ b/pkg/cli/serve.go @@ -15,6 +15,7 @@ import ( ) var ( + host = "127.0.0.1" port = 8393 uploadURL = "" ) @@ -33,6 +34,9 @@ and outputs as a REST API. Compatible with the Cog HTTP protocol.`, # Start on a custom port cog serve -p 5000 + # Listen on all interfaces (e.g. to expose to the network) + cog serve --host 0.0.0.0 + # Test the server curl http://localhost:8393/predictions \ -X POST \ @@ -49,6 +53,7 @@ and outputs as a REST API. Compatible with the Cog HTTP protocol.`, addGpusFlag(cmd) addConfigFlag(cmd) + cmd.Flags().StringVar(&host, "host", host, "Host to bind on (use 0.0.0.0 for all interfaces)") cmd.Flags().IntVarP(&port, "port", "p", port, "Port on which to listen") cmd.Flags().StringVar(&uploadURL, "upload-url", "", "Upload URL for file outputs (e.g. https://example.com/upload/)") @@ -130,12 +135,17 @@ func cmdServe(cmd *cobra.Command, arg []string) error { runOptions.ExtraHosts = []string{"host.docker.internal:host-gateway"} } - runOptions.Ports = append(runOptions.Ports, command.Port{HostPort: port, ContainerPort: 5000}) + runOptions.Ports = append(runOptions.Ports, command.Port{HostPort: port, ContainerPort: 5000, HostIP: host}) + + displayHost := host + if displayHost == "0.0.0.0" { + displayHost = "localhost" + } console.Info("") console.Infof("Running %[1]s in Docker with the current directory mounted as a volume...", console.Bold(strings.Join(args, " "))) console.Info("") - console.Infof("Serving at %s", console.Bold(fmt.Sprintf("http://127.0.0.1:%v", port))) + console.Infof("Serving at %s", console.Bold(fmt.Sprintf("http://%s:%v", displayHost, port))) console.Info("") err = docker.Run(ctx, dockerClient, runOptions) diff --git a/pkg/docker/command/command.go b/pkg/docker/command/command.go index 1d3539a042..71d525354a 100644 --- a/pkg/docker/command/command.go +++ b/pkg/docker/command/command.go @@ -69,6 +69,7 @@ type RunOptions struct { type Port struct { HostPort int ContainerPort int + HostIP string // Host IP to bind to. Defaults to "127.0.0.1" if empty. } type Volume struct { diff --git a/pkg/docker/docker.go b/pkg/docker/docker.go index cfc9a578b6..77cce56b39 100644 --- a/pkg/docker/docker.go +++ b/pkg/docker/docker.go @@ -434,10 +434,14 @@ func (c *apiClient) containerRun(ctx context.Context, options command.RunOptions if len(options.Ports) > 0 { hostCfg.PortBindings = make(nat.PortMap) for _, port := range options.Ports { + hostIP := port.HostIP + if hostIP == "" { + hostIP = "127.0.0.1" + } containerPort := nat.Port(fmt.Sprintf("%d/tcp", port.ContainerPort)) hostCfg.PortBindings[containerPort] = []nat.PortBinding{ { - HostIP: "", // use empty string to bind to all interfaces + HostIP: hostIP, HostPort: strconv.Itoa(port.HostPort), }, } diff --git a/pkg/docker/run.go b/pkg/docker/run.go index 44cc2a1373..07cd718ff6 100644 --- a/pkg/docker/run.go +++ b/pkg/docker/run.go @@ -34,9 +34,13 @@ func RunDaemon(ctx context.Context, dockerClient command.Command, options comman return dockerClient.ContainerStart(ctx, options) } -func GetHostPortForContainer(ctx context.Context, dockerCommand command.Command, containerID string, containerPort int) (int, error) { +func GetHostPortForContainer(ctx context.Context, dockerCommand command.Command, containerID string, containerPort int, hostIP string) (int, error) { console.Debugf("=== DockerCommand.GetPort %s/%d", containerID, containerPort) + if hostIP == "" { + hostIP = "127.0.0.1" + } + inspect, err := dockerCommand.ContainerInspect(ctx, containerID) if err != nil { return 0, fmt.Errorf("failed to inspect container %q: %w", containerID, err) @@ -56,8 +60,7 @@ func GetHostPortForContainer(ctx context.Context, dockerCommand command.Command, } for _, portBinding := range inspect.NetworkSettings.Ports[targetPort] { - // TODO[md]: this should not be hardcoded since docker may be bound to a different address - if portBinding.HostIP != "0.0.0.0" { + if portBinding.HostIP != hostIP { continue } hostPort, err := nat.ParsePort(portBinding.HostPort) @@ -67,5 +70,5 @@ func GetHostPortForContainer(ctx context.Context, dockerCommand command.Command, return hostPort, nil } - return 0, fmt.Errorf("container %s does not have a port bound to 0.0.0.0", containerID) + return 0, fmt.Errorf("container %s does not have a port bound to %s", containerID, hostIP) } diff --git a/pkg/docker/run_test.go b/pkg/docker/run_test.go index ca829fee6d..4f95dee33f 100644 --- a/pkg/docker/run_test.go +++ b/pkg/docker/run_test.go @@ -13,6 +13,63 @@ import ( func TestGetHostPortForContainer(t *testing.T) { t.Run("WithExposedPort", func(t *testing.T) { + testClient := dockertest.NewMockCommand2(t) + testClient.EXPECT().ContainerInspect(t.Context(), "container123").Return(&container.InspectResponse{ + ContainerJSONBase: &container.ContainerJSONBase{ + State: &container.State{ + Status: "running", + Running: true, + }, + }, + NetworkSettings: &container.NetworkSettings{ + NetworkSettingsBase: container.NetworkSettingsBase{ + Ports: nat.PortMap{ + nat.Port("5678/tcp"): []nat.PortBinding{ + { + HostIP: "127.0.0.1", + HostPort: "12345", + }, + }, + }, + }, + }, + }, nil) + + hostPort, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678, "127.0.0.1") + require.NoError(t, err) + require.Equal(t, 12345, hostPort) + }) + + t.Run("WithExposedPortDefaultHostIP", func(t *testing.T) { + testClient := dockertest.NewMockCommand2(t) + testClient.EXPECT().ContainerInspect(t.Context(), "container123").Return(&container.InspectResponse{ + ContainerJSONBase: &container.ContainerJSONBase{ + State: &container.State{ + Status: "running", + Running: true, + }, + }, + NetworkSettings: &container.NetworkSettings{ + NetworkSettingsBase: container.NetworkSettingsBase{ + Ports: nat.PortMap{ + nat.Port("5678/tcp"): []nat.PortBinding{ + { + HostIP: "127.0.0.1", + HostPort: "12345", + }, + }, + }, + }, + }, + }, nil) + + // Empty hostIP should default to 127.0.0.1 + hostPort, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678, "") + require.NoError(t, err) + require.Equal(t, 12345, hostPort) + }) + + t.Run("WithExposedPortAllInterfaces", func(t *testing.T) { testClient := dockertest.NewMockCommand2(t) testClient.EXPECT().ContainerInspect(t.Context(), "container123").Return(&container.InspectResponse{ ContainerJSONBase: &container.ContainerJSONBase{ @@ -35,7 +92,7 @@ func TestGetHostPortForContainer(t *testing.T) { }, }, nil) - hostPort, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678) + hostPort, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678, "0.0.0.0") require.NoError(t, err) require.Equal(t, 12345, hostPort) }) @@ -54,11 +111,11 @@ func TestGetHostPortForContainer(t *testing.T) { Ports: nat.PortMap{ nat.Port("5678/tcp"): []nat.PortBinding{ { - HostIP: "0.0.0.0", + HostIP: "127.0.0.1", HostPort: "12345", }, { - HostIP: "0.0.0.0", + HostIP: "127.0.0.1", HostPort: "54321", }, }, @@ -67,7 +124,7 @@ func TestGetHostPortForContainer(t *testing.T) { }, }, nil) - hostPort, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678) + hostPort, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678, "127.0.0.1") require.NoError(t, err) require.Equal(t, 12345, hostPort) }) @@ -86,7 +143,7 @@ func TestGetHostPortForContainer(t *testing.T) { Ports: nat.PortMap{ nat.Port("5678/tcp"): []nat.PortBinding{ { - HostIP: "127.0.0.1", + HostIP: "0.0.0.0", HostPort: "12345", }, }, @@ -95,8 +152,8 @@ func TestGetHostPortForContainer(t *testing.T) { }, }, nil) - _, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678) - require.ErrorContains(t, err, "does not have a port bound to 0.0.0.0") + _, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678, "127.0.0.1") + require.ErrorContains(t, err, "does not have a port bound to 127.0.0.1") }) t.Run("WithDifferentPortExposed", func(t *testing.T) { @@ -113,7 +170,7 @@ func TestGetHostPortForContainer(t *testing.T) { Ports: nat.PortMap{ nat.Port("1234/tcp"): []nat.PortBinding{ { - HostIP: "0.0.0.0", + HostIP: "127.0.0.1", HostPort: "12345", }, }, @@ -122,8 +179,8 @@ func TestGetHostPortForContainer(t *testing.T) { }, }, nil) - _, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678) - require.ErrorContains(t, err, "does not have a port bound to 0.0.0.0") + _, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678, "127.0.0.1") + require.ErrorContains(t, err, "does not have a port bound to 127.0.0.1") }) t.Run("WithNoExposedPort", func(t *testing.T) { @@ -137,7 +194,7 @@ func TestGetHostPortForContainer(t *testing.T) { }, }, nil) - _, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678) + _, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678, "127.0.0.1") require.ErrorContains(t, err, "does not have expected network configuration") }) @@ -152,7 +209,7 @@ func TestGetHostPortForContainer(t *testing.T) { }, }, nil) - _, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678) + _, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678, "127.0.0.1") require.ErrorContains(t, err, "is not running") }) } diff --git a/pkg/predict/predictor.go b/pkg/predict/predictor.go index 7f72d751e5..8c52441b63 100644 --- a/pkg/predict/predictor.go +++ b/pkg/predict/predictor.go @@ -75,15 +75,16 @@ func NewPredictor(ctx context.Context, runOptions command.RunOptions, isTrain bo func (p *Predictor) Start(ctx context.Context, logsWriter io.Writer, timeout time.Duration) error { var err error containerPort := 5000 + hostIP := "127.0.0.1" - p.runOptions.Ports = append(p.runOptions.Ports, command.Port{HostPort: 0, ContainerPort: containerPort}) + p.runOptions.Ports = append(p.runOptions.Ports, command.Port{HostPort: 0, ContainerPort: containerPort, HostIP: hostIP}) p.containerID, err = docker.RunDaemon(ctx, p.dockerClient, p.runOptions, logsWriter) if err != nil { return fmt.Errorf("Failed to start container: %w", err) } - p.port, err = docker.GetHostPortForContainer(ctx, p.dockerClient, p.containerID, containerPort) + p.port, err = docker.GetHostPortForContainer(ctx, p.dockerClient, p.containerID, containerPort, hostIP) if err != nil { return fmt.Errorf("Failed to determine container port: %w", err) } From f5ff4c2723c6ef003641f786ee4e2aa8ae40d448 Mon Sep 17 00:00:00 2001 From: Ben Firshman Date: Wed, 4 Mar 2026 20:21:17 -0600 Subject: [PATCH 2/3] docs: regenerate CLI and LLM docs for new --host flag --- docs/cli.md | 6 +++++- docs/llms.txt | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/docs/cli.md b/docs/cli.md index 5467661b0a..ca7fdcac45 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -234,7 +234,7 @@ cog run [arg...] [flags] --gpus docker run --gpus GPU devices to add to the container, in the same format as docker run --gpus. -h, --help help for run --progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto") - -p, --publish stringArray Publish a container's port to the host, e.g. -p 8000 + -p, --publish stringArray Publish a container's port to the host, e.g. -p 8000 or -p 0.0.0.0:8000 --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") ``` @@ -258,6 +258,9 @@ cog serve [flags] # Start on a custom port cog serve -p 5000 + # Listen on all interfaces (e.g. to expose to the network) + cog serve --host 0.0.0.0 + # Test the server curl http://localhost:8393/predictions \ -X POST \ @@ -271,6 +274,7 @@ cog serve [flags] -f, --file string The name of the config file. (default "cog.yaml") --gpus docker run --gpus GPU devices to add to the container, in the same format as docker run --gpus. -h, --help help for serve + --host string Host to bind on (use 0.0.0.0 for all interfaces) (default "127.0.0.1") -p, --port int Port on which to listen (default 8393) --progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto") --upload-url string Upload URL for file outputs (e.g. https://example.com/upload/) diff --git a/docs/llms.txt b/docs/llms.txt index 8655392cd7..18b25f59f9 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -432,7 +432,7 @@ cog run [arg...] [flags] --gpus docker run --gpus GPU devices to add to the container, in the same format as docker run --gpus. -h, --help help for run --progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto") - -p, --publish stringArray Publish a container's port to the host, e.g. -p 8000 + -p, --publish stringArray Publish a container's port to the host, e.g. -p 8000 or -p 0.0.0.0:8000 --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") ``` @@ -456,6 +456,9 @@ cog serve [flags] # Start on a custom port cog serve -p 5000 + # Listen on all interfaces (e.g. to expose to the network) + cog serve --host 0.0.0.0 + # Test the server curl http://localhost:8393/predictions \ -X POST \ @@ -469,6 +472,7 @@ cog serve [flags] -f, --file string The name of the config file. (default "cog.yaml") --gpus docker run --gpus GPU devices to add to the container, in the same format as docker run --gpus. -h, --help help for serve + --host string Host to bind on (use 0.0.0.0 for all interfaces) (default "127.0.0.1") -p, --port int Port on which to listen (default 8393) --progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto") --upload-url string Upload URL for file outputs (e.g. https://example.com/upload/) From f0f1a08ea79b5b07731bd026c94b4100b6ae1a82 Mon Sep 17 00:00:00 2001 From: asahoo Date: Fri, 26 Jun 2026 14:47:10 -0500 Subject: [PATCH 3/3] fix: address review comments on localhost-by-default binding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Restore host:port parsing for `cog exec -p` lost in the cog run→exec rename; support plain ports, IPv4, bare and bracketed IPv6, with port range and empty-host validation - Extract shared command.DefaultHostIP constant (was duplicated) - Narrow GetHostPortForContainer fallback to single wildcard/empty bindings - Warn when serving via a remote Docker daemon; show a navigable localhost URL alongside 0.0.0.0; bracket IPv6 display hosts - Rename generic serve `host` var to `serveHost` - Add unit tests for port binding helpers, host parsing, remote-host detection, and serve URL formatting; make run_test table-driven - Regenerate docs/cli.md and docs/llms.txt --- docs/cli.md | 2 +- docs/llms.txt | 2 +- pkg/cli/exec.go | 82 ++++++++- pkg/cli/exec_test.go | 114 +++++++++++++ pkg/cli/serve.go | 39 ++++- pkg/cli/serve_test.go | 49 ++++++ pkg/docker/command/command.go | 7 +- pkg/docker/docker.go | 52 +++--- pkg/docker/host.go | 17 ++ pkg/docker/host_test.go | 48 ++++++ pkg/docker/ports_test.go | 118 +++++++++++++ pkg/docker/run.go | 16 +- pkg/docker/run_test.go | 312 ++++++++++++++-------------------- pkg/predict/predictor.go | 2 +- 14 files changed, 634 insertions(+), 226 deletions(-) create mode 100644 pkg/cli/exec_test.go create mode 100644 pkg/cli/serve_test.go create mode 100644 pkg/docker/host_test.go create mode 100644 pkg/docker/ports_test.go diff --git a/docs/cli.md b/docs/cli.md index 31852d6cdf..eacd27e5d3 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -125,7 +125,7 @@ cog exec [arg...] [flags] --gpus docker run --gpus GPU devices to add to the container, in the same format as docker run --gpus. -h, --help help for exec --progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto") - -p, --publish stringArray Publish a container's port to the host, e.g. -p 8000 + -p, --publish stringArray Publish a container's port to the host, e.g. -p 8000 or -p 0.0.0.0:8000 --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") ``` diff --git a/docs/llms.txt b/docs/llms.txt index f4ad21f7e5..820c38320f 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -371,7 +371,7 @@ cog exec [arg...] [flags] --gpus docker run --gpus GPU devices to add to the container, in the same format as docker run --gpus. -h, --help help for exec --progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto") - -p, --publish stringArray Publish a container's port to the host, e.g. -p 8000 + -p, --publish stringArray Publish a container's port to the host, e.g. -p 8000 or -p 0.0.0.0:8000 --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") ``` diff --git a/pkg/cli/exec.go b/pkg/cli/exec.go index 4d14d5b320..42d5016b59 100644 --- a/pkg/cli/exec.go +++ b/pkg/cli/exec.go @@ -2,6 +2,7 @@ package cli import ( "errors" + "fmt" "os" "strconv" "strings" @@ -59,7 +60,7 @@ exploring the environment your model will run in.`, // Flags after first argument are considered args and passed to command // This is called `publish` for consistency with `docker run` - cmd.Flags().StringArrayVarP(&execPorts, "publish", "p", []string{}, "Publish a container's port to the host, e.g. -p 8000") + cmd.Flags().StringArrayVarP(&execPorts, "publish", "p", []string{}, "Publish a container's port to the host, e.g. -p 8000 or -p 0.0.0.0:8000") cmd.Flags().StringArrayVarP(&envFlags, "env", "e", []string{}, "Environment variables, in the form name=value") flags.SetInterspersed(false) @@ -67,6 +68,74 @@ exploring the environment your model will run in.`, return cmd } +// parsePublishFlags parses the values passed to `cog exec -p`. Each value may +// be either a port number ("8000") or a host:port pair ("0.0.0.0:8000" or +// "[::1]:8000"). When no host is given, the port is bound to +// command.DefaultHostIP. +func parsePublishFlags(values []string) ([]command.Port, error) { + ports := make([]command.Port, 0, len(values)) + for _, portString := range values { + hostIP, portStr, err := splitPublishFlag(portString) + if err != nil { + return nil, err + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return nil, fmt.Errorf("invalid port %q: %w", portString, err) + } + if port < 1 || port > 65535 { + return nil, fmt.Errorf("invalid port %q: port must be between 1 and 65535", portString) + } + + ports = append(ports, command.Port{HostPort: port, ContainerPort: port, HostIP: hostIP}) + } + return ports, nil +} + +// splitPublishFlag splits a publish flag value into host and port parts. It +// supports plain ports ("8000"), IPv4 host:port ("0.0.0.0:8000"), bare IPv6 +// ("::1:8000"), and bracketed IPv6 ("[::1]:8000"). +func splitPublishFlag(value string) (host, port string, err error) { + host = command.DefaultHostIP + port = value + + if value == "" { + return "", "", fmt.Errorf("invalid port %q: value cannot be empty", value) + } + + // Bracketed IPv6 form: [::1]:8000 + if strings.HasPrefix(value, "[") { + end := strings.Index(value, "]") + if end == -1 { + return "", "", fmt.Errorf("invalid port %q: missing closing bracket for IPv6 address", value) + } + if end == len(value)-1 { + return "", "", fmt.Errorf("invalid port %q: port is required after IPv6 address", value) + } + if value[end+1] != ':' { + return "", "", fmt.Errorf("invalid port %q: expected ':' after ']'", value) + } + host = value[1:end] + port = value[end+2:] + if host == "" { + return "", "", fmt.Errorf("invalid port %q: host cannot be empty", value) + } + return host, port, nil + } + + // Standard host:port form, splitting on the last colon to tolerate IPv6. + if idx := strings.LastIndex(value, ":"); idx != -1 { + host = value[:idx] + port = value[idx+1:] + if host == "" { + return "", "", fmt.Errorf("invalid port %q: host cannot be empty", value) + } + } + + return host, port, nil +} + func execCmd(cmd *cobra.Command, args []string) error { ctx := cmd.Context() @@ -114,14 +183,11 @@ func execCmd(cmd *cobra.Command, args []string) error { Workdir: "/src", } - for _, portString := range execPorts { - port, err := strconv.Atoi(portString) - if err != nil { - return err - } - - runOptions.Ports = append(runOptions.Ports, command.Port{HostPort: port, ContainerPort: port}) + ports, err := parsePublishFlags(execPorts) + if err != nil { + return err } + runOptions.Ports = ports console.Info("") console.Infof("Running %s in Docker with the current directory mounted as a volume...", console.Bold(strings.Join(args, " "))) diff --git a/pkg/cli/exec_test.go b/pkg/cli/exec_test.go new file mode 100644 index 0000000000..bc70e2728c --- /dev/null +++ b/pkg/cli/exec_test.go @@ -0,0 +1,114 @@ +package cli + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/replicate/cog/pkg/docker/command" +) + +func TestParsePublishFlags(t *testing.T) { + tests := []struct { + name string + values []string + wantPorts []command.Port + wantErr string + }{ + { + name: "empty", + values: []string{}, + wantPorts: []command.Port{}, + }, + { + name: "port only", + values: []string{"8000"}, + wantPorts: []command.Port{ + {HostPort: 8000, ContainerPort: 8000, HostIP: command.DefaultHostIP}, + }, + }, + { + name: "host:port", + values: []string{"0.0.0.0:8000"}, + wantPorts: []command.Port{ + {HostPort: 8000, ContainerPort: 8000, HostIP: "0.0.0.0"}, + }, + }, + { + name: "IPv6 host:port", + values: []string{"::1:8000"}, + wantPorts: []command.Port{ + {HostPort: 8000, ContainerPort: 8000, HostIP: "::1"}, + }, + }, + { + name: "multiple ports", + values: []string{"8000", "0.0.0.0:8888"}, + wantPorts: []command.Port{ + {HostPort: 8000, ContainerPort: 8000, HostIP: command.DefaultHostIP}, + {HostPort: 8888, ContainerPort: 8888, HostIP: "0.0.0.0"}, + }, + }, + { + name: "bracketed IPv6 host:port", + values: []string{"[::1]:8000"}, + wantPorts: []command.Port{ + {HostPort: 8000, ContainerPort: 8000, HostIP: "::1"}, + }, + }, + { + name: "bracketed IPv6 with zone", + values: []string{"[::1%lo0]:8000"}, + wantPorts: []command.Port{ + {HostPort: 8000, ContainerPort: 8000, HostIP: "::1%lo0"}, + }, + }, + { + name: "invalid port", + values: []string{"not-a-port"}, + wantErr: "invalid port", + }, + { + name: "empty value", + values: []string{""}, + wantErr: "cannot be empty", + }, + { + name: "empty host", + values: []string{":8000"}, + wantErr: "host cannot be empty", + }, + { + name: "empty port", + values: []string{"0.0.0.0:"}, + wantErr: "invalid port", + }, + { + name: "port out of range", + values: []string{"99999"}, + wantErr: "between 1 and 65535", + }, + { + name: "missing closing bracket", + values: []string{"[::1:8000"}, + wantErr: "missing closing bracket", + }, + { + name: "port after bracket without colon", + values: []string{"[::1]8000"}, + wantErr: "expected ':' after ']'", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ports, err := parsePublishFlags(tt.values) + if tt.wantErr != "" { + require.ErrorContains(t, err, tt.wantErr) + return + } + require.NoError(t, err) + require.Equal(t, tt.wantPorts, ports) + }) + } +} diff --git a/pkg/cli/serve.go b/pkg/cli/serve.go index 8dbff7e6d3..b848520658 100644 --- a/pkg/cli/serve.go +++ b/pkg/cli/serve.go @@ -3,7 +3,9 @@ package cli import ( "errors" "fmt" + "net" "os" + "strconv" "strings" "github.com/spf13/cobra" @@ -17,7 +19,7 @@ import ( ) var ( - host = "127.0.0.1" + serveHost = command.DefaultHostIP port = 8393 uploadURL = "" ) @@ -60,7 +62,7 @@ the Docker port mapping is published on.`, addGpusFlag(cmd) addConfigFlag(cmd) - cmd.Flags().StringVar(&host, "host", host, "Host IP to publish the container port on. Use 0.0.0.0 to allow connections from other machines.") + cmd.Flags().StringVar(&serveHost, "host", serveHost, "Host IP to publish the container port on. Use 0.0.0.0 to allow connections from other machines.") cmd.Flags().IntVarP(&port, "port", "p", port, "Port on which to listen") cmd.Flags().StringVar(&uploadURL, "upload-url", "", "Upload URL for file outputs (e.g. https://example.com/upload/)") @@ -81,6 +83,28 @@ func serveBuildOptions(cmd *cobra.Command) model.BuildOptions { } } +// displayHostForServe returns the host string to show in the "Serving at" URL. +// Loopback bindings are displayed as "localhost" for clarity; any other +// address is returned as-is so the URL reflects the actual binding. +func displayHostForServe(host string) string { + if host == command.DefaultHostIP || host == "::1" { + return "localhost" + } + return host +} + +// formatServeURL builds the "Serving at" URL for the given bind host and port. +// When bound to all interfaces (0.0.0.0), it also shows the usable localhost +// URL since 0.0.0.0 is not a navigable address. +func formatServeURL(host string, port int) string { + url := fmt.Sprintf("http://%s", net.JoinHostPort(displayHostForServe(host), strconv.Itoa(port))) + if host == "0.0.0.0" { + localhostURL := fmt.Sprintf("http://%s", net.JoinHostPort("localhost", strconv.Itoa(port))) + url = fmt.Sprintf("%s (%s)", url, localhostURL) + } + return url +} + func cmdServe(cmd *cobra.Command, arg []string) error { ctx := cmd.Context() @@ -168,17 +192,18 @@ func cmdServe(cmd *cobra.Command, arg []string) error { runOptions.ExtraHosts = []string{"host.docker.internal:host-gateway"} } - runOptions.Ports = append(runOptions.Ports, command.Port{HostPort: port, ContainerPort: 5000, HostIP: host}) + runOptions.Ports = append(runOptions.Ports, command.Port{HostPort: port, ContainerPort: 5000, HostIP: serveHost}) + + serveURL := formatServeURL(serveHost, port) - displayHost := host - if displayHost == "0.0.0.0" { - displayHost = "localhost" + if isRemote, dockerHost, err := docker.IsRemoteDockerHost(); err == nil && isRemote { + console.Warnf("Using Docker daemon at %s; the server will bind to %s on that host, not this machine.", dockerHost, serveHost) } console.Info("") console.Infof("Running %[1]s in Docker with the current directory mounted as a volume...", console.Bold(strings.Join(args, " "))) console.Info("") - console.Infof("Serving at %s", console.Bold(fmt.Sprintf("http://%s:%v", displayHost, port))) + console.Infof("Serving at %s", console.Bold(serveURL)) console.Info("") err = docker.Run(ctx, dockerClient, runOptions) diff --git a/pkg/cli/serve_test.go b/pkg/cli/serve_test.go new file mode 100644 index 0000000000..e665522e3b --- /dev/null +++ b/pkg/cli/serve_test.go @@ -0,0 +1,49 @@ +package cli + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/replicate/cog/pkg/docker/command" +) + +func TestDisplayHostForServe(t *testing.T) { + tests := []struct { + name string + host string + want string + }{ + {"default localhost", command.DefaultHostIP, "localhost"}, + {"all interfaces", "0.0.0.0", "0.0.0.0"}, + {"custom IP", "192.168.1.1", "192.168.1.1"}, + {"IPv6 localhost", "::1", "localhost"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, displayHostForServe(tt.host)) + }) + } +} + +func TestFormatServeURL(t *testing.T) { + tests := []struct { + name string + host string + port int + want string + }{ + {"default localhost", command.DefaultHostIP, 8393, "http://localhost:8393"}, + {"IPv6 localhost", "::1", 8393, "http://localhost:8393"}, + {"all interfaces shows localhost too", "0.0.0.0", 8393, "http://0.0.0.0:8393 (http://localhost:8393)"}, + {"custom IPv4", "192.168.1.1", 5000, "http://192.168.1.1:5000"}, + {"custom IPv6 is bracketed", "fe80::1", 5000, "http://[fe80::1]:5000"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, formatServeURL(tt.host, tt.port)) + }) + } +} diff --git a/pkg/docker/command/command.go b/pkg/docker/command/command.go index 3770419eea..4751a843cf 100644 --- a/pkg/docker/command/command.go +++ b/pkg/docker/command/command.go @@ -76,6 +76,11 @@ type ImageBuildOptions struct { BuildArgs map[string]*string } +// DefaultHostIP is the host interface that container ports are published on +// when no explicit HostIP is provided. Binding to localhost-only prevents +// prediction endpoints from being accidentally exposed to the network. +const DefaultHostIP = "127.0.0.1" + type RunOptions struct { Detach bool Args []string @@ -94,7 +99,7 @@ type RunOptions struct { type Port struct { HostPort int ContainerPort int - HostIP string // Host IP to bind to. Defaults to "127.0.0.1" if empty. + HostIP string // Host IP to bind to. Defaults to DefaultHostIP if empty. } type Volume struct { diff --git a/pkg/docker/docker.go b/pkg/docker/docker.go index e127052086..29803173e5 100644 --- a/pkg/docker/docker.go +++ b/pkg/docker/docker.go @@ -424,11 +424,7 @@ func (c *apiClient) containerRun(ctx context.Context, options command.RunOptions } if len(options.Ports) > 0 { - containerCfg.ExposedPorts = make(nat.PortSet) - for _, port := range options.Ports { - containerPort := nat.Port(fmt.Sprintf("%d/tcp", port.ContainerPort)) - containerCfg.ExposedPorts[containerPort] = struct{}{} - } + containerCfg.ExposedPorts = exposedPortsFromRunOptions(options.Ports) } hostCfg := &container.HostConfig{ @@ -450,20 +446,7 @@ func (c *apiClient) containerRun(ctx context.Context, options command.RunOptions // Configure port bindings if len(options.Ports) > 0 { - hostCfg.PortBindings = make(nat.PortMap) - for _, port := range options.Ports { - hostIP := port.HostIP - if hostIP == "" { - hostIP = "127.0.0.1" - } - containerPort := nat.Port(fmt.Sprintf("%d/tcp", port.ContainerPort)) - hostCfg.PortBindings[containerPort] = []nat.PortBinding{ - { - HostIP: hostIP, - HostPort: strconv.Itoa(port.HostPort), - }, - } - } + hostCfg.PortBindings = portBindingsFromRunOptions(options.Ports) } // Configure volume bindings @@ -626,6 +609,37 @@ func (c *apiClient) ContainerStart(ctx context.Context, options command.RunOptio return id, err } +// exposedPortsFromRunOptions returns the set of container ports that must be +// exposed for the given port mappings. +func exposedPortsFromRunOptions(ports []command.Port) nat.PortSet { + exposed := make(nat.PortSet, len(ports)) + for _, port := range ports { + containerPort := nat.Port(fmt.Sprintf("%d/tcp", port.ContainerPort)) + exposed[containerPort] = struct{}{} + } + return exposed +} + +// portBindingsFromRunOptions returns Docker port bindings for the given port +// mappings. Empty HostIP values default to command.DefaultHostIP. +func portBindingsFromRunOptions(ports []command.Port) nat.PortMap { + bindings := make(nat.PortMap, len(ports)) + for _, port := range ports { + hostIP := port.HostIP + if hostIP == "" { + hostIP = command.DefaultHostIP + } + containerPort := nat.Port(fmt.Sprintf("%d/tcp", port.ContainerPort)) + bindings[containerPort] = []nat.PortBinding{ + { + HostIP: hostIP, + HostPort: strconv.Itoa(port.HostPort), + }, + } + } + return bindings +} + // parseGPURequest converts a Docker CLI --gpus string into a DeviceRequest slice func parseGPURequest(opts command.RunOptions) (container.DeviceRequest, error) { if opts.GPUs == "" { diff --git a/pkg/docker/host.go b/pkg/docker/host.go index 47c9ee1d7a..d0ebd7b17a 100644 --- a/pkg/docker/host.go +++ b/pkg/docker/host.go @@ -3,6 +3,7 @@ package docker import ( "fmt" "os" + "strings" dconfig "github.com/docker/cli/cli/config" dctxdocker "github.com/docker/cli/cli/context/docker" @@ -41,6 +42,22 @@ func determineDockerHost() (string, error) { return defaultDockerHost, nil } +// IsRemoteDockerHost reports whether the configured Docker daemon is not a +// local Unix socket or named pipe. When true, port bindings are applied on the +// remote daemon's interfaces, so localhost bindings may not be reachable from +// the local machine. +func IsRemoteDockerHost() (bool, string, error) { + host, err := determineDockerHost() + if err != nil { + return false, "", err + } + return !isLocalDockerHost(host), host, nil +} + +func isLocalDockerHost(host string) bool { + return host == "" || strings.HasPrefix(host, "unix://") || strings.HasPrefix(host, "npipe://") +} + func dockerHostFromContext(contextName string) (string, string, error) { if contextName == "" { cf, err := dconfig.Load(dconfig.Dir()) diff --git a/pkg/docker/host_test.go b/pkg/docker/host_test.go new file mode 100644 index 0000000000..e8262bb6c7 --- /dev/null +++ b/pkg/docker/host_test.go @@ -0,0 +1,48 @@ +package docker + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsLocalDockerHost(t *testing.T) { + tests := []struct { + name string + host string + want bool + }{ + {"empty", "", true}, + {"unix socket", "unix:///var/run/docker.sock", true}, + {"named pipe", "npipe:////./pipe/docker_engine", true}, + {"tcp localhost", "tcp://localhost:2375", false}, + {"tcp loopback IP", "tcp://127.0.0.1:2375", false}, + {"tcp IPv6 loopback", "tcp://[::1]:2375", false}, + {"tcp remote", "tcp://192.168.1.1:2375", false}, + {"ssh", "ssh://user@host", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, isLocalDockerHost(tt.host)) + }) + } +} + +func TestIsRemoteDockerHost(t *testing.T) { + t.Setenv("DOCKER_HOST", "tcp://192.168.1.1:2375") + + isRemote, host, err := IsRemoteDockerHost() + require.NoError(t, err) + require.True(t, isRemote) + require.Equal(t, "tcp://192.168.1.1:2375", host) +} + +func TestIsRemoteDockerHost_Local(t *testing.T) { + t.Setenv("DOCKER_HOST", "unix:///var/run/docker.sock") + + isRemote, host, err := IsRemoteDockerHost() + require.NoError(t, err) + require.False(t, isRemote) + require.Equal(t, "unix:///var/run/docker.sock", host) +} diff --git a/pkg/docker/ports_test.go b/pkg/docker/ports_test.go new file mode 100644 index 0000000000..98b7743431 --- /dev/null +++ b/pkg/docker/ports_test.go @@ -0,0 +1,118 @@ +package docker + +import ( + "testing" + + "github.com/docker/go-connections/nat" + "github.com/stretchr/testify/assert" + + "github.com/replicate/cog/pkg/docker/command" +) + +func TestExposedPortsFromRunOptions(t *testing.T) { + tests := []struct { + name string + ports []command.Port + want nat.PortSet + }{ + { + name: "empty", + ports: nil, + want: nat.PortSet{}, + }, + { + name: "single port", + ports: []command.Port{ + {HostPort: 8080, ContainerPort: 5000, HostIP: command.DefaultHostIP}, + }, + want: nat.PortSet{ + nat.Port("5000/tcp"): {}, + }, + }, + { + name: "multiple ports", + ports: []command.Port{ + {HostPort: 8080, ContainerPort: 5000}, + {HostPort: 8888, ContainerPort: 8888, HostIP: "0.0.0.0"}, + }, + want: nat.PortSet{ + nat.Port("5000/tcp"): {}, + nat.Port("8888/tcp"): {}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, exposedPortsFromRunOptions(tt.ports)) + }) + } +} + +func TestPortBindingsFromRunOptions(t *testing.T) { + tests := []struct { + name string + ports []command.Port + want nat.PortMap + }{ + { + name: "empty", + ports: nil, + want: nat.PortMap{}, + }, + { + name: "default host IP", + ports: []command.Port{ + {HostPort: 8080, ContainerPort: 5000}, + }, + want: nat.PortMap{ + nat.Port("5000/tcp"): { + {HostIP: command.DefaultHostIP, HostPort: "8080"}, + }, + }, + }, + { + name: "explicit host IP", + ports: []command.Port{ + {HostPort: 8080, ContainerPort: 5000, HostIP: "0.0.0.0"}, + }, + want: nat.PortMap{ + nat.Port("5000/tcp"): { + {HostIP: "0.0.0.0", HostPort: "8080"}, + }, + }, + }, + { + name: "IPv6 host IP", + ports: []command.Port{ + {HostPort: 8080, ContainerPort: 5000, HostIP: "::1"}, + }, + want: nat.PortMap{ + nat.Port("5000/tcp"): { + {HostIP: "::1", HostPort: "8080"}, + }, + }, + }, + { + name: "multiple ports", + ports: []command.Port{ + {HostPort: 8080, ContainerPort: 5000}, + {HostPort: 8888, ContainerPort: 8888, HostIP: "0.0.0.0"}, + }, + want: nat.PortMap{ + nat.Port("5000/tcp"): { + {HostIP: command.DefaultHostIP, HostPort: "8080"}, + }, + nat.Port("8888/tcp"): { + {HostIP: "0.0.0.0", HostPort: "8888"}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, portBindingsFromRunOptions(tt.ports)) + }) + } +} diff --git a/pkg/docker/run.go b/pkg/docker/run.go index 07cd718ff6..6cf90a8a01 100644 --- a/pkg/docker/run.go +++ b/pkg/docker/run.go @@ -38,7 +38,7 @@ func GetHostPortForContainer(ctx context.Context, dockerCommand command.Command, console.Debugf("=== DockerCommand.GetPort %s/%d", containerID, containerPort) if hostIP == "" { - hostIP = "127.0.0.1" + hostIP = command.DefaultHostIP } inspect, err := dockerCommand.ContainerInspect(ctx, containerID) @@ -59,7 +59,8 @@ func GetHostPortForContainer(ctx context.Context, dockerCommand command.Command, return 0, fmt.Errorf("container %s does not have expected network configuration", containerID) } - for _, portBinding := range inspect.NetworkSettings.Ports[targetPort] { + bindings := inspect.NetworkSettings.Ports[targetPort] + for _, portBinding := range bindings { if portBinding.HostIP != hostIP { continue } @@ -70,5 +71,16 @@ func GetHostPortForContainer(ctx context.Context, dockerCommand command.Command, return hostPort, nil } + // Fall back to a single wildcard or unspecified binding. Docker may report + // 0.0.0.0 or an empty HostIP even when we requested a specific localhost + // address, and such a binding is reachable on localhost. + if len(bindings) == 1 && (bindings[0].HostIP == "" || bindings[0].HostIP == "0.0.0.0") { + hostPort, err := nat.ParsePort(bindings[0].HostPort) + if err != nil { + return 0, fmt.Errorf("failed to parse host port: %w", err) + } + return hostPort, nil + } + return 0, fmt.Errorf("container %s does not have a port bound to %s", containerID, hostIP) } diff --git a/pkg/docker/run_test.go b/pkg/docker/run_test.go index 4f95dee33f..40b88e925e 100644 --- a/pkg/docker/run_test.go +++ b/pkg/docker/run_test.go @@ -8,208 +8,148 @@ import ( "github.com/docker/go-connections/nat" "github.com/stretchr/testify/require" + "github.com/replicate/cog/pkg/docker/command" "github.com/replicate/cog/pkg/docker/dockertest" ) func TestGetHostPortForContainer(t *testing.T) { - t.Run("WithExposedPort", func(t *testing.T) { - testClient := dockertest.NewMockCommand2(t) - testClient.EXPECT().ContainerInspect(t.Context(), "container123").Return(&container.InspectResponse{ - ContainerJSONBase: &container.ContainerJSONBase{ - State: &container.State{ - Status: "running", - Running: true, - }, - }, - NetworkSettings: &container.NetworkSettings{ - NetworkSettingsBase: container.NetworkSettingsBase{ - Ports: nat.PortMap{ - nat.Port("5678/tcp"): []nat.PortBinding{ - { - HostIP: "127.0.0.1", - HostPort: "12345", - }, - }, - }, - }, - }, - }, nil) + runningState := &container.State{Status: "running", Running: true} - hostPort, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678, "127.0.0.1") - require.NoError(t, err) - require.Equal(t, 12345, hostPort) - }) - - t.Run("WithExposedPortDefaultHostIP", func(t *testing.T) { - testClient := dockertest.NewMockCommand2(t) - testClient.EXPECT().ContainerInspect(t.Context(), "container123").Return(&container.InspectResponse{ - ContainerJSONBase: &container.ContainerJSONBase{ - State: &container.State{ - Status: "running", - Running: true, - }, - }, + inspect := func(bindings []nat.PortBinding) *container.InspectResponse { + return &container.InspectResponse{ + ContainerJSONBase: &container.ContainerJSONBase{State: runningState}, NetworkSettings: &container.NetworkSettings{ NetworkSettingsBase: container.NetworkSettingsBase{ - Ports: nat.PortMap{ - nat.Port("5678/tcp"): []nat.PortBinding{ - { - HostIP: "127.0.0.1", - HostPort: "12345", - }, - }, - }, + Ports: nat.PortMap{nat.Port("5678/tcp"): bindings}, }, }, - }, nil) - - // Empty hostIP should default to 127.0.0.1 - hostPort, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678, "") - require.NoError(t, err) - require.Equal(t, 12345, hostPort) - }) + } + } - t.Run("WithExposedPortAllInterfaces", func(t *testing.T) { - testClient := dockertest.NewMockCommand2(t) - testClient.EXPECT().ContainerInspect(t.Context(), "container123").Return(&container.InspectResponse{ - ContainerJSONBase: &container.ContainerJSONBase{ - State: &container.State{ - Status: "running", - Running: true, - }, - }, - NetworkSettings: &container.NetworkSettings{ - NetworkSettingsBase: container.NetworkSettingsBase{ - Ports: nat.PortMap{ - nat.Port("5678/tcp"): []nat.PortBinding{ - { - HostIP: "0.0.0.0", - HostPort: "12345", - }, - }, - }, - }, - }, - }, nil) - - hostPort, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678, "0.0.0.0") - require.NoError(t, err) - require.Equal(t, 12345, hostPort) - }) - - t.Run("WithMultipleExposedPorts", func(t *testing.T) { - testClient := dockertest.NewMockCommand2(t) - testClient.EXPECT().ContainerInspect(t.Context(), "container123").Return(&container.InspectResponse{ - ContainerJSONBase: &container.ContainerJSONBase{ - State: &container.State{ - Status: "running", - Running: true, - }, - }, + inspectDifferentPort := func(bindings []nat.PortBinding) *container.InspectResponse { + return &container.InspectResponse{ + ContainerJSONBase: &container.ContainerJSONBase{State: runningState}, NetworkSettings: &container.NetworkSettings{ NetworkSettingsBase: container.NetworkSettingsBase{ - Ports: nat.PortMap{ - nat.Port("5678/tcp"): []nat.PortBinding{ - { - HostIP: "127.0.0.1", - HostPort: "12345", - }, - { - HostIP: "127.0.0.1", - HostPort: "54321", - }, - }, - }, + Ports: nat.PortMap{nat.Port("1234/tcp"): bindings}, }, }, - }, nil) + } + } - hostPort, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678, "127.0.0.1") - require.NoError(t, err) - require.Equal(t, 12345, hostPort) - }) + inspectNoNetwork := func() *container.InspectResponse { + return &container.InspectResponse{ + ContainerJSONBase: &container.ContainerJSONBase{State: runningState}, + } + } - t.Run("WithExposedPortOnDifferentAddress", func(t *testing.T) { - testClient := dockertest.NewMockCommand2(t) - testClient.EXPECT().ContainerInspect(t.Context(), "container123").Return(&container.InspectResponse{ + inspectNotRunning := func() *container.InspectResponse { + return &container.InspectResponse{ ContainerJSONBase: &container.ContainerJSONBase{ - State: &container.State{ - Status: "running", - Running: true, - }, + State: &container.State{Status: "dead", Dead: true}, }, - NetworkSettings: &container.NetworkSettings{ - NetworkSettingsBase: container.NetworkSettingsBase{ - Ports: nat.PortMap{ - nat.Port("5678/tcp"): []nat.PortBinding{ - { - HostIP: "0.0.0.0", - HostPort: "12345", - }, - }, - }, - }, - }, - }, nil) - - _, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678, "127.0.0.1") - require.ErrorContains(t, err, "does not have a port bound to 127.0.0.1") - }) - - t.Run("WithDifferentPortExposed", func(t *testing.T) { - testClient := dockertest.NewMockCommand2(t) - testClient.EXPECT().ContainerInspect(t.Context(), "container123").Return(&container.InspectResponse{ - ContainerJSONBase: &container.ContainerJSONBase{ - State: &container.State{ - Status: "running", - Running: true, - }, - }, - NetworkSettings: &container.NetworkSettings{ - NetworkSettingsBase: container.NetworkSettingsBase{ - Ports: nat.PortMap{ - nat.Port("1234/tcp"): []nat.PortBinding{ - { - HostIP: "127.0.0.1", - HostPort: "12345", - }, - }, - }, - }, - }, - }, nil) - - _, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678, "127.0.0.1") - require.ErrorContains(t, err, "does not have a port bound to 127.0.0.1") - }) - - t.Run("WithNoExposedPort", func(t *testing.T) { - testClient := dockertest.NewMockCommand2(t) - testClient.EXPECT().ContainerInspect(t.Context(), "container123").Return(&container.InspectResponse{ - ContainerJSONBase: &container.ContainerJSONBase{ - State: &container.State{ - Status: "running", - Running: true, - }, - }, - }, nil) - - _, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678, "127.0.0.1") - require.ErrorContains(t, err, "does not have expected network configuration") - }) - - t.Run("ContainerNotRunning", func(t *testing.T) { - testClient := dockertest.NewMockCommand2(t) - testClient.EXPECT().ContainerInspect(t.Context(), "container123").Return(&container.InspectResponse{ - ContainerJSONBase: &container.ContainerJSONBase{ - State: &container.State{ - Status: "dead", - Dead: true, - }, - }, - }, nil) - - _, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678, "127.0.0.1") - require.ErrorContains(t, err, "is not running") - }) + } + } + + tests := []struct { + name string + hostIP string + inspect *container.InspectResponse + wantPort int + wantErrString string + }{ + { + name: "matching localhost binding", + hostIP: command.DefaultHostIP, + inspect: inspect([]nat.PortBinding{{HostIP: command.DefaultHostIP, HostPort: "12345"}}), + wantPort: 12345, + }, + { + name: "empty hostIP defaults to localhost", + hostIP: "", + inspect: inspect([]nat.PortBinding{{HostIP: command.DefaultHostIP, HostPort: "12345"}}), + wantPort: 12345, + }, + { + name: "all interfaces", + hostIP: "0.0.0.0", + inspect: inspect([]nat.PortBinding{{HostIP: "0.0.0.0", HostPort: "12345"}}), + wantPort: 12345, + }, + { + name: "custom IP", + hostIP: "192.168.1.1", + inspect: inspect([]nat.PortBinding{{HostIP: "192.168.1.1", HostPort: "12345"}}), + wantPort: 12345, + }, + { + name: "IPv6 localhost", + hostIP: "::1", + inspect: inspect([]nat.PortBinding{{HostIP: "::1", HostPort: "12345"}}), + wantPort: 12345, + }, + { + name: "fallback to single binding when HostIP differs", + hostIP: command.DefaultHostIP, + inspect: inspect([]nat.PortBinding{{HostIP: "0.0.0.0", HostPort: "12345"}}), + wantPort: 12345, + }, + { + name: "fallback to single binding when HostIP is empty", + hostIP: command.DefaultHostIP, + inspect: inspect([]nat.PortBinding{{HostIP: "", HostPort: "12345"}}), + wantPort: 12345, + }, + { + name: "error when single binding has non-matching specific IP", + hostIP: command.DefaultHostIP, + inspect: inspect([]nat.PortBinding{{HostIP: "192.168.1.1", HostPort: "12345"}}), + wantErrString: "does not have a port bound to " + command.DefaultHostIP, + }, + { + name: "select matching binding from multiple", + hostIP: command.DefaultHostIP, + inspect: inspect([]nat.PortBinding{{HostIP: command.DefaultHostIP, HostPort: "12345"}, {HostIP: command.DefaultHostIP, HostPort: "54321"}}), + wantPort: 12345, + }, + { + name: "error when no matching binding and multiple bindings", + hostIP: command.DefaultHostIP, + inspect: inspect([]nat.PortBinding{{HostIP: "0.0.0.0", HostPort: "12345"}, {HostIP: "192.168.1.1", HostPort: "54321"}}), + wantErrString: "does not have a port bound to " + command.DefaultHostIP, + }, + { + name: "error when target port not exposed", + hostIP: command.DefaultHostIP, + inspect: inspectDifferentPort([]nat.PortBinding{{HostIP: command.DefaultHostIP, HostPort: "12345"}}), + wantErrString: "does not have a port bound to " + command.DefaultHostIP, + }, + { + name: "error when network settings missing", + hostIP: command.DefaultHostIP, + inspect: inspectNoNetwork(), + wantErrString: "does not have expected network configuration", + }, + { + name: "error when container not running", + hostIP: command.DefaultHostIP, + inspect: inspectNotRunning(), + wantErrString: "is not running", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testClient := dockertest.NewMockCommand2(t) + testClient.EXPECT().ContainerInspect(t.Context(), "container123").Return(tt.inspect, nil) + + hostPort, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678, tt.hostIP) + if tt.wantErrString != "" { + require.ErrorContains(t, err, tt.wantErrString) + return + } + require.NoError(t, err) + require.Equal(t, tt.wantPort, hostPort) + }) + } } diff --git a/pkg/predict/predictor.go b/pkg/predict/predictor.go index ee39b91b79..4f23f2b27a 100644 --- a/pkg/predict/predictor.go +++ b/pkg/predict/predictor.go @@ -98,7 +98,7 @@ func NewPredictor(_ context.Context, opts PredictorOptions) (*Predictor, error) func (p *Predictor) Start(ctx context.Context, logsWriter io.Writer, timeout time.Duration) (retErr error) { containerPort := 5000 - hostIP := "127.0.0.1" + hostIP := command.DefaultHostIP if p.weightManager != nil { mounts, err := p.weightManager.Prepare(ctx)