From 2d6134d276e69eeedede821c609610369c235fef Mon Sep 17 00:00:00 2001 From: Samuele Cerea Date: Wed, 19 Nov 2025 23:10:54 +0100 Subject: [PATCH] Handle connection string for unix sockets correctly When connecting to a unix socket the path should be interpreted as the path to the directory that contains the socket rather that the path to the socket itself. --- CHANGELOG | 1 + spec/pq/conninfo_spec.cr | 2 +- src/pq/connection.cr | 6 ++++-- src/pq/conninfo.cr | 28 +++++++++++++++------------- 4 files changed, 21 insertions(+), 16 deletions(-) diff --git a/CHANGELOG b/CHANGELOG index 742bd017..2566c12b 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -5,6 +5,7 @@ v? upcoming * Add support for COPY (thanks @17dec) * Add pg log source to log entries (thanks @hugopl) * Update crystal-db to 0.14 (thanks @bcardiff) +* Handle connections strings with unix domain sockets correctly (thanks @samu698) v0.29.0 2024-11-05 ===================== diff --git a/spec/pq/conninfo_spec.cr b/spec/pq/conninfo_spec.cr index 5c063da8..07207a9c 100644 --- a/spec/pq/conninfo_spec.cr +++ b/spec/pq/conninfo_spec.cr @@ -140,7 +140,7 @@ describe PQ::ConnInfo, ".from_conninfo_string" do env_var_bubble do ENV["PGHOST"] = "/path" ci = PQ::ConnInfo.from_conninfo_string("postgres://") - ci.host.should eq("/path/.s.PGSQL.5432") + ci.host.should eq("/path") end end end diff --git a/src/pq/connection.cr b/src/pq/connection.cr index c226ddd4..581ca44d 100644 --- a/src/pq/connection.cr +++ b/src/pq/connection.cr @@ -22,8 +22,10 @@ module PQ def initialize(@conninfo : ConnInfo) begin - if @conninfo.host[0] == '/' - soc = UNIXSocket.new(@conninfo.host) + if Path.new(@conninfo.host).absolute? + socket_name = ".s.PGSQL.#{@conninfo.port}" + socket_path = File.join(@conninfo.host, socket_name) + soc = UNIXSocket.new(socket_path) else soc = TCPSocket.new(@conninfo.host, @conninfo.port) end diff --git a/src/pq/conninfo.cr b/src/pq/conninfo.cr index 6e1b9223..9dfcd2a0 100644 --- a/src/pq/conninfo.cr +++ b/src/pq/conninfo.cr @@ -4,7 +4,7 @@ require "system/user" module PQ struct ConnInfo - SOCKET_SEARCH = %w(/run/postgresql/.s.PGSQL.5432 /tmp/.s.PGSQL.5432 /var/run/postgresql/.s.PGSQL.5432) + SOCKET_SEARCH = %w(/run/postgresql /tmp /var/run/postgresql) SUPPORTED_AUTH_METHODS = %w[cleartext md5 scram-sha-256 scram-sha-256-plus] @@ -42,11 +42,11 @@ module PQ # Create a new ConnInfo from all parts def initialize(host : String? = nil, database : String? = nil, user : String? = nil, password : String? = nil, port : Int | String? = nil, sslmode : String | Symbol? = nil, application_name : String? = nil) - @host = default_host host + @port = (port || ENV.fetch("PGPORT", "5432")).to_i + @host = default_host(host, @port) db = default_database database @database = db.lchop('/') @user = default_user user - @port = (port || ENV.fetch("PGPORT", "5432")).to_i @sslmode = default_sslmode sslmode @password = password || ENV.fetch("PGPASSWORD", PgPass.locate(@host, @port, @database, @user)) @application_name = default_application_name application_name @@ -77,7 +77,8 @@ module PQ def initialize(uri : URI) params = URI::Params.parse(uri.query.to_s) hostname = uri.hostname.presence || params.fetch("host", "") - initialize(hostname, uri.path, uri.user, uri.password, uri.port, :prefer, params.fetch("application_name", nil)) + port = uri.port || params["port"]? + initialize(hostname, uri.path, uri.user, uri.password, port, :prefer, params.fetch("application_name", nil)) if q = uri.query HTTP::Params.parse(q) do |key, value| handle_sslparam(key, value) @@ -121,18 +122,19 @@ module PQ end end - private def default_host(h) - return h if h && !h.blank? - - if pghost = ENV["PGHOST"]? - return pghost[0] == '/' ? "#{pghost}/.s.PGSQL.5432" : pghost - end + private def default_host(h, port) + socket_name = ".s.PGSQL.#{port}" - SOCKET_SEARCH.each do |s| - return s if File.exists?(s) + if host = h.presence || ENV["PGHOST"]? + host = Path.new(host) + # For backwards compatibility: + # Check if the path is pointing to the socket file and replace it with + # the directory that contains the socket instead + host = host.basename == socket_name ? host.parent : host + return host.to_s end - "localhost" + SOCKET_SEARCH.find { |s| File.exists?(File.join(s, socket_name)) } || "localhost" end private def default_database(db)