From c08eb35abc90efa30d54066ebe52e6993789db6f Mon Sep 17 00:00:00 2001 From: Remy Marronnier Date: Fri, 20 Jun 2025 18:44:18 +0200 Subject: [PATCH] New features --- spec/pg/cursor_spec.cr | 162 +++++++++++++++++++++++ spec/pg/network_spec.cr | 84 ++++++++++++ spec/pg/range_spec.cr | 123 +++++++++++++++++ src/pg.cr | 1 + src/pg/connection.cr | 20 +++ src/pg/cursor.cr | 128 ++++++++++++++++++ src/pg/decoder.cr | 4 +- src/pg/decoders/network_decoder.cr | 203 +++++++++++++++++++++++++++++ src/pg/decoders/range_decoder.cr | 153 ++++++++++++++++++++++ src/pg/network.cr | 99 ++++++++++++++ src/pg/numeric.cr | 11 ++ src/pg/range.cr | 72 ++++++++++ src/pq/connection.cr | 6 + src/pq/param.cr | 22 ++++ 14 files changed, 1087 insertions(+), 1 deletion(-) create mode 100644 spec/pg/cursor_spec.cr create mode 100644 spec/pg/network_spec.cr create mode 100644 spec/pg/range_spec.cr create mode 100644 src/pg/cursor.cr create mode 100644 src/pg/decoders/network_decoder.cr create mode 100644 src/pg/decoders/range_decoder.cr create mode 100644 src/pg/network.cr create mode 100644 src/pg/range.cr diff --git a/spec/pg/cursor_spec.cr b/spec/pg/cursor_spec.cr new file mode 100644 index 00000000..ced7f750 --- /dev/null +++ b/spec/pg/cursor_spec.cr @@ -0,0 +1,162 @@ +require "../spec_helper" + +describe PG::Cursor do + it "works with simple query" do + with_connection do |conn| + conn.exec("create table cursor_test (id serial, name text)") + conn.exec("insert into cursor_test (name) values ('Alice'), ('Bob'), ('Charlie'), ('David'), ('Eve')") + + conn.transaction do + cursor = conn.cursor("select * from cursor_test order by id") + + # Fetch first batch + rows = [] of {Int32, String} + cursor.fetch(2) do |rs| + rows << {rs.read(Int32), rs.read(String)} + while rs.move_next + rows << {rs.read(Int32), rs.read(String)} + end + end + + rows.size.should eq(2) + rows[0].should eq({1, "Alice"}) + rows[1].should eq({2, "Bob"}) + + # Fetch next batch + rows.clear + cursor.fetch(2) do |rs| + rows << {rs.read(Int32), rs.read(String)} + while rs.move_next + rows << {rs.read(Int32), rs.read(String)} + end + end + + rows.size.should eq(2) + rows[0].should eq({3, "Charlie"}) + rows[1].should eq({4, "David"}) + + cursor.close + end + ensure + with_connection &.exec("drop table if exists cursor_test") + end + end + + it "works with block syntax" do + with_connection do |conn| + conn.exec("create table cursor_test2 (value int)") + conn.exec("insert into cursor_test2 select generate_series(1, 10)") + + values = [] of Int32 + conn.transaction do + conn.cursor("select value from cursor_test2 order by value") do |cursor| + cursor.fetch_all do |rs| + values << rs.read(Int32) + end + end + end + + values.should eq((1..10).to_a) + ensure + with_connection &.exec("drop table if exists cursor_test2") + end + end + + it "supports cursor movement" do + with_connection do |conn| + conn.exec("create table cursor_test3 (id int)") + conn.exec("insert into cursor_test3 select generate_series(1, 5)") + + conn.transaction do + cursor = conn.cursor("select id from cursor_test3 order by id") + + # Move forward + cursor.move(2) + cursor.fetch(1) do |rs| + rs.read(Int32).should eq(3) + end + + # Move backward + cursor.move(-1) + cursor.fetch(1) do |rs| + rs.read(Int32).should eq(3) + end + + # Move to first + cursor.move_first + cursor.fetch(1) do |rs| + rs.read(Int32).should eq(2) # After MOVE ABSOLUTE 1, FETCH gets the next row + end + + # Move to last + cursor.move_last + cursor.fetch(1) do |rs| + rs.read(Int32).should eq(5) + end + + cursor.close + end + ensure + with_connection &.exec("drop table if exists cursor_test3") + end + end + + it "works with parameterized queries" do + with_connection do |conn| + conn.exec("create table cursor_test4 (id int, category text)") + conn.exec("insert into cursor_test4 values (1, 'A'), (2, 'B'), (3, 'A'), (4, 'B'), (5, 'A')") + + ids = [] of Int32 + conn.transaction do + conn.cursor("select id from cursor_test4 where category = $1 order by id", "A") do |cursor| + cursor.fetch_all do |rs| + ids << rs.read(Int32) + end + end + end + + ids.should eq([1, 3, 5]) + ensure + with_connection &.exec("drop table if exists cursor_test4") + end + end + + it "automatically starts transaction if needed" do + with_connection do |conn| + conn.exec("create table cursor_test5 (id int)") + conn.exec("insert into cursor_test5 values (42)") + + # Should start a transaction automatically + cursor = conn.cursor("select id from cursor_test5") + cursor.fetch(1) do |rs| + rs.read(Int32).should eq(42) + end + cursor.close + + # Verify we're still in a transaction + conn.in_transaction?.should be_true + + # End the transaction + conn.exec("COMMIT") + ensure + with_connection &.exec("drop table if exists cursor_test5") + end + end + + it "raises on closed cursor operations" do + with_connection do |conn| + conn.transaction do + cursor = conn.cursor("select 1") + cursor.close + + expect_raises(DB::Error, "Cursor is closed") do + cursor.fetch(1) { } + end + + expect_raises(DB::Error, "Cursor is closed") do + cursor.move(1) + end + end + end + end +end diff --git a/spec/pg/network_spec.cr b/spec/pg/network_spec.cr new file mode 100644 index 00000000..51cbeac3 --- /dev/null +++ b/spec/pg/network_spec.cr @@ -0,0 +1,84 @@ +require "../spec_helper" +require "../../src/pg/network" + +private def assert_roundtrip(x, sql_type, as cls) + PG_DB.exec("create table test (a #{sql_type})") + PG_DB.exec("insert into test values ($1)", x) + value = PG_DB.query_one("select a from test", as: cls) + value.should eq(x) +ensure + PG_DB.exec("drop table test") rescue nil +end + +describe "PG::Network" do + describe "Inet" do + it "decodes inet ipv4" do + assert_roundtrip(PG::Network::Inet.new("192.168.1.1"), "inet", PG::Network::Inet) + end + + it "decodes inet ipv4 with netmask" do + assert_roundtrip(PG::Network::Inet.new("192.168.1.0", 24), "inet", PG::Network::Inet) + end + + it "decodes inet ipv6" do + assert_roundtrip(PG::Network::Inet.new("2001:db8::1"), "inet", PG::Network::Inet) + end + + it "decodes inet ipv6 with netmask" do + assert_roundtrip(PG::Network::Inet.new("2001:db8::", 64), "inet", PG::Network::Inet) + end + + it "roundtrips with to_s" do + inet = PG::Network::Inet.new("192.168.1.1", 24) + inet.to_s.should eq("192.168.1.1/24") + + inet2 = PG::Network::Inet.new("192.168.1.1", 32) + inet2.to_s.should eq("192.168.1.1") + end + end + + describe "Cidr" do + it "decodes cidr ipv4" do + assert_roundtrip(PG::Network::Cidr.new("192.168.0.0", 24), "cidr", PG::Network::Cidr) + end + + it "decodes cidr ipv6" do + assert_roundtrip(PG::Network::Cidr.new("2001:db8::", 32), "cidr", PG::Network::Cidr) + end + + it "roundtrips with to_s" do + cidr = PG::Network::Cidr.new("10.0.0.0", 8) + cidr.to_s.should eq("10.0.0.0/8") + end + end + + describe "MacAddr" do + it "decodes macaddr" do + mac = PG::Network::MacAddr.new("08:00:2b:01:02:03") + assert_roundtrip(mac, "macaddr", PG::Network::MacAddr) + end + + it "accepts different formats" do + mac1 = PG::Network::MacAddr.new("08:00:2b:01:02:03") + mac2 = PG::Network::MacAddr.new("08-00-2b-01-02-03") + + mac1.to_s.should eq("08:00:2b:01:02:03") + mac2.to_s.should eq("08:00:2b:01:02:03") + end + end + + describe "MacAddr8" do + it "decodes macaddr8" do + mac = PG::Network::MacAddr8.new("08:00:2b:01:02:03:04:05") + assert_roundtrip(mac, "macaddr8", PG::Network::MacAddr8) + end + + it "accepts different formats" do + mac1 = PG::Network::MacAddr8.new("08:00:2b:01:02:03:04:05") + mac2 = PG::Network::MacAddr8.new("08-00-2b-01-02-03-04-05") + + mac1.to_s.should eq("08:00:2b:01:02:03:04:05") + mac2.to_s.should eq("08:00:2b:01:02:03:04:05") + end + end +end diff --git a/spec/pg/range_spec.cr b/spec/pg/range_spec.cr new file mode 100644 index 00000000..3f2052d0 --- /dev/null +++ b/spec/pg/range_spec.cr @@ -0,0 +1,123 @@ +require "../spec_helper" +require "../../src/pg/range" + +private def assert_roundtrip(x, sql_type, as cls) + PG_DB.exec("create table test (a #{sql_type})") + PG_DB.exec("insert into test values ($1)", x) + value = PG_DB.query_one("select a from test", as: cls) + value.should eq(x) +ensure + PG_DB.exec("drop table test") rescue nil +end + +describe "PG::Range" do + describe "Int4Range" do + it "decodes int4range with inclusive bounds" do + # PostgreSQL canonicalizes [1,10] to [1,11) + range = PG::Range(Int32).new(1, 11, true, false) + assert_roundtrip(range, "int4range", PG::Int4Range) + end + + it "decodes int4range with exclusive bounds" do + # PostgreSQL canonicalizes (1,10) to [2,10) + range = PG::Range(Int32).new(2, 10, true, false) + assert_roundtrip(range, "int4range", PG::Int4Range) + end + + it "decodes empty int4range" do + range = PG::Range(Int32).empty + assert_roundtrip(range, "int4range", PG::Int4Range) + end + + it "decodes int4range with infinite bounds" do + # PostgreSQL canonicalizes [,10) to (,10) + range = PG::Range(Int32).new(nil, 10, false, false) + assert_roundtrip(range, "int4range", PG::Int4Range) + + range2 = PG::Range(Int32).new(1, nil, true, false) + assert_roundtrip(range2, "int4range", PG::Int4Range) + end + end + + describe "Int8Range" do + it "decodes int8range" do + range = PG::Range(Int64).new(1_i64, 1000000_i64, true, false) + assert_roundtrip(range, "int8range", PG::Int8Range) + end + end + + describe "NumRange" do + it "decodes numrange" do + # Test with simpler numerics that PostgreSQL won't change + PG_DB.exec("create table test (a numrange)") + PG_DB.exec("insert into test values (numrange(1.0, 10.0))") + value = PG_DB.query_one("select a from test", as: PG::NumRange) + + value.lower.not_nil!.to_s.should eq("1.0") + value.upper.not_nil!.to_s.should eq("10.0") + value.lower_inclusive.should be_true + value.upper_inclusive.should be_false + value.empty?.should be_false + ensure + PG_DB.exec("drop table test") rescue nil + end + end + + describe "TsRange" do + it "decodes tsrange" do + t1 = Time.utc(2020, 1, 1, 12, 0, 0) + t2 = Time.utc(2020, 12, 31, 23, 59, 59) + range = PG::Range(Time).new(t1, t2, true, true) + assert_roundtrip(range, "tsrange", PG::TsRange) + end + end + + describe "TsTzRange" do + it "decodes tstzrange" do + t1 = Time.utc(2020, 1, 1, 12, 0, 0) + t2 = Time.utc(2020, 12, 31, 23, 59, 59) + range = PG::Range(Time).new(t1, t2, true, false) + assert_roundtrip(range, "tstzrange", PG::TsTzRange) + end + end + + describe "DateRange" do + it "decodes daterange" do + d1 = Time.utc(2020, 1, 1) + # PostgreSQL canonicalizes [2020-01-01,2020-12-31] to [2020-01-01,2021-01-01) + d2 = Time.utc(2021, 1, 1) + range = PG::Range(Time).new(d1, d2, true, false) + assert_roundtrip(range, "daterange", PG::DateRange) + end + end + + describe "to_s" do + it "formats inclusive range" do + range = PG::Range(Int32).new(1, 10, true, true) + range.to_s.should eq("[1,10]") + end + + it "formats exclusive range" do + range = PG::Range(Int32).new(1, 10, false, false) + range.to_s.should eq("(1,10)") + end + + it "formats mixed bounds" do + range = PG::Range(Int32).new(1, 10, true, false) + range.to_s.should eq("[1,10)") + end + + it "formats empty range" do + range = PG::Range(Int32).empty + range.to_s.should eq("empty") + end + + it "formats infinite ranges" do + range = PG::Range(Int32).new(nil, 10, true, false) + range.to_s.should eq("[,10)") + + range2 = PG::Range(Int32).new(1, nil, true, false) + range2.to_s.should eq("[1,)") + end + end +end diff --git a/src/pg.cr b/src/pg.cr index 112b5d59..a60be319 100644 --- a/src/pg.cr +++ b/src/pg.cr @@ -1,5 +1,6 @@ require "db" require "./pg/*" +require "./pg/cursor" module PG # Establish a connection to the database diff --git a/src/pg/connection.cr b/src/pg/connection.cr index 1f22202a..dbb1faf6 100644 --- a/src/pg/connection.cr +++ b/src/pg/connection.cr @@ -100,6 +100,26 @@ module PG {major: vers[0], minor: vers[1], patch: vers[2]? || 0} end + # Create a cursor for the given query + def cursor(query : String, *args, name : String? = nil) : Cursor + Cursor.new(self, query, args.to_a, name) + end + + # Create a cursor with block syntax - cursor is automatically closed + def cursor(query : String, *args, name : String? = nil, &) + cur = cursor(query, *args, name: name) + begin + yield cur + ensure + cur.close + end + end + + # Check if currently in a transaction + def in_transaction? + @connection.transaction_status != PQ::Frame::ReadyForQuery::Status::Idle + end + protected def do_close super diff --git a/src/pg/cursor.cr b/src/pg/cursor.cr new file mode 100644 index 00000000..65867014 --- /dev/null +++ b/src/pg/cursor.cr @@ -0,0 +1,128 @@ +module PG + # Represents a PostgreSQL cursor for fetching query results in batches + class Cursor + getter connection : Connection + getter query : String + getter args : Array(PQ::Param) + getter name : String + + @closed = false + @result_set : ResultSet? + + def initialize(@connection : Connection, @query : String, args = [] of DB::Any, name : String? = nil) + @name = name || "cursor_#{Time.utc.to_unix_ns}" + @args = args.map { |arg| PQ::Param.encode(arg).as(PQ::Param) } + + # Start a transaction if not already in one + unless connection.in_transaction? + connection.exec("BEGIN") + end + + # Declare the cursor + declare_cursor + end + + # Fetch a batch of rows from the cursor + def fetch(count : Int32 = 100, &) + check_closed + + result = connection.query("FETCH #{count} FROM #{escape_identifier(@name)}") + if result.move_next + yield result + end + result + ensure + result.try &.close + end + + # Fetch all remaining rows + def fetch_all(&) + check_closed + + connection.query("FETCH ALL FROM #{escape_identifier(@name)}") do |rs| + while rs.move_next + yield rs + end + end + end + + # Move the cursor position + def move(offset : Int32) + check_closed + + direction = offset >= 0 ? "FORWARD" : "BACKWARD" + connection.exec("MOVE #{direction} #{offset.abs} FROM #{escape_identifier(@name)}") + end + + # Move to absolute position + def move_absolute(position : Int32) + check_closed + + connection.exec("MOVE ABSOLUTE #{position} FROM #{escape_identifier(@name)}") + end + + # Move to first row + def move_first + move_absolute(1) + end + + # Move to last row + def move_last + check_closed + + connection.exec("MOVE LAST FROM #{escape_identifier(@name)}") + end + + # Close the cursor + def close + return if @closed + + begin + connection.exec("CLOSE #{escape_identifier(@name)}") + ensure + @closed = true + end + end + + # Check if cursor is closed + def closed? + @closed + end + + private def declare_cursor + if @args.empty? + connection.exec("DECLARE #{escape_identifier(@name)} CURSOR FOR #{@query}") + else + # For parameterized queries, we need to bind the values directly in the query + # because DECLARE CURSOR doesn't support parameter binding + # This is a limitation of PostgreSQL cursors + formatted_query = @query.dup + @args.each_with_index do |arg, i| + # Replace $1, $2, etc. with escaped literal values + value = if arg.size == -1 + "NULL" + else + connection.escape_literal(String.new(arg.slice)) + end + formatted_query = formatted_query.gsub("$#{i + 1}", value) + end + + connection.exec("DECLARE #{escape_identifier(@name)} CURSOR FOR #{formatted_query}") + end + end + + private def check_closed + raise DB::Error.new("Cursor is closed") if @closed + end + + private def escape_identifier(name : String) + connection.escape_identifier(name) + end + + def finalize + close unless @closed + rescue + # Ignore errors during finalization + end + end +end diff --git a/src/pg/decoder.cr b/src/pg/decoder.cr index 990d8fc7..0274ca2d 100644 --- a/src/pg/decoder.cr +++ b/src/pg/decoder.cr @@ -1,8 +1,10 @@ require "json" require "uuid" +require "./network" +require "./range" module PG - alias PGValue = String | Nil | Bool | Int32 | Float32 | Float64 | Time | JSON::Any | PG::Numeric | UUID + alias PGValue = String | Nil | Bool | Int32 | Float32 | Float64 | Time | JSON::Any | PG::Numeric | UUID | PG::Network::Inet | PG::Network::Cidr | PG::Network::MacAddr | PG::Network::MacAddr8 | PG::Int4Range | PG::Int8Range | PG::NumRange | PG::TsRange | PG::TsTzRange | PG::DateRange # :nodoc: module Decoders diff --git a/src/pg/decoders/network_decoder.cr b/src/pg/decoders/network_decoder.cr new file mode 100644 index 00000000..ab355e96 --- /dev/null +++ b/src/pg/decoders/network_decoder.cr @@ -0,0 +1,203 @@ +require "../network" + +module PG + module Decoders + # Helper method to format IPv6 addresses with proper compression + def self.format_ipv6(words : StaticArray(UInt16, 8)) : String + # Convert to array of hex strings + hex_parts = words.to_a.map { |w| w.to_s(16) } + + # Find longest sequence of zeros for compression + max_start = -1 + max_len = 0 + current_start = -1 + + 8.times do |i| + if words[i] == 0 + current_start = i if current_start == -1 + else + if current_start != -1 + len = i - current_start + if len > max_len && len > 1 + max_start = current_start + max_len = len + end + current_start = -1 + end + end + end + + # Check if zeros extend to the end + if current_start != -1 + len = 8 - current_start + if len > max_len && len > 1 + max_start = current_start + max_len = len + end + end + + # Build the compressed string + if max_start >= 0 + parts = [] of String + + # Add parts before compression + if max_start > 0 + parts.concat(hex_parts[0...max_start]) + end + + # Add empty string for :: + parts << "" + + # Add parts after compression + end_pos = max_start + max_len + if end_pos < 8 + parts << "" if max_start == 0 # Leading :: + parts.concat(hex_parts[end_pos...8]) + else + parts << "" # Trailing :: + end + + parts.join(":") + else + hex_parts.join(":") + end + end + + struct InetDecoder + include Decoder + + def_oids [ + 869, # inet + ] + + def decode(io, bytesize, oid) + family = io.read_byte.not_nil! + netmask = io.read_byte.not_nil! + is_cidr = io.read_byte.not_nil! == 1 + nb = io.read_byte.not_nil! + + if family == 2 # AF_INET (IPv4) + bytes = uninitialized UInt8[4] + io.read_fully(bytes.to_slice[0, nb]) + address = bytes.map(&.to_s).join('.') + elsif family == 3 # AF_INET6 (IPv6) + words = uninitialized UInt16[8] + (nb // 2).times do |i| + words[i] = read_u16(io) + end + ((nb // 2)...8).each do |i| + words[i] = 0_u16 + end + address = Decoders.format_ipv6(words) + else + raise "Unknown inet family: #{family}" + end + + PG::Network::Inet.new(address, netmask) + end + + def type + PG::Network::Inet + end + + private def read_u16(io) + io.read_bytes(UInt16, IO::ByteFormat::NetworkEndian) + end + end + + struct CidrDecoder + include Decoder + + def_oids [ + 650, # cidr + ] + + def decode(io, bytesize, oid) + family = io.read_byte.not_nil! + netmask = io.read_byte.not_nil! + is_cidr = io.read_byte.not_nil! == 1 + nb = io.read_byte.not_nil! + + if family == 2 # AF_INET (IPv4) + bytes = uninitialized UInt8[4] + io.read_fully(bytes.to_slice[0, nb]) + # For CIDR, normalize the address based on netmask + mask_bytes = (netmask / 8).to_i + mask_bits = netmask % 8 + + if mask_bytes < 4 + (mask_bytes...4).each { |i| bytes[i] = 0 } + end + if mask_bits > 0 && mask_bytes < 4 + bytes[mask_bytes] = (bytes[mask_bytes] & (0xFF_u8 << (8 - mask_bits))) + end + + address = bytes.map(&.to_s).join('.') + elsif family == 3 # AF_INET6 (IPv6) + words = uninitialized UInt16[8] + (nb // 2).times do |i| + words[i] = read_u16(io) + end + ((nb // 2)...8).each do |i| + words[i] = 0_u16 + end + address = Decoders.format_ipv6(words) + else + raise "Unknown cidr family: #{family}" + end + + PG::Network::Cidr.new(address, netmask) + end + + def type + PG::Network::Cidr + end + + private def read_u16(io) + io.read_bytes(UInt16, IO::ByteFormat::NetworkEndian) + end + end + + struct MacAddrDecoder + include Decoder + + def_oids [ + 829, # macaddr + ] + + def decode(io, bytesize, oid) + bytes = uninitialized UInt8[6] + io.read_fully(bytes.to_slice) + PG::Network::MacAddr.new(bytes) + end + + def type + PG::Network::MacAddr + end + end + + struct MacAddr8Decoder + include Decoder + + def_oids [ + 774, # macaddr8 + ] + + def decode(io, bytesize, oid) + bytes = uninitialized UInt8[8] + io.read_fully(bytes.to_slice) + PG::Network::MacAddr8.new(bytes) + end + + def type + PG::Network::MacAddr8 + end + end + + # Register the network decoders + register_decoder InetDecoder.new + register_decoder CidrDecoder.new + register_decoder MacAddrDecoder.new + register_decoder MacAddr8Decoder.new + end +end diff --git a/src/pg/decoders/range_decoder.cr b/src/pg/decoders/range_decoder.cr new file mode 100644 index 00000000..fed0b8d7 --- /dev/null +++ b/src/pg/decoders/range_decoder.cr @@ -0,0 +1,153 @@ +require "../range" + +module PG + module Decoders + abstract struct RangeDecoder(T, D) + include Decoder + + abstract def element_decoder : D + + def decode(io, bytesize, oid) + flags = io.read_byte.not_nil! + + # Check for empty range + if flags & 0x01 != 0 + return Range(T).empty + end + + # Read bounds + lower_infinite = (flags & 0x08) != 0 + upper_infinite = (flags & 0x10) != 0 + lower_inclusive = (flags & 0x02) != 0 + upper_inclusive = (flags & 0x04) != 0 + + # Read lower bound + lower = if lower_infinite + nil + else + read_element(io) + end + + # Read upper bound + upper = if upper_infinite + nil + else + read_element(io) + end + + Range(T).new(lower, upper, lower_inclusive, upper_inclusive) + end + + private def read_element(io) + element_size = read_i32(io) + if element_size == -1 + nil + else + element_decoder.decode(io, element_size, 0) + end + end + + def type + Range(T) + end + end + + struct Int4RangeDecoder < RangeDecoder(Int32, Int32Decoder) + def_oids [ + 3904, # int4range + ] + + def element_decoder : Int32Decoder + Int32Decoder.new + end + end + + struct Int8RangeDecoder < RangeDecoder(Int64, Int64Decoder) + def_oids [ + 3926, # int8range + ] + + def element_decoder : Int64Decoder + Int64Decoder.new + end + end + + struct NumRangeDecoder < RangeDecoder(Numeric, NumericDecoder) + def_oids [ + 3906, # numrange + ] + + def element_decoder : NumericDecoder + NumericDecoder.new + end + end + + struct TsRangeDecoder < RangeDecoder(Time, TimeDecoder) + def_oids [ + 3908, # tsrange + ] + + def element_decoder : TimeDecoder + TimeDecoder.new + end + + private def read_element(io) + element_size = read_i32(io) + if element_size == -1 + nil + else + # tsrange uses timestamp (not timestamptz) + element_decoder.decode(io, element_size, 1114) + end + end + end + + struct TsTzRangeDecoder < RangeDecoder(Time, TimeDecoder) + def_oids [ + 3910, # tstzrange + ] + + def element_decoder : TimeDecoder + TimeDecoder.new + end + + private def read_element(io) + element_size = read_i32(io) + if element_size == -1 + nil + else + # tstzrange uses timestamptz + element_decoder.decode(io, element_size, 1184) + end + end + end + + struct DateRangeDecoder < RangeDecoder(Time, TimeDecoder) + def_oids [ + 3912, # daterange + ] + + def element_decoder : TimeDecoder + TimeDecoder.new + end + + private def read_element(io) + element_size = read_i32(io) + if element_size == -1 + nil + else + # daterange uses date + element_decoder.decode(io, element_size, 1082) + end + end + end + + # Register the range decoders + register_decoder Int4RangeDecoder.new + register_decoder Int8RangeDecoder.new + register_decoder NumRangeDecoder.new + register_decoder TsRangeDecoder.new + register_decoder TsTzRangeDecoder.new + register_decoder DateRangeDecoder.new + end +end diff --git a/src/pg/network.cr b/src/pg/network.cr new file mode 100644 index 00000000..2a87c175 --- /dev/null +++ b/src/pg/network.cr @@ -0,0 +1,99 @@ +module PG::Network + # Represents PostgreSQL inet type - IP address with optional netmask + struct Inet + getter address : String + getter netmask : UInt8 + + def initialize(@address : String, @netmask : UInt8 = 32) + if @address.includes?(':') + @netmask = 128 if @netmask == 32 + end + end + + def to_s(io) + io << @address + if (@address.includes?(':') && @netmask != 128) || (!@address.includes?(':') && @netmask != 32) + io << '/' << @netmask + end + end + + def ipv4? + @address.includes?('.') + end + + def ipv6? + @address.includes?(':') + end + end + + # Represents PostgreSQL cidr type - IP network + struct Cidr + getter address : String + getter netmask : UInt8 + + def initialize(@address : String, @netmask : UInt8) + end + + def to_s(io) + io << @address << '/' << @netmask + end + + def ipv4? + @address.includes?('.') + end + + def ipv6? + @address.includes?(':') + end + end + + # Represents PostgreSQL macaddr type - 6-byte MAC address + struct MacAddr + getter bytes : StaticArray(UInt8, 6) + + def initialize(@bytes : StaticArray(UInt8, 6)) + end + + def initialize(str : String) + parts = str.split(/[:-]/) + raise ArgumentError.new("Invalid MAC address format") unless parts.size == 6 + + bytes = StaticArray(UInt8, 6).new do |i| + parts[i].to_u8(16) + end + @bytes = bytes + end + + def to_s(io) + @bytes.each_with_index do |byte, i| + io << ':' if i > 0 + byte.to_s(io, base: 16, precision: 2) + end + end + end + + # Represents PostgreSQL macaddr8 type - 8-byte MAC address + struct MacAddr8 + getter bytes : StaticArray(UInt8, 8) + + def initialize(@bytes : StaticArray(UInt8, 8)) + end + + def initialize(str : String) + parts = str.split(/[:-]/) + raise ArgumentError.new("Invalid MAC address format") unless parts.size == 8 + + bytes = StaticArray(UInt8, 8).new do |i| + parts[i].to_u8(16) + end + @bytes = bytes + end + + def to_s(io) + @bytes.each_with_index do |byte, i| + io << ':' if i > 0 + byte.to_s(io, base: 16, precision: 2) + end + end + end +end diff --git a/src/pg/numeric.cr b/src/pg/numeric.cr index 1e4f4c6b..023e52fc 100644 --- a/src/pg/numeric.cr +++ b/src/pg/numeric.cr @@ -135,5 +135,16 @@ module PG (dscale - count).times { io << '0' } end + + def ==(other : Numeric) + return true if nan? && other.nan? + return false if nan? || other.nan? + + sign == other.sign && + ndigits == other.ndigits && + weight == other.weight && + dscale == other.dscale && + digits == other.digits + end end end diff --git a/src/pg/range.cr b/src/pg/range.cr new file mode 100644 index 00000000..3b8e06b6 --- /dev/null +++ b/src/pg/range.cr @@ -0,0 +1,72 @@ +module PG + # Represents a PostgreSQL range type + # Ranges can be bounded or unbounded, inclusive or exclusive + struct Range(T) + getter lower : T? + getter upper : T? + getter lower_inclusive : Bool + getter upper_inclusive : Bool + getter? empty : Bool + + def initialize(@lower : T?, @upper : T?, @lower_inclusive : Bool = true, @upper_inclusive : Bool = false, @empty : Bool = false) + end + + def self.empty + new(nil, nil, true, false, true) + end + + def lower_bound + if @lower_inclusive + '[' + else + '(' + end + end + + def upper_bound + if @upper_inclusive + ']' + else + ')' + end + end + + def infinite_lower? + @lower.nil? && !@empty + end + + def infinite_upper? + @upper.nil? && !@empty + end + + def to_s(io) + if @empty + io << "empty" + else + io << lower_bound + io << @lower if @lower + io << ',' + io << @upper if @upper + io << upper_bound + end + end + + def ==(other : Range) + return true if @empty && other.empty? + return false if @empty != other.empty? + + @lower == other.lower && + @upper == other.upper && + @lower_inclusive == other.lower_inclusive && + @upper_inclusive == other.upper_inclusive + end + end + + # Type aliases for specific PostgreSQL range types + alias Int4Range = Range(Int32) + alias Int8Range = Range(Int64) + alias NumRange = Range(Numeric) + alias TsRange = Range(Time) + alias TsTzRange = Range(Time) + alias DateRange = Range(Time) +end diff --git a/src/pq/connection.cr b/src/pq/connection.cr index c226ddd4..7556a4c0 100644 --- a/src/pq/connection.cr +++ b/src/pq/connection.cr @@ -19,6 +19,9 @@ module PQ property notification_handler = Proc(Notification, Void).new { } @mutex = Mutex.new @established = false + @transaction_status = Frame::ReadyForQuery::Status::Idle + + getter transaction_status def initialize(@conninfo : ConnInfo) begin @@ -174,6 +177,9 @@ module PQ def read(frame_type) frame = read_one_frame(frame_type) + if frame.is_a?(Frame::ReadyForQuery) + @transaction_status = frame.transaction_status + end handle_async_frames(frame) ? read : frame end diff --git a/src/pq/param.cr b/src/pq/param.cr index 5232d817..f07e7e0a 100644 --- a/src/pq/param.cr +++ b/src/pq/param.cr @@ -1,4 +1,6 @@ require "../pg/geo" +require "../pg/network" +require "../pg/range" module PQ # :nodoc: @@ -77,6 +79,26 @@ module PQ text "#{val.months} months #{val.days} days #{val.microseconds} microseconds" end + def self.encode(val : PG::Network::Inet) + text val.to_s + end + + def self.encode(val : PG::Network::Cidr) + text val.to_s + end + + def self.encode(val : PG::Network::MacAddr) + text val.to_s + end + + def self.encode(val : PG::Network::MacAddr8) + text val.to_s + end + + def self.encode(val : PG::Range) + text val.to_s + end + def self.encode(val) text val.to_s end