diff --git a/internal/app/app.go b/internal/app/app.go index 62e09d6..102338c 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -102,9 +102,9 @@ type App struct { currentTab int // 0=Data, 1=Columns, 2=Constraints, 3=Indexes // Code editor for viewing/editing database object definitions - codeEditor *components.CodeEditor - showCodeEditor bool - isLoadingObjectDetails bool // Loading indicator for function/sequence/etc details + codeEditor *components.CodeEditor + showCodeEditor bool + isLoadingObjectDetails bool // Loading indicator for function/sequence/etc details // Favorites showFavorites bool @@ -875,8 +875,8 @@ func (a *App) Update(msg tea.Msg) (tea.Model, tea.Cmd) { "j": true, "k": true, "up": true, "down": true, // Scrolling "g": true, "G": true, // Scroll to top/bottom "ctrl+d": true, "ctrl+u": true, // Page scroll - "y": true, // Copy - "e": true, // Enter edit mode + "y": true, // Copy + "e": true, // Enter edit mode "esc": true, // Close (q is reserved for quitting app) } key := msg.String() @@ -2118,7 +2118,7 @@ func (a *App) renderNormalView() string { // Width must account for border: lipgloss Width() sets content area, // border chars are added outside, so subtract border width (2) to avoid overflow topBar := lipgloss.NewStyle(). - Width(a.state.Width - 2). + Width(a.state.Width-2). Background(lipgloss.Color("#313244")). Foreground(lipgloss.Color("#cdd6f4")). Border(lipgloss.RoundedBorder()). @@ -2214,7 +2214,7 @@ func (a *App) renderNormalView() string { // Create modern bottom bar // Width must account for border: subtract border width (2) to avoid overflow bottomBar := lipgloss.NewStyle(). - Width(a.state.Width - 2). + Width(a.state.Width-2). Background(lipgloss.Color("#313244")). Foreground(lipgloss.Color("#cdd6f4")). Border(lipgloss.RoundedBorder()). @@ -3160,17 +3160,7 @@ func (a *App) connectToHistoryEntry(entry models.ConnectionHistoryEntry) (tea.Mo // connectToDiscoveredInstance connects using a discovered instance func (a *App) connectToDiscoveredInstance(instance models.DiscoveredInstance) (tea.Model, tea.Cmd) { - // Create connection config from discovered instance - config := models.ConnectionConfig{ - Host: instance.Host, - Port: instance.Port, - Database: "postgres", // Default database - User: os.Getenv("USER"), // Current user - Password: "", // No password for now - SSLMode: "prefer", - } - - return a.performConnection(config) + return a.performConnection(discovery.BuildConnectionConfig(instance)) } // performConnection starts an async connection attempt @@ -3362,15 +3352,7 @@ func (a *App) handleConnectionDialog(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return a, nil } - // Create connection config from discovered instance - config = models.ConnectionConfig{ - Host: instance.Host, - Port: instance.Port, - Database: "postgres", - User: os.Getenv("USER"), - Password: "", - SSLMode: "prefer", - } + config = discovery.BuildConnectionConfig(*instance) } return a.performConnection(config) @@ -4501,9 +4483,9 @@ func (a *App) overlayLine(background, foreground string, startX int) string { // SearchTableResultMsg is sent when table search completes type SearchTableResultMsg struct { - Query string - Data *metadata.TableData - Err error + Query string + Data *metadata.TableData + Err error } // searchTable executes a table-wide search diff --git a/internal/app/delegates/connection.go b/internal/app/delegates/connection.go index b1de6d1..07b0673 100644 --- a/internal/app/delegates/connection.go +++ b/internal/app/delegates/connection.go @@ -71,8 +71,8 @@ func (d *ConnectionDelegate) handleConnectionResult(msg messages.ConnectionResul if msg.Err != nil { // Connection failed - clear pending password (don't save wrong password) app.ClearPendingPasswordSave() - app.ShowError("Connection Failed", fmt.Sprintf("Could not connect to %s:%d\n\nError: %v", - msg.Config.Host, msg.Config.Port, msg.Err)) + app.ShowError("Connection Failed", fmt.Sprintf("Could not connect to %s\n\nError: %v", + msg.Config.DisplayTarget(), msg.Err)) return true, nil } diff --git a/internal/db/discovery/config.go b/internal/db/discovery/config.go new file mode 100644 index 0000000..b3f4375 --- /dev/null +++ b/internal/db/discovery/config.go @@ -0,0 +1,88 @@ +package discovery + +import ( + "os" + "strings" + + "github.com/rebelice/lazypg/internal/models" +) + +// BuildConnectionConfig turns a discovered instance into a connection config. +func BuildConnectionConfig(instance models.DiscoveredInstance) models.ConnectionConfig { + switch instance.Source { + case models.SourceEnvironment: + if envConfig := GetEnvironmentConfig(); envConfig != nil && envConfig.Host == instance.Host && envConfig.Port == instance.Port { + return *envConfig + } + case models.SourcePgPass: + if pgpassConfig := buildPgPassConfig(instance.Host, instance.Port); pgpassConfig != nil { + return *pgpassConfig + } + } + + return buildDefaultConfig(instance) +} + +func buildPgPassConfig(host string, port int) *models.ConnectionConfig { + entries, err := ParsePgPass() + if err != nil { + return nil + } + + for _, entry := range entries { + if entry.Host != host || entry.Port != port { + continue + } + + user := entry.User + if user == "" || user == "*" { + user = defaultUser() + } + + database := entry.Database + if database == "" || database == "*" { + database = defaultDatabase(user) + } + + return &models.ConnectionConfig{ + Host: host, + Port: port, + Database: database, + User: user, + Password: entry.Password, + SSLMode: "prefer", + } + } + + return nil +} + +func buildDefaultConfig(instance models.DiscoveredInstance) models.ConnectionConfig { + user := defaultUser() + + return models.ConnectionConfig{ + Host: instance.Host, + Port: instance.Port, + Database: defaultDatabase(user), + User: user, + SSLMode: "prefer", + } +} + +func defaultUser() string { + for _, key := range []string{"PGUSER", "USER", "USERNAME"} { + if value := strings.TrimSpace(os.Getenv(key)); value != "" { + return value + } + } + + return "postgres" +} + +func defaultDatabase(user string) string { + if user != "" { + return user + } + + return "postgres" +} diff --git a/internal/db/discovery/discovery.go b/internal/db/discovery/discovery.go index c448848..2844fb0 100644 --- a/internal/db/discovery/discovery.go +++ b/internal/db/discovery/discovery.go @@ -29,11 +29,15 @@ func (d *Discoverer) DiscoverAll(ctx context.Context) []models.DiscoveredInstanc instances = append(instances, *envInstance) } - // 2. Scan localhost ports + // 2. Scan common Unix socket directories + unixSocketInstances := d.scanner.ScanUnixSockets(ctx) + instances = append(instances, unixSocketInstances...) + + // 3. Scan localhost ports localInstances := d.scanner.ScanLocalhost(ctx) instances = append(instances, localInstances...) - // 3. Parse .pgpass + // 4. Parse .pgpass pgpassInstances := GetDiscoveredInstances() instances = append(instances, pgpassInstances...) @@ -42,7 +46,15 @@ func (d *Discoverer) DiscoverAll(ctx context.Context) []models.DiscoveredInstanc // Sort by source priority sort.Slice(instances, func(i, j int) bool { - return instances[i].Source < instances[j].Source + if instances[i].Source != instances[j].Source { + return instances[i].Source < instances[j].Source + } + + if instances[i].Host != instances[j].Host { + return instances[i].Host < instances[j].Host + } + + return instances[i].Port < instances[j].Port }) return instances diff --git a/internal/db/discovery/unix_socket.go b/internal/db/discovery/unix_socket.go new file mode 100644 index 0000000..d6792c6 --- /dev/null +++ b/internal/db/discovery/unix_socket.go @@ -0,0 +1,161 @@ +package discovery + +import ( + "context" + "fmt" + "net" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" + "time" + + "github.com/rebelice/lazypg/internal/models" +) + +var defaultUnixSocketDirs = []string{ + "/var/run/postgresql", + "/run/postgresql", + "/tmp", + "/private/tmp", + "/var/pgsql_socket", + "/private/var/run/postgresql", + "/opt/homebrew/var/run/postgresql", + "/usr/local/var/run/postgresql", +} + +// ScanUnixSockets scans common PostgreSQL socket directories. +func (s *Scanner) ScanUnixSockets(ctx context.Context) []models.DiscoveredInstance { + if runtime.GOOS == "windows" { + return nil + } + + return s.ScanUnixSocketDirs(ctx, candidateUnixSocketDirs()) +} + +// ScanUnixSocketDirs scans the provided directories for PostgreSQL socket files. +func (s *Scanner) ScanUnixSocketDirs(ctx context.Context, dirs []string) []models.DiscoveredInstance { + if runtime.GOOS == "windows" { + return nil + } + + instances := make([]models.DiscoveredInstance, 0) + seen := make(map[string]struct{}) + + for _, dir := range uniqueSocketDirs(dirs) { + if ctx.Err() != nil { + break + } + + for _, instance := range s.scanUnixSocketDir(ctx, dir) { + key := instance.DisplayTarget() + if _, exists := seen[key]; exists { + continue + } + + seen[key] = struct{}{} + instances = append(instances, instance) + } + } + + return instances +} + +func (s *Scanner) scanUnixSocketDir(ctx context.Context, dir string) []models.DiscoveredInstance { + entries, err := os.ReadDir(dir) + if err != nil { + return nil + } + + instances := make([]models.DiscoveredInstance, 0) + for _, entry := range entries { + if ctx.Err() != nil { + break + } + + port, ok := postgresSocketPort(entry.Name()) + if !ok { + continue + } + + instance := s.scanUnixSocket(ctx, dir, port) + if instance.Available { + instances = append(instances, instance) + } + } + + return instances +} + +func (s *Scanner) scanUnixSocket(ctx context.Context, dir string, port int) models.DiscoveredInstance { + instance := models.DiscoveredInstance{ + Host: dir, + Port: port, + Source: models.SourceUnixSocket, + } + + start := time.Now() + socketPath := filepath.Join(dir, fmt.Sprintf(".s.PGSQL.%d", port)) + + dialer := &net.Dialer{Timeout: s.timeout} + conn, err := dialer.DialContext(ctx, "unix", socketPath) + instance.ResponseTime = time.Since(start) + if err != nil { + return instance + } + + _ = conn.Close() + instance.Available = true + return instance +} + +func candidateUnixSocketDirs() []string { + dirs := make([]string, 0, len(defaultUnixSocketDirs)+1) + + if host := strings.TrimSpace(os.Getenv("PGHOST")); strings.HasPrefix(host, "/") { + dirs = append(dirs, host) + } + + dirs = append(dirs, defaultUnixSocketDirs...) + return dirs +} + +func uniqueSocketDirs(dirs []string) []string { + unique := make([]string, 0, len(dirs)) + seen := make(map[string]struct{}) + + for _, dir := range dirs { + dir = strings.TrimSpace(dir) + if dir == "" { + continue + } + + if _, exists := seen[dir]; exists { + continue + } + + seen[dir] = struct{}{} + unique = append(unique, dir) + } + + return unique +} + +func postgresSocketPort(name string) (int, bool) { + if !strings.HasPrefix(name, ".s.PGSQL.") { + return 0, false + } + + portStr := strings.TrimPrefix(name, ".s.PGSQL.") + if portStr == "" || strings.Contains(portStr, ".") { + return 0, false + } + + port, err := strconv.Atoi(portStr) + if err != nil || port < 1 || port > 65535 { + return 0, false + } + + return port, true +} diff --git a/internal/db/discovery/unix_socket_test.go b/internal/db/discovery/unix_socket_test.go new file mode 100644 index 0000000..6416aab --- /dev/null +++ b/internal/db/discovery/unix_socket_test.go @@ -0,0 +1,94 @@ +package discovery + +import ( + "context" + "net" + "path/filepath" + "runtime" + "testing" + + "github.com/rebelice/lazypg/internal/models" +) + +func TestScanUnixSocketDirs(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("unix sockets are not supported on windows") + } + + tempDir := t.TempDir() + socketPath := filepath.Join(tempDir, ".s.PGSQL.6543") + + listener, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatalf("listen on unix socket: %v", err) + } + defer func() { + _ = listener.Close() + }() + + acceptDone := make(chan struct{}) + go func() { + defer close(acceptDone) + conn, err := listener.Accept() + if err == nil { + _ = conn.Close() + } + }() + + scanner := NewScanner() + instances := scanner.ScanUnixSocketDirs(context.Background(), []string{tempDir}) + + if len(instances) != 1 { + t.Fatalf("expected 1 discovered socket, got %d", len(instances)) + } + + instance := instances[0] + if instance.Host != tempDir { + t.Fatalf("expected host %q, got %q", tempDir, instance.Host) + } + if instance.Port != 6543 { + t.Fatalf("expected port 6543, got %d", instance.Port) + } + if instance.Source != models.SourceUnixSocket { + t.Fatalf("expected source %v, got %v", models.SourceUnixSocket, instance.Source) + } + if !instance.Available { + t.Fatal("expected discovered socket to be available") + } + + <-acceptDone +} + +func TestBuildConnectionConfigForSocketDefaults(t *testing.T) { + t.Setenv("PGUSER", "") + t.Setenv("USER", "socket-user") + t.Setenv("USERNAME", "") + + config := BuildConnectionConfig(models.DiscoveredInstance{ + Host: "/tmp", + Port: 5432, + Source: models.SourceUnixSocket, + }) + + if config.Host != "/tmp" { + t.Fatalf("expected host /tmp, got %q", config.Host) + } + if config.Port != 5432 { + t.Fatalf("expected port 5432, got %d", config.Port) + } + if config.User != "socket-user" { + t.Fatalf("expected user socket-user, got %q", config.User) + } + if config.Database != "socket-user" { + t.Fatalf("expected database socket-user, got %q", config.Database) + } + if config.SSLMode != "prefer" { + t.Fatalf("expected sslmode prefer, got %q", config.SSLMode) + } + if !config.UsesUnixSocket() { + t.Fatal("expected config to use a unix socket") + } + if got := config.DisplayTarget(); got != filepath.Join("/tmp", ".s.PGSQL.5432") { + t.Fatalf("unexpected display target %q", got) + } +} diff --git a/internal/models/connection.go b/internal/models/connection.go index 2fb5e29..f257164 100644 --- a/internal/models/connection.go +++ b/internal/models/connection.go @@ -1,6 +1,9 @@ package models import ( + "fmt" + "path/filepath" + "strings" "time" ) @@ -15,6 +18,16 @@ type ConnectionConfig struct { SSLMode string `yaml:"ssl_mode"` } +// UsesUnixSocket returns true when the host represents a Unix socket directory. +func (c ConnectionConfig) UsesUnixSocket() bool { + return isUnixSocketHost(c.Host) +} + +// DisplayTarget returns a user-friendly connection target. +func (c ConnectionConfig) DisplayTarget() string { + return formatConnectionTarget(c.Host, c.Port) +} + // Connection represents an active database connection type Connection struct { ID string @@ -44,6 +57,16 @@ type DiscoveredInstance struct { ResponseTime time.Duration } +// UsesUnixSocket returns true when the discovered host represents a Unix socket directory. +func (d DiscoveredInstance) UsesUnixSocket() bool { + return isUnixSocketHost(d.Host) +} + +// DisplayTarget returns a user-friendly discovery target. +func (d DiscoveredInstance) DisplayTarget() string { + return formatConnectionTarget(d.Host, d.Port) +} + // DiscoverySource indicates how an instance was discovered type DiscoverySource int @@ -77,17 +100,17 @@ func (s DiscoverySource) String() string { // ConnectionHistoryEntry represents a saved connection from history type ConnectionHistoryEntry struct { - ID string `yaml:"id"` - Name string `yaml:"name"` // User-friendly name (auto-generated or custom) - Host string `yaml:"host"` - Port int `yaml:"port"` - Database string `yaml:"database"` - User string `yaml:"user"` + ID string `yaml:"id"` + Name string `yaml:"name"` // User-friendly name (auto-generated or custom) + Host string `yaml:"host"` + Port int `yaml:"port"` + Database string `yaml:"database"` + User string `yaml:"user"` // Note: Password is NOT stored for security reasons - SSLMode string `yaml:"ssl_mode"` - LastUsed time.Time `yaml:"last_used"` - UsageCount int `yaml:"usage_count"` - CreatedAt time.Time `yaml:"created_at"` + SSLMode string `yaml:"ssl_mode"` + LastUsed time.Time `yaml:"last_used"` + UsageCount int `yaml:"usage_count"` + CreatedAt time.Time `yaml:"created_at"` } // ToConnectionConfig converts a history entry to a ConnectionConfig (without password) @@ -102,3 +125,15 @@ func (e *ConnectionHistoryEntry) ToConnectionConfig() ConnectionConfig { SSLMode: e.SSLMode, } } + +func isUnixSocketHost(host string) bool { + return strings.HasPrefix(host, "/") +} + +func formatConnectionTarget(host string, port int) string { + if isUnixSocketHost(host) { + return filepath.Join(host, fmt.Sprintf(".s.PGSQL.%d", port)) + } + + return fmt.Sprintf("%s:%d", host, port) +} diff --git a/internal/ui/components/connection_dialog.go b/internal/ui/components/connection_dialog.go index ba546ed..7b86288 100644 --- a/internal/ui/components/connection_dialog.go +++ b/internal/ui/components/connection_dialog.go @@ -29,9 +29,9 @@ type ConnectionDialog struct { searchInput textinput.Model // Text input fields for manual mode - inputs []textinput.Model - focusIndex int - cursorMode cursor.Mode + inputs []textinput.Model + focusIndex int + cursorMode cursor.Mode } const ( @@ -308,9 +308,8 @@ func (c *ConnectionDialog) renderDiscoveryMode(contentWidth int) string { sourceStyle := lipgloss.NewStyle(). Foreground(lipgloss.Color("#6c7086")) - line := fmt.Sprintf("%s:%d %s", - instance.Host, - instance.Port, + line := fmt.Sprintf("%s %s", + instance.DisplayTarget(), sourceStyle.Render(fmt.Sprintf("(%s)", instance.Source.String())), ) // Wrap with zone for click detection