From 2d4cc0350db293038b2f664dbf0bb2f8e2e99b89 Mon Sep 17 00:00:00 2001 From: Samiul Sk Date: Sat, 1 Nov 2025 22:26:07 +0530 Subject: [PATCH] feat: add localtunnel provider implementation --- internal/provider/localtunnel.go | 287 ++++++++++++++++++++++++ internal/provider/localtunnel_test.go | 299 ++++++++++++++++++++++++++ internal/tunnel/provider.go | 22 ++ 3 files changed, 608 insertions(+) create mode 100644 internal/provider/localtunnel.go create mode 100644 internal/provider/localtunnel_test.go create mode 100644 internal/tunnel/provider.go diff --git a/internal/provider/localtunnel.go b/internal/provider/localtunnel.go new file mode 100644 index 0000000..c5b434e --- /dev/null +++ b/internal/provider/localtunnel.go @@ -0,0 +1,287 @@ +package provider + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "sync" + "time" + + "github.com/kernelshard/expose/internal/tunnel" +) + +const ( + localTunnelProviderName = "LocalTunnel" + localtunnelAPI = "https://localtunnel.me" + localTunnelTCPHost = "localtunnel.me" + // maximum concurrent connections allowed for us, + // override if tunnel api sends their limit + clientMaxConn = 10 +) + +// localTunnel implements the Provider interface for localtunnel.me +// It manages the lifecycle of a tunnel connection. +// It maintains a pool of TCP connections to handle incoming requests. +// It forwards traffic from the tunnel to the local server running on localPort & vice versa. +type localTunnel struct { + publicURL string + localPort int + tunnelPort int + tunnelHost string + connected bool + mu sync.RWMutex + connections []net.Conn // connection pool + maxConnections int + ctx context.Context + cancel context.CancelFunc + + // HTTP client for API calls, reusable + httpClient *http.Client + // api endpoint string, it's configurable for testing + serverAPIEndpoint string +} + +// TunnelInfo is the response model from localtunnel server when establishing a tunnel. +type TunnelInfo struct { + ID string `json:"id"` + URL string `json:"url"` + Port int `json:"port"` + MaxConn int `json:"max_conn_count"` +} + +// NewLocalTunnel creates a new localTunnel provider instance. +func NewLocalTunnel(httpClient *http.Client) tunnel.Provider { + if httpClient == nil { + httpClient = http.DefaultClient + } + + return &localTunnel{ + connections: make([]net.Conn, 0, clientMaxConn), + httpClient: httpClient, + serverAPIEndpoint: localtunnelAPI, + } +} + +// Connect establishes tunnel to localtunnel.me +func (lt *localTunnel) Connect(ctx context.Context, localPort int) (string, error) { + lt.mu.Lock() + lt.localPort = localPort + lt.ctx, lt.cancel = context.WithCancel(ctx) + lt.mu.Unlock() + + // Step 1: Request tunnel from the localtunnel.me + info, err := lt.requestTunnel(ctx) + if err != nil { + return "", fmt.Errorf("failed to request tunnel: %w", err) + } + + lt.mu.Lock() + lt.publicURL = info.URL + lt.tunnelPort = info.Port + lt.tunnelHost = localTunnelTCPHost + + // set maxConnections allowed to open + if info.MaxConn > 0 { + // Take minimum: respect both server limit and our limit + lt.maxConnections = min(info.MaxConn, clientMaxConn) + } else { + // Server didn't specify, use our default + lt.maxConnections = clientMaxConn + } + + lt.mu.Unlock() + + // Step 2: Open TCP connection pool which + // - connects to localtunnel server + // - handles incoming requests and forwards to local server + // - forwards responses back to tunnel + // We open multiple connections to handle concurrent requests + if err := lt.openConnections(); err != nil { + return "", fmt.Errorf("failed to open connections: %w", err) + } + + lt.mu.Lock() + lt.connected = true + lt.mu.Unlock() + + return info.URL, nil + +} + +// requestTunnel request a tunnel from localtunnel.me API and returns the TunnelInfo. +// we make an HTTP GET request to localtunnel.me/?new +// localtunnel.me opens a tcp port for us and responds with the port +// and url info(to be used for accessing the local server) +func (lt *localTunnel) requestTunnel(ctx context.Context) (*TunnelInfo, error) { + localTunnelReqURL := lt.serverAPIEndpoint + "/?new" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, localTunnelReqURL, nil) + + if err != nil { + return nil, err + } + + // Perform the HTTP request to localtunnel.me + resp, err := lt.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + // Check for non-200 status codes + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("status %d:%s", resp.StatusCode, string(body)) + } + + // decode response body to TunnelInfo + var info TunnelInfo + err = json.NewDecoder(resp.Body).Decode(&info) + if err != nil { + return nil, fmt.Errorf("decode error: %w", err) + } + return &info, nil +} + +// openConnections opens a pool of TCP connections to the localtunnel server. +func (lt *localTunnel) openConnections() error { + lt.mu.Lock() + defer lt.mu.Unlock() + + for i := 0; i < lt.maxConnections; i++ { + // create tunnel connection to the upstream server & store in pool + // each connection will handle incoming requests + conn, err := lt.dialTunnel() + if err != nil { + // Close any connections we already opened + // TODO: can do retry here instead of failing immediately + lt.closeAllConnections() + return fmt.Errorf("connection %d failed: %w", i, err) + } + // it used to close connections later + lt.connections = append(lt.connections, conn) + + // Start handling this connection + go lt.handleConnection(conn) + } + + return nil +} + +// dialTunnel creates a single TCP connection to the localtunnel server. +func (lt *localTunnel) dialTunnel() (net.Conn, error) { + // TODO: give IPv6 support here using net.JoinHostPort later + address := fmt.Sprintf("%s:%d", lt.tunnelHost, lt.tunnelPort) + conn, err := net.DialTimeout("tcp", address, 10*time.Second) + if err != nil { + return nil, err + } + return conn, nil +} + +// closeAllConnections closes all existing TCP connections +func (lt *localTunnel) closeAllConnections() { + for _, conn := range lt.connections { + if conn != nil { + _ = conn.Close() + } + } + + lt.connections = lt.connections[:0] +} + +// handleConnection processes traffic from one tunnel connection +func (lt *localTunnel) handleConnection(tunnelConn net.Conn) { + defer tunnelConn.Close() + + for { + select { + // run until context is done means user does Ctrl+C or Close() is called + case <-lt.ctx.Done(): + return + default: + // Read request from tunnel + // Forward to localhost + // Write response back + // TODO: Use connection pool instead of dialing on every request + if err := lt.proxyRequest(tunnelConn); err != nil { + if lt.ctx.Err() != nil { + return // Shutting down + } + // Connection closed or error, exit this handler + fmt.Printf("[localtunnel] connection error: %v\n", err) + return + } + } + } +} + +// proxyRequest forwards data between the tunnel connection and the local server. +func (lt *localTunnel) proxyRequest(tunnelConn net.Conn) error { + // connect to local server + localAddr := fmt.Sprintf("localhost:%d", lt.localPort) + localConn, err := net.DialTimeout("tcp", localAddr, 5*time.Second) + if err != nil { + return fmt.Errorf("local dial failed: %w", err) + } + defer localConn.Close() + + // Set deadlines, it helps to avoid hanging connections + // e.g: if either side doesn't respond in time, the copy will end + _ = tunnelConn.SetDeadline(time.Now().Add(30 * time.Second)) + _ = localConn.SetDeadline(time.Now().Add(30 * time.Second)) + + // Start bidirectional copy + // mental model: copy(blocking ops) the data from tunnel to local and + //local to tunnel concurrently when either side closes, the copy ends + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + io.Copy(localConn, tunnelConn) + }() + + go func() { + defer wg.Done() + io.Copy(tunnelConn, localConn) + }() + + wg.Wait() + return nil + +} + +// Close terminates the tunnel +func (lt *localTunnel) Close() error { + lt.mu.Lock() + defer lt.mu.Unlock() + + if lt.cancel != nil { + lt.cancel() + + } + + lt.closeAllConnections() + lt.connected = false + return nil +} + +// IsConnected returns true if tunnel is active +func (lt *localTunnel) IsConnected() bool { + lt.mu.RLock() + defer lt.mu.RUnlock() + return lt.connected +} + +func (lt *localTunnel) PublicURL() string { + lt.mu.RLock() + defer lt.mu.RUnlock() + return lt.publicURL +} + +func (lt *localTunnel) Name() string { + return localTunnelProviderName +} diff --git a/internal/provider/localtunnel_test.go b/internal/provider/localtunnel_test.go new file mode 100644 index 0000000..506f1f4 --- /dev/null +++ b/internal/provider/localtunnel_test.go @@ -0,0 +1,299 @@ +package provider + +import ( + "context" + "encoding/json" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func Test_NewLocalTunnel(t *testing.T) { + t.Run("with nil httpClient should use default", func(t *testing.T) { + provider := NewLocalTunnel(nil) + lt := provider.(*localTunnel) + + if lt.httpClient == nil { + t.Fatal("expected default httpClient, got nil") + } + + if lt.httpClient.Timeout != http.DefaultClient.Timeout { + t.Errorf("expected %v timeout, got %v", http.DefaultClient.Timeout, lt.httpClient.Timeout) + } + + if lt.serverAPIEndpoint != localtunnelAPI { + t.Errorf("expected endpoint %s, got %s", localtunnelAPI, lt.serverAPIEndpoint) + } + + if cap(lt.connections) != clientMaxConn { + t.Errorf("expected connections capacity %d, got %d", clientMaxConn, cap(lt.connections)) + } + }) + + t.Run("with custom httpClient should use it", func(t *testing.T) { + customClient := &http.Client{Timeout: 5 * time.Second} + + provider := NewLocalTunnel(customClient) + lt := provider.(*localTunnel) + + if lt.httpClient != customClient { + t.Error("expected custom client to be used") + } + + if lt.httpClient.Timeout != 5*time.Second { + t.Errorf("expected 5s timeout, got %d", lt.httpClient.Timeout) + } + }, + ) + +} + +// Test_requestTunnel tests the API call +func Test_requestTunnel(t *testing.T) { + t.Run("successful API call", func(t *testing.T) { + dummyRespID := "abc123" + dummyRespURL := "https://abc123.example.com" + dummyRespPort := 666666 + + // mock handler + mockHandler := func(w http.ResponseWriter, r *http.Request) { + _, newQueryPresent := r.URL.Query()["new"] + if r.URL.Path != "/" || !newQueryPresent { + t.Error("expected /?new endpoint") + } + response := TunnelInfo{ + ID: dummyRespID, + URL: dummyRespURL, + Port: dummyRespPort, + MaxConn: clientMaxConn, + } + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(response) + } + + // mock server + server := httptest.NewServer(http.HandlerFunc(mockHandler)) + + defer server.Close() + + lt := localTunnel{ + httpClient: server.Client(), + serverAPIEndpoint: server.URL, + } + + ctx := context.Background() + info, err := lt.requestTunnel(ctx) + + if err != nil { + t.Fatalf("unexpected error:%v", err) + } + + if info.ID != dummyRespID { + t.Errorf("expected ID %s, got %s", dummyRespID, info.ID) + } + + if info.URL != dummyRespURL { + t.Errorf("expected URL %s, got %s", dummyRespURL, info.URL) + } + + if info.Port != dummyRespPort { + t.Errorf("expected Port %d, got %d", dummyRespPort, info.Port) + } + }) + + t.Run("non-200 status code", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("internal server error")) + })) + + defer server.Close() + + lt := localTunnel{httpClient: http.DefaultClient, serverAPIEndpoint: server.URL} + + ctx := context.Background() + _, err := lt.requestTunnel(ctx) + if err == nil { + t.Fatalf("expected error for non-200 status") + } + + if !strings.Contains(err.Error(), "500") { + t.Errorf("expected error to maintain status 500, got %v", err) + } + }) + + t.Run("invalid JSON response", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("invalid json")) + })) + + defer server.Close() + + lt := &localTunnel{ + httpClient: http.DefaultClient, + serverAPIEndpoint: server.URL, + } + + ctx := context.Background() + _, err := lt.requestTunnel(ctx) + + if err == nil { + t.Fatalf("expected decode error for invalid JSON") + } + + if !strings.Contains(err.Error(), "decode error") { + t.Errorf("expected decode error, got %v", err) + } + }) + + t.Run("context cancellation", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(100 * time.Millisecond) + w.WriteHeader(http.StatusOK) + })) + + defer server.Close() + + lt := &localTunnel{ + httpClient: http.DefaultClient, + serverAPIEndpoint: server.URL, + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + _, err := lt.requestTunnel(ctx) + + if err == nil { + t.Fatal("expected error for cancelled context") + } + }) +} + +// TestLocalTunnel_Name +func TestLocalTunnel_Name(t *testing.T) { + provider := NewLocalTunnel(nil) + + if provider.Name() != localTunnelProviderName { + t.Errorf("expected name %s, got %s", localTunnelProviderName, provider.Name()) + } +} + +// Test_IsConnected verifies connection state tracking +func TestLocalTunnel_IsConnected(t *testing.T) { + lt := localTunnel{} + if lt.IsConnected() { + t.Errorf("new tunnel should not be connected") + } + + lt.mu.Lock() + lt.connected = true + lt.mu.Unlock() + + if !lt.IsConnected() { + t.Error("expected IsConnected to turn true") + } + + lt.mu.Lock() + lt.connected = false + lt.mu.Unlock() + + if lt.IsConnected() { + t.Errorf("expected IsConnected to return false after disconnect") + } +} + +// TestLocalTunnel_PublicURL verifies URL getter +func TestLocalTunnel_PublicURL(t *testing.T) { + + tests := []struct { + name string + url string + }{ + {"empty URL", ""}, + {"valid URL", "https://test.localtunnel.me"}, + {"custom domain", "https://abc123.loca.lt"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lt := &localTunnel{publicURL: tt.url} + got := lt.PublicURL() + + if got != tt.url { + t.Errorf("expected URL %s, got %s", tt.url, got) + } + }) + } +} + +func Test_closeAllConnections(t *testing.T) { + // create mock connection s + conn1Client, conn1Server := net.Pipe() + conn2Client, conn2Server := net.Pipe() + + defer conn1Server.Close() + defer conn2Server.Close() + + lt := &localTunnel{ + connections: []net.Conn{conn1Client, conn2Client}, + } + + lt.closeAllConnections() + // make sure all closed + if len(lt.connections) != 0 { + t.Errorf("expected empty connections slice, got %d connections, ", len(lt.connections)) + } + + // verify connections are actually closed, + _, err := conn1Client.Write([]byte("test")) + if err == nil { + t.Error("expected error writing to closed connection") + } + _, err = conn2Client.Write([]byte("test")) + if err == nil { + t.Error("Expected error writing to closed connection") + } + +} + +func TestLocalTunnel_Close(t *testing.T) { + ctx, canelFunc := context.WithCancel(context.Background()) + + // create mock connection + clientConn, serverConn := net.Pipe() + defer serverConn.Close() + + lt := &localTunnel{ + connected: true, + publicURL: "https://test.example.com", + connections: []net.Conn{clientConn}, + cancel: canelFunc, + } + + err := lt.Close() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if lt.IsConnected() { + t.Error("expected connected to be false after Close") + } + if len(lt.connections) != 0 { + t.Errorf("expected connections to be cleared, got %d", len(lt.connections)) + } + + // verify ctx was canceled + select { + case <-ctx.Done(): + // expected + default: + t.Error("expected context to be cancelled") + } + +} diff --git a/internal/tunnel/provider.go b/internal/tunnel/provider.go new file mode 100644 index 0000000..c127bee --- /dev/null +++ b/internal/tunnel/provider.go @@ -0,0 +1,22 @@ +package tunnel + +import "context" + +// Provider is an interface for tunnel service providers. +// It defines the methods required to establish and manage a tunnel. +type Provider interface { + // Connect establishes a tunnel to the specified local port and returns the public URL. + Connect(ctx context.Context, localPort int) (string, error) + + // Close disconnect closes the tunnel connection & cleans up resources. + Close() error + + // IsConnected returns true if the tunnel is currently active. + IsConnected() bool + + // PublicURL returns the public URL of the tunnel. + PublicURL() string + + // Name of the provider (metadata) + Name() string // "localtunnel", "ngrok", etc. +}