From fc45d09177d4d12f95ffbe800fa5202143e6dcc0 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 18 Apr 2026 19:34:59 +0000 Subject: [PATCH] build(deps): bump github.com/jackc/pgx/v5 Bumps the go_modules group with 1 update in the / directory: [github.com/jackc/pgx/v5](https://github.com/jackc/pgx). Updates `github.com/jackc/pgx/v5` from 5.7.6 to 5.9.0 - [Changelog](https://github.com/jackc/pgx/blob/master/CHANGELOG.md) - [Commits](https://github.com/jackc/pgx/compare/v5.7.6...v5.9.0) --- updated-dependencies: - dependency-name: github.com/jackc/pgx/v5 dependency-version: 5.9.0 dependency-type: direct:production dependency-group: go_modules ... Signed-off-by: dependabot[bot] --- go.mod | 2 +- go.sum | 4 +- vendor/github.com/jackc/pgx/v5/.golangci.yml | 9 +- vendor/github.com/jackc/pgx/v5/CHANGELOG.md | 68 ++ vendor/github.com/jackc/pgx/v5/CLAUDE.md | 73 ++ .../github.com/jackc/pgx/v5/CONTRIBUTING.md | 26 +- vendor/github.com/jackc/pgx/v5/README.md | 5 +- vendor/github.com/jackc/pgx/v5/batch.go | 71 +- vendor/github.com/jackc/pgx/v5/conn.go | 54 +- .../github.com/jackc/pgx/v5/derived_types.go | 6 +- .../pgx/v5/internal/iobufpool/iobufpool.go | 38 +- .../jackc/pgx/v5/internal/pgio/write.go | 22 +- .../sanitize/{benchmmark.sh => benchmark.sh} | 2 +- .../pgx/v5/internal/stmtcache/lru_cache.go | 150 +++- .../v5/internal/stmtcache/unlimited_cache.go | 77 -- .../jackc/pgx/v5/pgconn/auth_oauth.go | 67 ++ .../jackc/pgx/v5/pgconn/auth_scram.go | 169 +++- .../github.com/jackc/pgx/v5/pgconn/config.go | 96 ++- .../pgx/v5/pgconn/ctxwatch/context_watcher.go | 46 +- .../github.com/jackc/pgx/v5/pgconn/errors.go | 17 + .../github.com/jackc/pgx/v5/pgconn/pgconn.go | 748 ++++++++++++++---- .../pgx/v5/pgproto3/authentication_sasl.go | 1 + .../jackc/pgx/v5/pgproto3/backend.go | 10 +- .../jackc/pgx/v5/pgproto3/backend_key_data.go | 33 +- .../github.com/jackc/pgx/v5/pgproto3/bind.go | 8 +- .../jackc/pgx/v5/pgproto3/cancel_request.go | 45 +- .../pgx/v5/pgproto3/copy_both_response.go | 2 +- .../jackc/pgx/v5/pgproto3/copy_fail.go | 4 + .../jackc/pgx/v5/pgproto3/copy_in_response.go | 2 +- .../pgx/v5/pgproto3/copy_out_response.go | 2 +- .../jackc/pgx/v5/pgproto3/data_row.go | 7 +- .../jackc/pgx/v5/pgproto3/frontend.go | 7 +- .../jackc/pgx/v5/pgproto3/function_call.go | 27 +- .../pgx/v5/pgproto3/function_call_response.go | 4 +- .../v5/pgproto3/negotiate_protocol_version.go | 93 +++ .../pgx/v5/pgproto3/parameter_description.go | 2 +- .../github.com/jackc/pgx/v5/pgproto3/parse.go | 2 +- .../github.com/jackc/pgx/v5/pgproto3/query.go | 4 + .../jackc/pgx/v5/pgproto3/row_description.go | 2 +- .../pgx/v5/pgproto3/sasl_initial_response.go | 3 + .../jackc/pgx/v5/pgproto3/startup_message.go | 10 +- .../github.com/jackc/pgx/v5/pgproto3/trace.go | 4 +- .../github.com/jackc/pgx/v5/pgtype/array.go | 12 +- .../jackc/pgx/v5/pgtype/array_codec.go | 12 +- .../jackc/pgx/v5/pgtype/builtin_wrappers.go | 2 +- .../jackc/pgx/v5/pgtype/composite.go | 6 +- .../github.com/jackc/pgx/v5/pgtype/convert.go | 28 +- vendor/github.com/jackc/pgx/v5/pgtype/date.go | 118 ++- .../github.com/jackc/pgx/v5/pgtype/hstore.go | 9 +- vendor/github.com/jackc/pgx/v5/pgtype/int.go | 4 +- .../jackc/pgx/v5/pgtype/interval.go | 4 +- vendor/github.com/jackc/pgx/v5/pgtype/json.go | 2 +- .../jackc/pgx/v5/pgtype/multirange.go | 11 +- .../github.com/jackc/pgx/v5/pgtype/numeric.go | 49 +- vendor/github.com/jackc/pgx/v5/pgtype/path.go | 2 +- .../github.com/jackc/pgx/v5/pgtype/pgtype.go | 84 +- .../jackc/pgx/v5/pgtype/pgtype_default.go | 3 + .../github.com/jackc/pgx/v5/pgtype/polygon.go | 2 +- .../jackc/pgx/v5/pgtype/timestamp.go | 10 +- .../jackc/pgx/v5/pgtype/timestamptz.go | 6 +- .../jackc/pgx/v5/pgtype/tsvector.go | 507 ++++++++++++ .../github.com/jackc/pgx/v5/pgxpool/pool.go | 56 +- vendor/github.com/jackc/pgx/v5/rows.go | 10 +- vendor/github.com/jackc/pgx/v5/test.sh | 170 ++++ vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go | 77 -- vendor/modules.txt | 5 +- 66 files changed, 2561 insertions(+), 650 deletions(-) create mode 100644 vendor/github.com/jackc/pgx/v5/CLAUDE.md rename vendor/github.com/jackc/pgx/v5/internal/sanitize/{benchmmark.sh => benchmark.sh} (97%) delete mode 100644 vendor/github.com/jackc/pgx/v5/internal/stmtcache/unlimited_cache.go create mode 100644 vendor/github.com/jackc/pgx/v5/pgconn/auth_oauth.go create mode 100644 vendor/github.com/jackc/pgx/v5/pgproto3/negotiate_protocol_version.go create mode 100644 vendor/github.com/jackc/pgx/v5/pgtype/tsvector.go create mode 100644 vendor/github.com/jackc/pgx/v5/test.sh delete mode 100644 vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go diff --git a/go.mod b/go.mod index 010ac20..af2e02f 100644 --- a/go.mod +++ b/go.mod @@ -19,7 +19,7 @@ require ( github.com/go-playground/validator/v10 v10.28.0 github.com/google/uuid v1.6.0 // indirect github.com/jackc/pgerrcode v0.0.0-20250907135507-afb5586c32a6 - github.com/jackc/pgx/v5 v5.7.6 + github.com/jackc/pgx/v5 v5.9.0 github.com/oklog/ulid/v2 v2.1.1 github.com/testcontainers/testcontainers-go v0.40.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 diff --git a/go.sum b/go.sum index 5ff6fcc..cf8f55e 100644 --- a/go.sum +++ b/go.sum @@ -81,8 +81,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk= -github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= +github.com/jackc/pgx/v5 v5.9.0 h1:T/dI+2TvmI2H8s/KH1/lXIbz1CUFk3gn5oTjr0/mBsE= +github.com/jackc/pgx/v5 v5.9.0/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= diff --git a/vendor/github.com/jackc/pgx/v5/.golangci.yml b/vendor/github.com/jackc/pgx/v5/.golangci.yml index ca74c70..d0903ab 100644 --- a/vendor/github.com/jackc/pgx/v5/.golangci.yml +++ b/vendor/github.com/jackc/pgx/v5/.golangci.yml @@ -1,9 +1,14 @@ # See for configurations: https://golangci-lint.run/usage/configuration/ -version: 2 +version: "2" + +linters: + default: none + enable: + - govet + - ineffassign # See: https://golangci-lint.run/usage/formatters/ formatters: - default: none enable: - gofmt # https://pkg.go.dev/cmd/gofmt - gofumpt # https://github.com/mvdan/gofumpt diff --git a/vendor/github.com/jackc/pgx/v5/CHANGELOG.md b/vendor/github.com/jackc/pgx/v5/CHANGELOG.md index 6c9c99b..49efb67 100644 --- a/vendor/github.com/jackc/pgx/v5/CHANGELOG.md +++ b/vendor/github.com/jackc/pgx/v5/CHANGELOG.md @@ -1,3 +1,71 @@ +# 5.9.0 (March 21, 2026) + +This release includes a number of new features such as SCRAM-SHA-256-PLUS support, OAuth authentication support, and +PostgreSQL protocol 3.2 support. + +It significantly reduces the amount of network traffic when using prepared statements (which are used automatically by +default) by avoiding unnecessary Describe Portal messages. This also reduces local memory usage. + +It also includes multiple fixes for potential DoS due to panic or OOM if connected to a malicious server that sends +deliberately malformed messages. + +* Require Go 1.25+ +* Add SCRAM-SHA-256-PLUS support (Adam Brightwell) +* Add OAuth authentication support for PostgreSQL 18 (David Schneider) +* Add PostgreSQL protocol 3.2 support (Dirkjan Bussink) +* Add tsvector type support (Adam Brightwell) +* Skip Describe Portal for cached prepared statements reducing network round trips +* Make LoadTypes query easier to support on "postgres-like" servers (Jelte Fennema-Nio) +* Default empty user to current OS user matching libpq behavior (ShivangSrivastava) +* Optimize LRU statement cache with custom linked list and node pooling (Mathias Bogaert) +* Optimize date scanning by replacing regex with manual parsing (Mathias Bogaert) +* Optimize pgio append/set functions with direct byte shifts (Mathias Bogaert) +* Make RowsAffected faster (Abhishek Chanda) +* Fix: Pipeline.Close panic when server sends multiple FATAL errors (Varun Chawla) +* Fix: ContextWatcher goroutine leak (Hank Donnay) +* Fix: stdlib discard connections with open transactions in ResetSession (Jeremy Schneider) +* Fix: pipelineBatchResults.Exec silently swallowing lastRows error +* Fix: ColumnTypeLength using BPCharArrayOID instead of BPCharOID +* Fix: TSVector text encoding returning nil for valid empty tsvector +* Fix: wrong error messages for Int2 and Int4 underflow +* Fix: Numeric nil Int pointer dereference with Valid: true +* Fix: reversed strings.ContainsAny arguments in Numeric.ScanScientific +* Fix: message length parsing on 32-bit platforms +* Fix: FunctionCallResponse.Decode mishandling of signed result size +* Fix: returning wrong error in configTLS when DecryptPEMBlock fails (Maxim Motyshen) +* Fix: misleading ParseConfig error when default_query_exec_mode is invalid (Skarm) +* Fix: missed Unwatch in Pipeline error paths +* Clarify too many failed acquire attempts error message +* Better error wrapping with context and SQL statement (Aneesh Makala) +* Enable govet and ineffassign linters (Federico Guerinoni) +* Guard against various malformed binary messages (arrays, hstore, multirange, protocol messages) +* Fix various godoc comments (ferhat elmas) +* Fix typos in comments (Oleksandr Redko) + +# 5.8.0 (December 26, 2025) + +* Require Go 1.24+ +* Remove golang.org/x/crypto dependency +* Add OptionShouldPing to control ResetSession ping behavior (ilyam8) +* Fix: Avoid overflow when MaxConns is set to MaxInt32 +* Fix: Close batch pipeline after a query error (Anthonin Bonnefoy) +* Faster shutdown of pgxpool.Pool background goroutines (Blake Gentry) +* Add pgxpool ping timeout (Amirsalar Safaei) +* Fix: Rows.FieldDescriptions for empty query +* Scan unknown types into *any as string or []byte based on format code +* Optimize pgtype.Numeric (Philip Dubé) +* Add AfterNetConnect hook to pgconn.Config +* Fix: Handle for preparing statements that fail during the Describe phase +* Fix overflow in numeric scanning (Ilia Demianenko) +* Fix: json/jsonb sql.Scanner source type is []byte +* Migrate from math/rand to math/rand/v2 (Mathias Bogaert) +* Optimize internal iobufpool (Mathias Bogaert) +* Optimize stmtcache invalidation (Mathias Bogaert) +* Fix: missing error case in interval parsing (Maxime Soulé) +* Fix: invalidate statement/description cache in Exec (James Hartig) +* ColumnTypeLength method return the type length for varbit type (DengChan) +* Array and Composite codecs handle typed nils + # 5.7.6 (September 8, 2025) * Use ParseConfigError in pgx.ParseConfig and pgxpool.ParseConfig (Yurasov Ilia) diff --git a/vendor/github.com/jackc/pgx/v5/CLAUDE.md b/vendor/github.com/jackc/pgx/v5/CLAUDE.md new file mode 100644 index 0000000..e3ed1a2 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/CLAUDE.md @@ -0,0 +1,73 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +pgx is a PostgreSQL driver and toolkit for Go (`github.com/jackc/pgx/v5`). It provides both a native PostgreSQL interface and a `database/sql` compatible driver. Requires Go 1.25+ and supports PostgreSQL 14+ and CockroachDB. + +## Build & Test Commands + +```bash +# Run all tests (requires PGX_TEST_DATABASE to be set) +go test ./... + +# Run a specific test +go test -run TestFunctionName ./... + +# Run tests for a specific package +go test ./pgconn/... + +# Run tests with race detector +go test -race ./... + +# DevContainer: run tests against specific PostgreSQL versions +./test.sh pg18 # Default: PostgreSQL 18 +./test.sh pg16 -run TestConnect # Specific test against PG16 +./test.sh crdb # CockroachDB +./test.sh all # All targets (pg14-18 + crdb) + +# Format (always run after making changes) +goimports -w . + +# Lint +golangci-lint run ./... +``` + +## Test Database Setup + +Tests require `PGX_TEST_DATABASE` environment variable. In the devcontainer, `test.sh` handles this. For local development: + +```bash +export PGX_TEST_DATABASE="host=localhost user=postgres password=postgres dbname=pgx_test" +``` + +The test database needs extensions: `hstore`, `ltree`, and a `uint64` domain. See `testsetup/postgresql_setup.sql` for full setup. Many tests are skipped unless additional `PGX_TEST_*` env vars are set (for TLS, SCRAM, MD5, unix socket, PgBouncer testing). + +## Architecture + +The codebase is a layered architecture, bottom-up: + +- **pgproto3/** — PostgreSQL wire protocol v3 encoder/decoder. Defines `FrontendMessage` and `BackendMessage` types for every protocol message. +- **pgconn/** — Low-level connection layer (roughly libpq-equivalent). Handles authentication, TLS, query execution, COPY protocol, and notifications. `PgConn` is the core type. +- **pgx** (root package) — High-level query interface built on `pgconn`. Provides `Conn`, `Rows`, `Tx`, `Batch`, `CopyFrom`, and generic helpers like `CollectRows`/`ForEachRow`. Includes automatic statement caching (LRU). +- **pgtype/** — Type system mapping between Go and PostgreSQL types (70+ types). Key interfaces: `Codec`, `Type`, `TypeMap`. Custom types (enums, composites, domains) are registered through `TypeMap`. +- **pgxpool/** — Concurrency-safe connection pool built on `puddle/v2`. `Pool` is the main type; wraps `pgx.Conn`. +- **stdlib/** — `database/sql` compatibility adapter. + +Supporting packages: +- **internal/stmtcache/** — Prepared statement cache with LRU eviction +- **internal/sanitize/** — SQL query sanitization +- **tracelog/** — Logging adapter that implements tracer interfaces +- **multitracer/** — Composes multiple tracers into one +- **pgxtest/** — Test helpers for running tests across connection types + +## Key Design Conventions + +- **Semantic versioning** — strictly followed. Do not break the public API (no removing or renaming exported types, functions, methods, or fields; no changing function signatures). +- **Minimal dependencies** — adding new dependencies is strongly discouraged (see CONTRIBUTING.md). +- **Context-based** — all blocking operations take `context.Context`. +- **Tracer interfaces** — observability via `QueryTracer`, `BatchTracer`, `CopyFromTracer`, `PrepareTracer` on `ConnConfig.Tracer`. +- **Formatting** — always run `goimports -w .` after making changes to ensure code is properly formatted. CI checks formatting via `gofmt -l -s -w . && git diff --exit-code`. `gofumpt` with extra rules is also enforced via `golangci-lint`. +- **Linters** — `govet` and `ineffassign` only (configured in `.golangci.yml`). +- **CI matrix** — tests run against Go 1.25/1.26 × PostgreSQL 14-18 + CockroachDB, on Linux and Windows. Race detector enabled on Linux only. diff --git a/vendor/github.com/jackc/pgx/v5/CONTRIBUTING.md b/vendor/github.com/jackc/pgx/v5/CONTRIBUTING.md index c975a93..2283ae6 100644 --- a/vendor/github.com/jackc/pgx/v5/CONTRIBUTING.md +++ b/vendor/github.com/jackc/pgx/v5/CONTRIBUTING.md @@ -10,6 +10,18 @@ proposal. This will help to ensure your proposed change has a reasonable chance Adding a dependency is a big deal. While on occasion a new dependency may be accepted, the default answer to any change that adds a dependency is no. +## AI + +Using AI is acceptable (not that it can really be stopped) under one the following conditions. + +* AI was used, but you deeply understand the code and you can answer questions regarding your change. You are not going + to answer questions with "I don't know", AI did it. You are not going to "answer" questions by relaying them to your + agent. This is wasteful of the code reviewer's time. +* AI was used to solve a problem without your deep understanding. This can still be a good starting point for a fix or + feature. But you need to clearly state that this is an AI proposal. You should include additional information such as + the AI used and what prompts were used. You should also be aware that large, complicated, or subtle changes may be + rejected simply because the reviewer is not confident in a change that no human understands. + ## Development Environment Setup pgx tests naturally require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_DATABASE` @@ -17,7 +29,12 @@ environment variable. The `PGX_TEST_DATABASE` environment variable can either be the standard `PG*` environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify environment variable handling. -### Using an Existing PostgreSQL Cluster +### Devcontainer + +The easiest way to start development is with the included devcontainer. It includes containers for each supported +PostgreSQL version as well as CockroachDB. `./test.sh all` will run the tests against all database types. + +### Using an Existing PostgreSQL Cluster Outside of a Devcontainer If you already have a PostgreSQL development server this is the quickest way to start and run the majority of the pgx test suite. Some tests will be skipped that require server configuration changes (e.g. those testing different @@ -49,7 +66,7 @@ go test ./... This will run the vast majority of the tests, but some tests will be skipped (e.g. those testing different connection methods). -### Creating a New PostgreSQL Cluster Exclusively for Testing +### Creating a New PostgreSQL Cluster Exclusively for Testing Outside of a Devcontainer The following environment variables need to be set both for initial setup and whenever the tests are run. (direnv is highly recommended). Depending on your platform, you may need to change the host for `PGX_TEST_UNIX_SOCKET_CONN_STRING`. @@ -63,10 +80,11 @@ export POSTGRESQL_DATA_DIR=postgresql export PGX_TEST_DATABASE="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret" export PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/private/tmp database=pgx_test" export PGX_TEST_TCP_CONN_STRING="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret" -export PGX_TEST_SCRAM_PASSWORD_CONN_STRING="host=127.0.0.1 user=pgx_scram password=secret database=pgx_test" +export PGX_TEST_SCRAM_PASSWORD_CONN_STRING="host=127.0.0.1 user=pgx_scram password=secret database=pgx_test channel_binding=disable" +export PGX_TEST_SCRAM_PLUS_CONN_STRING="host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem database=pgx_test channel_binding=require" export PGX_TEST_MD5_PASSWORD_CONN_STRING="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret" export PGX_TEST_PLAIN_PASSWORD_CONN_STRING="host=127.0.0.1 user=pgx_pw password=secret" -export PGX_TEST_TLS_CONN_STRING="host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem" +export PGX_TEST_TLS_CONN_STRING="host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem channel_binding=disable" export PGX_SSL_PASSWORD=certpw export PGX_TEST_TLS_CLIENT_CONN_STRING="host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem database=pgx_test sslcert=`pwd`/.testdb/pgx_sslcert.crt sslkey=`pwd`/.testdb/pgx_sslcert.key" ``` diff --git a/vendor/github.com/jackc/pgx/v5/README.md b/vendor/github.com/jackc/pgx/v5/README.md index cb709e2..aa35e4a 100644 --- a/vendor/github.com/jackc/pgx/v5/README.md +++ b/vendor/github.com/jackc/pgx/v5/README.md @@ -92,7 +92,7 @@ See the presentation at Golang Estonia, [PGX Top to Bottom](https://www.youtube. ## Supported Go and PostgreSQL Versions -pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.23 and higher and PostgreSQL 13 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/). +pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.25 and higher and PostgreSQL 14 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/). ## Version Policy @@ -120,6 +120,7 @@ pgerrcode contains constants for the PostgreSQL error codes. * [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid) * [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal) +* [github.com/ColeBurch/pgx-govalues-decimal](https://github.com/ColeBurch/pgx-govalues-decimal) * [github.com/twpayne/pgx-geos](https://github.com/twpayne/pgx-geos) ([PostGIS](https://postgis.net/) and [GEOS](https://libgeos.org/) via [go-geos](https://github.com/twpayne/go-geos)) * [github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid) @@ -186,6 +187,6 @@ Simple Golang implementation for transactional outbox pattern for PostgreSQL usi Simplifies working with the pgx library, providing convenient scanning of nested structures. -## [https://github.com/KoNekoD/pgx-colon-query-rewriter](https://github.com/KoNekoD/pgx-colon-query-rewriter) +### [https://github.com/KoNekoD/pgx-colon-query-rewriter](https://github.com/KoNekoD/pgx-colon-query-rewriter) Implementation of the pgx query rewriter to use ':' instead of '@' in named query parameters. diff --git a/vendor/github.com/jackc/pgx/v5/batch.go b/vendor/github.com/jackc/pgx/v5/batch.go index 1b1cbd8..702fcff 100644 --- a/vendor/github.com/jackc/pgx/v5/batch.go +++ b/vendor/github.com/jackc/pgx/v5/batch.go @@ -272,7 +272,7 @@ func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) { ok = true br.qqIdx++ } - return + return query, args, ok } type pipelineBatchResults struct { @@ -296,6 +296,7 @@ func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) { return pgconn.CommandTag{}, fmt.Errorf("batch already closed") } if br.lastRows != nil && br.lastRows.err != nil { + br.err = br.lastRows.err return pgconn.CommandTag{}, br.err } @@ -404,7 +405,6 @@ func (br *pipelineBatchResults) Close() error { if br.err == nil && br.lastRows != nil && br.lastRows.err != nil { br.err = br.lastRows.err - return br.err } if br.closed { @@ -451,6 +451,45 @@ func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, er return bi.SQL, bi.Arguments, nil } +type emptyBatchResults struct { + conn *Conn + closed bool +} + +// Exec reads the results from the next query in the batch as if the query has been sent with Exec. +func (br *emptyBatchResults) Exec() (pgconn.CommandTag, error) { + if br.closed { + return pgconn.CommandTag{}, fmt.Errorf("batch already closed") + } + return pgconn.CommandTag{}, errors.New("no more results in batch") +} + +// Query reads the results from the next query in the batch as if the query has been sent with Query. +func (br *emptyBatchResults) Query() (Rows, error) { + if br.closed { + alreadyClosedErr := fmt.Errorf("batch already closed") + return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr + } + + rows := br.conn.getRows(context.Background(), "", nil) + rows.err = errors.New("no more results in batch") + rows.closed = true + return rows, rows.err +} + +// QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow. +func (br *emptyBatchResults) QueryRow() Row { + rows, _ := br.Query() + return (*connRow)(rows.(*baseRows)) +} + +// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to +// resyncronize the connection with the server. In this case the underlying connection will have been closed. +func (br *emptyBatchResults) Close() error { + br.closed = true + return nil +} + // invalidates statement and description caches on batch results error func invalidateCachesOnBatchResultsError(conn *Conn, b *Batch, err error) { if err != nil && conn != nil && b != nil { @@ -467,3 +506,31 @@ func invalidateCachesOnBatchResultsError(conn *Conn, b *Batch, err error) { } } } + +// ErrPreprocessingBatch occurs when an error is encountered while preprocessing a batch. +// The two preprocessing steps are "prepare" (server-side SQL parse/plan) and +// "build" (client-side argument encoding). +type ErrPreprocessingBatch struct { + step string // "prepare" or "build" + sql string + err error +} + +func newErrPreprocessingBatch(step, sql string, err error) ErrPreprocessingBatch { + return ErrPreprocessingBatch{step: step, sql: sql, err: err} +} + +func (e ErrPreprocessingBatch) Error() string { + // intentionally not including the SQL query in the error message + // to avoid leaking potentially sensitive information into logs. + // If the user wants the SQL, they can call SQL(). + return fmt.Sprintf("error preprocessing batch (%s): %v", e.step, e.err) +} + +func (e ErrPreprocessingBatch) Unwrap() error { + return e.err +} + +func (e ErrPreprocessingBatch) SQL() string { + return e.sql +} diff --git a/vendor/github.com/jackc/pgx/v5/conn.go b/vendor/github.com/jackc/pgx/v5/conn.go index 67b2252..c52039b 100644 --- a/vendor/github.com/jackc/pgx/v5/conn.go +++ b/vendor/github.com/jackc/pgx/v5/conn.go @@ -65,11 +65,12 @@ func (cc *ConnConfig) ConnString() string { return cc.connString } // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. Use a connection pool to manage access // to multiple database connections from multiple goroutines. type Conn struct { - pgConn *pgconn.PgConn - config *ConnConfig // config used when establishing this connection - preparedStatements map[string]*pgconn.StatementDescription - statementCache stmtcache.Cache - descriptionCache stmtcache.Cache + pgConn *pgconn.PgConn + config *ConnConfig // config used when establishing this connection + preparedStatements map[string]*pgconn.StatementDescription + failedDescribeStatement string + statementCache stmtcache.Cache + descriptionCache stmtcache.Cache queryTracer QueryTracer batchTracer BatchTracer @@ -202,7 +203,9 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con case "simple_protocol": defaultQueryExecMode = QueryExecModeSimpleProtocol default: - return nil, pgconn.NewParseConfigError(connString, "invalid default_query_exec_mode", err) + return nil, pgconn.NewParseConfigError( + connString, "invalid default_query_exec_mode", fmt.Errorf("unknown value %q", s), + ) } } @@ -314,6 +317,14 @@ func (c *Conn) Close(ctx context.Context) error { // Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same name and sql arguments. This // allows a code path to Prepare and Query/Exec without concern for if the statement has already been prepared. func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) { + if c.failedDescribeStatement != "" { + err = c.Deallocate(ctx, c.failedDescribeStatement) + if err != nil { + return nil, fmt.Errorf("failed to deallocate previously failed statement %q: %w", c.failedDescribeStatement, err) + } + c.failedDescribeStatement = "" + } + if c.prepareTracer != nil { ctx = c.prepareTracer.TracePrepareStart(ctx, c, TracePrepareStartData{Name: name, SQL: sql}) } @@ -346,6 +357,10 @@ func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.Statem sd, err = c.pgConn.Prepare(ctx, psName, sql, nil) if err != nil { + var pErr *pgconn.PrepareError + if errors.As(err, &pErr) { + c.failedDescribeStatement = psKey + } return nil, err } @@ -502,6 +517,18 @@ optionLoop: mode = QueryExecModeSimpleProtocol } + defer func() { + if err != nil { + if sc := c.statementCache; sc != nil { + sc.Invalidate(sql) + } + + if sc := c.descriptionCache; sc != nil { + sc.Invalidate(sql) + } + } + }() + if sd, ok := c.preparedStatements[sql]; ok { return c.execPrepared(ctx, sd, arguments) } @@ -583,7 +610,7 @@ func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription return pgconn.CommandTag{}, err } - result := c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats).Read() + result := c.pgConn.ExecStatement(ctx, sd, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats).Read() c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. return result.CommandTag, result.Err } @@ -817,7 +844,7 @@ optionLoop: if !explicitPreparedStatement && mode == QueryExecModeCacheDescribe { rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, resultFormats) } else { - rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, resultFormats) + rows.resultReader = c.pgConn.ExecStatement(ctx, sd, c.eqb.ParamValues, c.eqb.ParamFormats, resultFormats) } } else if mode == QueryExecModeExec { err := c.eqb.Build(c.typeMap, nil, args) @@ -912,6 +939,10 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row { // Depending on the QueryExecMode, all queries may be prepared before any are executed. This means that creating a table // and using it in a subsequent query in the same batch can fail. func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) { + if len(b.QueuedQueries) == 0 { + return &emptyBatchResults{conn: c} + } + if c.batchTracer != nil { ctx = c.batchTracer.TraceBatchStart(ctx, c, TraceBatchStartData{Batch: b}) defer func() { @@ -1163,7 +1194,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d for _, sd := range distinctNewQueries { results, err := pipeline.GetResults() if err != nil { - return err + return newErrPreprocessingBatch("prepare", sd.SQL, err) } resultSD, ok := results.(*pgconn.StatementDescription) @@ -1197,15 +1228,14 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d for _, bi := range b.QueuedQueries { err := c.eqb.Build(c.typeMap, bi.sd, bi.Arguments) if err != nil { - // we wrap the error so we the user can understand which query failed inside the batch - err = fmt.Errorf("error building query %s: %w", bi.SQL, err) + err = newErrPreprocessingBatch("build", bi.SQL, err) return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } if bi.sd.Name == "" { pipeline.SendQueryParams(bi.sd.SQL, c.eqb.ParamValues, bi.sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats) } else { - pipeline.SendQueryPrepared(bi.sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) + pipeline.SendQueryStatement(bi.sd, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) } } diff --git a/vendor/github.com/jackc/pgx/v5/derived_types.go b/vendor/github.com/jackc/pgx/v5/derived_types.go index 72c0a24..89b9a77 100644 --- a/vendor/github.com/jackc/pgx/v5/derived_types.go +++ b/vendor/github.com/jackc/pgx/v5/derived_types.go @@ -24,7 +24,7 @@ func buildLoadDerivedTypesSQL(pgVersion int64, typeNames []string) string { // This should not occur; this will not return any types typeNamesClause = "= ''" } else { - typeNamesClause = "= ANY($1)" + typeNamesClause = "= ANY($1::text[])" } parts := make([]string, 0, 10) @@ -169,7 +169,7 @@ func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Typ // the SQL not support recent structures such as multirange serverVersion, _ := serverVersion(c) sql := buildLoadDerivedTypesSQL(serverVersion, typeNames) - rows, err := c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames) + rows, err := c.Query(ctx, sql, QueryResultFormats{TextFormatCode}, typeNames) if err != nil { return nil, fmt.Errorf("While generating load types query: %w", err) } @@ -227,7 +227,7 @@ func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Typ return nil, fmt.Errorf("Unknown typtype %q was found while registering %q", ti.Typtype, ti.TypeName) } - // the type_ is imposible to be null + // the type_ is impossible to be null m.RegisterType(type_) if ti.NspName != "" { nspType := &pgtype.Type{Name: ti.NspName + "." + type_.Name, OID: type_.OID, Codec: type_.Codec} diff --git a/vendor/github.com/jackc/pgx/v5/internal/iobufpool/iobufpool.go b/vendor/github.com/jackc/pgx/v5/internal/iobufpool/iobufpool.go index 89e0c22..abc41f6 100644 --- a/vendor/github.com/jackc/pgx/v5/internal/iobufpool/iobufpool.go +++ b/vendor/github.com/jackc/pgx/v5/internal/iobufpool/iobufpool.go @@ -4,7 +4,10 @@ // an allocation is purposely not documented. https://github.com/golang/go/issues/16323 package iobufpool -import "sync" +import ( + "math/bits" + "sync" +) const minPoolExpOf2 = 8 @@ -37,15 +40,14 @@ func Get(size int) *[]byte { } func getPoolIdx(size int) int { - size-- - size >>= minPoolExpOf2 - i := 0 - for size > 0 { - size >>= 1 - i++ + if size < 2 { + return 0 } - - return i + idx := bits.Len(uint(size-1)) - minPoolExpOf2 + if idx < 0 { + return 0 + } + return idx } // Put returns buf to the pool. @@ -59,12 +61,18 @@ func Put(buf *[]byte) { } func putPoolIdx(size int) int { - minPoolSize := 1 << minPoolExpOf2 - for i := range pools { - if size == minPoolSize<= len(pools) { + return -1 } - return -1 + return idx } diff --git a/vendor/github.com/jackc/pgx/v5/internal/pgio/write.go b/vendor/github.com/jackc/pgx/v5/internal/pgio/write.go index 96aedf9..3a6700d 100644 --- a/vendor/github.com/jackc/pgx/v5/internal/pgio/write.go +++ b/vendor/github.com/jackc/pgx/v5/internal/pgio/write.go @@ -1,26 +1,18 @@ package pgio -import "encoding/binary" - func AppendUint16(buf []byte, n uint16) []byte { - wp := len(buf) - buf = append(buf, 0, 0) - binary.BigEndian.PutUint16(buf[wp:], n) - return buf + return append(buf, byte(n>>8), byte(n)) } func AppendUint32(buf []byte, n uint32) []byte { - wp := len(buf) - buf = append(buf, 0, 0, 0, 0) - binary.BigEndian.PutUint32(buf[wp:], n) - return buf + return append(buf, byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) } func AppendUint64(buf []byte, n uint64) []byte { - wp := len(buf) - buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0) - binary.BigEndian.PutUint64(buf[wp:], n) - return buf + return append(buf, + byte(n>>56), byte(n>>48), byte(n>>40), byte(n>>32), + byte(n>>24), byte(n>>16), byte(n>>8), byte(n), + ) } func AppendInt16(buf []byte, n int16) []byte { @@ -36,5 +28,5 @@ func AppendInt64(buf []byte, n int64) []byte { } func SetInt32(buf []byte, n int32) { - binary.BigEndian.PutUint32(buf, uint32(n)) + *(*[4]byte)(buf) = [4]byte{byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n)} } diff --git a/vendor/github.com/jackc/pgx/v5/internal/sanitize/benchmmark.sh b/vendor/github.com/jackc/pgx/v5/internal/sanitize/benchmark.sh similarity index 97% rename from vendor/github.com/jackc/pgx/v5/internal/sanitize/benchmmark.sh rename to vendor/github.com/jackc/pgx/v5/internal/sanitize/benchmark.sh index ec0f7b0..b4ee3fe 100644 --- a/vendor/github.com/jackc/pgx/v5/internal/sanitize/benchmmark.sh +++ b/vendor/github.com/jackc/pgx/v5/internal/sanitize/benchmark.sh @@ -42,7 +42,7 @@ for i in "${!commits[@]}"; do exit 1 } - # Sanitized commmit message + # Sanitized commit message commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_') # Benchmark data will go there diff --git a/vendor/github.com/jackc/pgx/v5/internal/stmtcache/lru_cache.go b/vendor/github.com/jackc/pgx/v5/internal/stmtcache/lru_cache.go index 17fec93..b677d29 100644 --- a/vendor/github.com/jackc/pgx/v5/internal/stmtcache/lru_cache.go +++ b/vendor/github.com/jackc/pgx/v5/internal/stmtcache/lru_cache.go @@ -1,36 +1,54 @@ package stmtcache import ( - "container/list" - "github.com/jackc/pgx/v5/pgconn" ) +// lruNode is a typed doubly-linked list node with freelist support. +type lruNode struct { + sd *pgconn.StatementDescription + prev *lruNode + next *lruNode +} + // LRUCache implements Cache with a Least Recently Used (LRU) cache. type LRUCache struct { - cap int - m map[string]*list.Element - l *list.List + m map[string]*lruNode + head *lruNode + + tail *lruNode + len int + cap int + freelist *lruNode + invalidStmts []*pgconn.StatementDescription + invalidSet map[string]struct{} } // NewLRUCache creates a new LRUCache. cap is the maximum size of the cache. func NewLRUCache(cap int) *LRUCache { + head := &lruNode{} + tail := &lruNode{} + head.next = tail + tail.prev = head + return &LRUCache{ - cap: cap, - m: make(map[string]*list.Element), - l: list.New(), + cap: cap, + m: make(map[string]*lruNode, cap), + head: head, + tail: tail, + invalidSet: make(map[string]struct{}), } } // Get returns the statement description for sql. Returns nil if not found. func (c *LRUCache) Get(key string) *pgconn.StatementDescription { - if el, ok := c.m[key]; ok { - c.l.MoveToFront(el) - return el.Value.(*pgconn.StatementDescription) + node, ok := c.m[key] + if !ok { + return nil } - - return nil + c.moveToFront(node) + return node.sd } // Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache or @@ -45,39 +63,49 @@ func (c *LRUCache) Put(sd *pgconn.StatementDescription) { } // The statement may have been invalidated but not yet handled. Do not readd it to the cache. - for _, invalidSD := range c.invalidStmts { - if invalidSD.SQL == sd.SQL { - return - } + if _, invalidated := c.invalidSet[sd.SQL]; invalidated { + return } - if c.l.Len() == c.cap { + if c.len == c.cap { c.invalidateOldest() } - el := c.l.PushFront(sd) - c.m[sd.SQL] = el + node := c.allocNode() + node.sd = sd + c.insertAfter(c.head, node) + c.m[sd.SQL] = node + c.len++ } // Invalidate invalidates statement description identified by sql. Does nothing if not found. func (c *LRUCache) Invalidate(sql string) { - if el, ok := c.m[sql]; ok { - delete(c.m, sql) - c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription)) - c.l.Remove(el) + node, ok := c.m[sql] + if !ok { + return } + delete(c.m, sql) + c.invalidStmts = append(c.invalidStmts, node.sd) + c.invalidSet[sql] = struct{}{} + c.unlink(node) + c.len-- + c.freeNode(node) } // InvalidateAll invalidates all statement descriptions. func (c *LRUCache) InvalidateAll() { - el := c.l.Front() - for el != nil { - c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription)) - el = el.Next() + for node := c.head.next; node != c.tail; { + next := node.next + c.invalidStmts = append(c.invalidStmts, node.sd) + c.invalidSet[node.sd.SQL] = struct{}{} + c.freeNode(node) + node = next } - c.m = make(map[string]*list.Element) - c.l = list.New() + clear(c.m) + c.head.next = c.tail + c.tail.prev = c.head + c.len = 0 } // GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated. @@ -89,12 +117,13 @@ func (c *LRUCache) GetInvalidated() []*pgconn.StatementDescription { // call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were // never seen by the call to GetInvalidated. func (c *LRUCache) RemoveInvalidated() { - c.invalidStmts = nil + c.invalidStmts = c.invalidStmts[:0] + clear(c.invalidSet) } // Len returns the number of cached prepared statement descriptions. func (c *LRUCache) Len() int { - return c.l.Len() + return c.len } // Cap returns the maximum number of cached prepared statement descriptions. @@ -103,9 +132,56 @@ func (c *LRUCache) Cap() int { } func (c *LRUCache) invalidateOldest() { - oldest := c.l.Back() - sd := oldest.Value.(*pgconn.StatementDescription) - c.invalidStmts = append(c.invalidStmts, sd) - delete(c.m, sd.SQL) - c.l.Remove(oldest) + node := c.tail.prev + if node == c.head { + return + } + c.invalidStmts = append(c.invalidStmts, node.sd) + c.invalidSet[node.sd.SQL] = struct{}{} + delete(c.m, node.sd.SQL) + c.unlink(node) + c.len-- + c.freeNode(node) +} + +// List operations - sentinel nodes eliminate nil checks + +func (c *LRUCache) insertAfter(at, node *lruNode) { + node.prev = at + node.next = at.next + at.next.prev = node + at.next = node +} + +func (c *LRUCache) unlink(node *lruNode) { + node.prev.next = node.next + node.next.prev = node.prev +} + +func (c *LRUCache) moveToFront(node *lruNode) { + if node.prev == c.head { + return + } + c.unlink(node) + c.insertAfter(c.head, node) +} + +// Node pool operations - reuse evicted nodes to avoid allocations + +func (c *LRUCache) allocNode() *lruNode { + if c.freelist != nil { + node := c.freelist + c.freelist = node.next + node.next = nil + node.prev = nil + return node + } + return &lruNode{} +} + +func (c *LRUCache) freeNode(node *lruNode) { + node.sd = nil + node.prev = nil + node.next = c.freelist + c.freelist = node } diff --git a/vendor/github.com/jackc/pgx/v5/internal/stmtcache/unlimited_cache.go b/vendor/github.com/jackc/pgx/v5/internal/stmtcache/unlimited_cache.go deleted file mode 100644 index 6964132..0000000 --- a/vendor/github.com/jackc/pgx/v5/internal/stmtcache/unlimited_cache.go +++ /dev/null @@ -1,77 +0,0 @@ -package stmtcache - -import ( - "math" - - "github.com/jackc/pgx/v5/pgconn" -) - -// UnlimitedCache implements Cache with no capacity limit. -type UnlimitedCache struct { - m map[string]*pgconn.StatementDescription - invalidStmts []*pgconn.StatementDescription -} - -// NewUnlimitedCache creates a new UnlimitedCache. -func NewUnlimitedCache() *UnlimitedCache { - return &UnlimitedCache{ - m: make(map[string]*pgconn.StatementDescription), - } -} - -// Get returns the statement description for sql. Returns nil if not found. -func (c *UnlimitedCache) Get(sql string) *pgconn.StatementDescription { - return c.m[sql] -} - -// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache. -func (c *UnlimitedCache) Put(sd *pgconn.StatementDescription) { - if sd.SQL == "" { - panic("cannot store statement description with empty SQL") - } - - if _, present := c.m[sd.SQL]; present { - return - } - - c.m[sd.SQL] = sd -} - -// Invalidate invalidates statement description identified by sql. Does nothing if not found. -func (c *UnlimitedCache) Invalidate(sql string) { - if sd, ok := c.m[sql]; ok { - delete(c.m, sql) - c.invalidStmts = append(c.invalidStmts, sd) - } -} - -// InvalidateAll invalidates all statement descriptions. -func (c *UnlimitedCache) InvalidateAll() { - for _, sd := range c.m { - c.invalidStmts = append(c.invalidStmts, sd) - } - - c.m = make(map[string]*pgconn.StatementDescription) -} - -// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated. -func (c *UnlimitedCache) GetInvalidated() []*pgconn.StatementDescription { - return c.invalidStmts -} - -// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a -// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were -// never seen by the call to GetInvalidated. -func (c *UnlimitedCache) RemoveInvalidated() { - c.invalidStmts = nil -} - -// Len returns the number of cached prepared statement descriptions. -func (c *UnlimitedCache) Len() int { - return len(c.m) -} - -// Cap returns the maximum number of cached prepared statement descriptions. -func (c *UnlimitedCache) Cap() int { - return math.MaxInt -} diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/auth_oauth.go b/vendor/github.com/jackc/pgx/v5/pgconn/auth_oauth.go new file mode 100644 index 0000000..991f658 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/pgconn/auth_oauth.go @@ -0,0 +1,67 @@ +package pgconn + +import ( + "context" + "encoding/json" + "errors" + "fmt" + + "github.com/jackc/pgx/v5/pgproto3" +) + +func (c *PgConn) oauthAuth(ctx context.Context) error { + if c.config.OAuthTokenProvider == nil { + return errors.New("OAuth authentication required but no token provider configured") + } + + token, err := c.config.OAuthTokenProvider(ctx) + if err != nil { + return fmt.Errorf("failed to obtain OAuth token: %w", err) + } + + // https://www.rfc-editor.org/rfc/rfc7628.html#section-3.1 + initialResponse := []byte("n,,\x01auth=Bearer " + token + "\x01\x01") + + saslInitialResponse := &pgproto3.SASLInitialResponse{ + AuthMechanism: "OAUTHBEARER", + Data: initialResponse, + } + c.frontend.Send(saslInitialResponse) + err = c.flushWithPotentialWriteReadDeadlock() + if err != nil { + return err + } + + msg, err := c.receiveMessage() + if err != nil { + return err + } + + switch m := msg.(type) { + case *pgproto3.AuthenticationOk: + return nil + case *pgproto3.AuthenticationSASLContinue: + // Server sent error response in SASL continue + // https://www.rfc-editor.org/rfc/rfc7628.html#section-3.2.2 + // https://www.rfc-editor.org/rfc/rfc7628.html#section-3.2.3 + errResponse := struct { + Status string `json:"status"` + Scope string `json:"scope"` + OpenIDConfiguration string `json:"openid-configuration"` + }{} + err := json.Unmarshal(m.Data, &errResponse) + if err != nil { + return fmt.Errorf("invalid OAuth error response from server: %w", err) + } + + // Per RFC 7628 section 3.2.3, we should send a SASLResponse which only contains \x01. + // However, since the connection will be closed anyway, we can skip this + return fmt.Errorf("OAuth authentication failed: %s", errResponse.Status) + + case *pgproto3.ErrorResponse: + return ErrorResponseToPgError(m) + + default: + return fmt.Errorf("unexpected message type during OAuth auth: %T", msg) + } +} diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go b/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go index f846ba8..f59d39c 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go @@ -1,7 +1,8 @@ -// SCRAM-SHA-256 authentication +// SCRAM-SHA-256 and SCRAM-SHA-256-PLUS authentication // // Resources: // https://tools.ietf.org/html/rfc5802 +// https://tools.ietf.org/html/rfc5929 // https://tools.ietf.org/html/rfc8265 // https://www.postgresql.org/docs/current/sasl-authentication.html // @@ -15,19 +16,28 @@ package pgconn import ( "bytes" "crypto/hmac" + "crypto/pbkdf2" "crypto/rand" "crypto/sha256" + "crypto/sha512" + "crypto/tls" + "crypto/x509" "encoding/base64" "errors" "fmt" + "hash" + "slices" "strconv" "github.com/jackc/pgx/v5/pgproto3" - "golang.org/x/crypto/pbkdf2" "golang.org/x/text/secure/precis" ) -const clientNonceLen = 18 +const ( + clientNonceLen = 18 + scramSHA256Name = "SCRAM-SHA-256" + scramSHA256PlusName = "SCRAM-SHA-256-PLUS" +) // Perform SCRAM authentication. func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { @@ -36,9 +46,35 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { return err } + serverHasPlus := slices.Contains(sc.serverAuthMechanisms, scramSHA256PlusName) + if c.config.ChannelBinding == "require" && !serverHasPlus { + return errors.New("channel binding required but server does not support SCRAM-SHA-256-PLUS") + } + + // If we have a TLS connection and channel binding is not disabled, attempt to + // extract the server certificate hash for tls-server-end-point channel binding. + if tlsConn, ok := c.conn.(*tls.Conn); ok && c.config.ChannelBinding != "disable" { + certHash, err := getTLSCertificateHash(tlsConn) + if err != nil && c.config.ChannelBinding == "require" { + return fmt.Errorf("channel binding required but failed to get server certificate hash: %w", err) + } + + // Upgrade to SCRAM-SHA-256-PLUS if we have binding data and the server supports it. + if certHash != nil && serverHasPlus { + sc.authMechanism = scramSHA256PlusName + } + + sc.channelBindingData = certHash + sc.hasTLS = true + } + + if c.config.ChannelBinding == "require" && sc.channelBindingData == nil { + return errors.New("channel binding required but channel binding data is not available") + } + // Send client-first-message in a SASLInitialResponse saslInitialResponse := &pgproto3.SASLInitialResponse{ - AuthMechanism: "SCRAM-SHA-256", + AuthMechanism: sc.authMechanism, Data: sc.clientFirstMessage(), } c.frontend.Send(saslInitialResponse) @@ -107,10 +143,31 @@ func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) { type scramClient struct { serverAuthMechanisms []string - password []byte + password string clientNonce []byte + // authMechanism is the selected SASL mechanism for the client. Must be + // either SCRAM-SHA-256 (default) or SCRAM-SHA-256-PLUS. + // + // Upgraded to SCRAM-SHA-256-PLUS during authentication when channel binding + // is not disabled, channel binding data is available (TLS connection with + // an obtainable server certificate hash) and the server advertises + // SCRAM-SHA-256-PLUS. + authMechanism string + + // hasTLS indicates whether the connection is using TLS. This is + // needed because the GS2 header must distinguish between a client that + // supports channel binding but the server does not ("y,,") versus one + // that does not support it at all ("n,,"). + hasTLS bool + + // channelBindingData is the hash of the server's TLS certificate, computed + // per the tls-server-end-point channel binding type (RFC 5929). Used as + // the binding input in SCRAM-SHA-256-PLUS. nil when not in use. + channelBindingData []byte + clientFirstMessageBare []byte + clientGS2Header []byte serverFirstMessage []byte clientAndServerNonce []byte @@ -124,26 +181,23 @@ type scramClient struct { func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) { sc := &scramClient{ serverAuthMechanisms: serverAuthMechanisms, + authMechanism: scramSHA256Name, } - // Ensure server supports SCRAM-SHA-256 - hasScramSHA256 := false - for _, mech := range sc.serverAuthMechanisms { - if mech == "SCRAM-SHA-256" { - hasScramSHA256 = true - break - } - } - if !hasScramSHA256 { + // Ensure the server supports SCRAM-SHA-256. SCRAM-SHA-256-PLUS is the + // channel binding variant and is only advertised when the server supports + // SSL. PostgreSQL always advertises the base SCRAM-SHA-256 mechanism + // regardless of SSL. + if !slices.Contains(sc.serverAuthMechanisms, scramSHA256Name) { return nil, errors.New("server does not support SCRAM-SHA-256") } // precis.OpaqueString is equivalent to SASLprep for password. var err error - sc.password, err = precis.OpaqueString.Bytes([]byte(password)) + sc.password, err = precis.OpaqueString.String(password) if err != nil { // PostgreSQL allows passwords invalid according to SCRAM / SASLprep. - sc.password = []byte(password) + sc.password = password } buf := make([]byte, clientNonceLen) @@ -158,8 +212,32 @@ func newScramClient(serverAuthMechanisms []string, password string) (*scramClien } func (sc *scramClient) clientFirstMessage() []byte { - sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce)) - return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare)) + // The client-first-message is the GS2 header concatenated with the bare + // message (username + client nonce). The GS2 header communicates the + // client's channel binding capability to the server: + // + // "n,," - client is not using TLS (channel binding not possible) + // "y,," - client is using TLS but channel binding is not + // in use (e.g., server did not advertise SCRAM-SHA-256-PLUS + // or the server certificate hash was not obtainable) + // "p=tls-server-end-point,," - channel binding is active via SCRAM-SHA-256-PLUS + // + // See: + // https://www.rfc-editor.org/rfc/rfc5802#section-6 + // https://www.rfc-editor.org/rfc/rfc5929#section-4 + // https://www.postgresql.org/docs/current/sasl-authentication.html#SASL-SCRAM-SHA-256 + + sc.clientFirstMessageBare = fmt.Appendf(nil, "n=,r=%s", sc.clientNonce) + + if sc.authMechanism == scramSHA256PlusName { + sc.clientGS2Header = []byte("p=tls-server-end-point,,") + } else if sc.hasTLS { + sc.clientGS2Header = []byte("y,,") + } else { + sc.clientGS2Header = []byte("n,,") + } + + return append(sc.clientGS2Header, sc.clientFirstMessageBare...) } func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error { @@ -218,9 +296,25 @@ func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error { } func (sc *scramClient) clientFinalMessage() string { - clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce)) + // The c= attribute carries the base64-encoded channel binding input. + // + // Without channel binding this is just the GS2 header alone ("biws" for + // "n,," or "eSws" for "y,,"). + // + // With channel binding, this is the GS2 header with the channel binding data + // (certificate hash) appended. + channelBindInput := sc.clientGS2Header + if sc.authMechanism == scramSHA256PlusName { + channelBindInput = slices.Concat(sc.clientGS2Header, sc.channelBindingData) + } + channelBindingEncoded := base64.StdEncoding.EncodeToString(channelBindInput) + clientFinalMessageWithoutProof := fmt.Appendf(nil, "c=%s,r=%s", channelBindingEncoded, sc.clientAndServerNonce) - sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New) + var err error + sc.saltedPassword, err = pbkdf2.Key(sha256.New, sc.password, sc.salt, sc.iterations, 32) + if err != nil { + panic(err) // This should never happen. + } sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(",")) clientProof := computeClientProof(sc.saltedPassword, sc.authMessage) @@ -254,7 +348,7 @@ func computeClientProof(saltedPassword, authMessage []byte) []byte { clientSignature := computeHMAC(storedKey[:], authMessage) clientProof := make([]byte, len(clientSignature)) - for i := 0; i < len(clientSignature); i++ { + for i := range clientSignature { clientProof[i] = clientKey[i] ^ clientSignature[i] } @@ -270,3 +364,36 @@ func computeServerSignature(saltedPassword, authMessage []byte) []byte { base64.StdEncoding.Encode(buf, serverSignature) return buf } + +// Get the server certificate hash for SCRAM channel binding type +// tls-server-end-point. +func getTLSCertificateHash(conn *tls.Conn) ([]byte, error) { + state := conn.ConnectionState() + if len(state.PeerCertificates) == 0 { + return nil, errors.New("no peer certificates for channel binding") + } + + cert := state.PeerCertificates[0] + + // Per RFC 5929 section 4.1: If the certificate's signatureAlgorithm uses + // MD5 or SHA-1, use SHA-256. Otherwise use the hash from the signature + // algorithm. + // + // See: https://www.rfc-editor.org/rfc/rfc5929.html#section-4.1 + var h hash.Hash + switch cert.SignatureAlgorithm { + case x509.MD5WithRSA, x509.SHA1WithRSA, x509.ECDSAWithSHA1: + h = sha256.New() + case x509.SHA256WithRSA, x509.SHA256WithRSAPSS, x509.ECDSAWithSHA256: + h = sha256.New() + case x509.SHA384WithRSA, x509.SHA384WithRSAPSS, x509.ECDSAWithSHA384: + h = sha512.New384() + case x509.SHA512WithRSA, x509.SHA512WithRSAPSS, x509.ECDSAWithSHA512: + h = sha512.New() + default: + return nil, fmt.Errorf("tls-server-end-point channel binding is undefined for certificate signature algorithm %v", cert.SignatureAlgorithm) + } + + h.Write(cert.Raw) + return h.Sum(nil), nil +} diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/config.go b/vendor/github.com/jackc/pgx/v5/pgconn/config.go index 3937dc4..dff5509 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/config.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/config.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "io" + "maps" "math" "net" "net/url" @@ -55,6 +56,13 @@ type Config struct { SSLNegotiation string // sslnegotiation=postgres or sslnegotiation=direct + // AfterNetConnect is called after the network connection, including TLS if applicable, is established but before any + // PostgreSQL protocol communication. It takes the established net.Conn and returns a net.Conn that will be used in + // its place. It can be used to wrap the net.Conn (e.g. for logging, diagnostics, or testing). Its functionality has + // some overlap with DialFunc. However, DialFunc takes place before TLS is established and cannot be used to control + // the final net.Conn used for PostgreSQL protocol communication while AfterNetConnect can. + AfterNetConnect func(ctx context.Context, config *Config, conn net.Conn) (net.Conn, error) + // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. // It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next // fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs. @@ -75,6 +83,23 @@ type Config struct { // that you close on FATAL errors by returning false. OnPgError PgErrorHandler + // OAuthTokenProvider is a function that returns an OAuth token for authentication. If set, it will be used for + // OAUTHBEARER SASL authentication when the server requests it. + OAuthTokenProvider func(context.Context) (string, error) + + // MinProtocolVersion is the minimum acceptable PostgreSQL protocol version. + // If the server does not support at least this version, the connection will fail. + // Valid values: "3.0", "3.2", "latest". Defaults to "3.0". + MinProtocolVersion string + + // MaxProtocolVersion is the maximum PostgreSQL protocol version to request from the server. + // Valid values: "3.0", "3.2", "latest". Defaults to "3.0" for compatibility. + MaxProtocolVersion string + + // ChannelBinding is the channel_binding parameter for SCRAM-SHA-256-PLUS authentication. + // Valid values: "disable", "prefer", "require". Defaults to "prefer". + ChannelBinding string + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } @@ -96,9 +121,7 @@ func (c *Config) Copy() *Config { } if newConf.RuntimeParams != nil { newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams)) - for k, v := range c.RuntimeParams { - newConf.RuntimeParams[k] = v - } + maps.Copy(newConf.RuntimeParams, c.RuntimeParams) } if newConf.Fallbacks != nil { newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks)) @@ -207,6 +230,8 @@ func NetworkAddress(host string, port uint16) (network, address string) { // PGCONNECT_TIMEOUT // PGTARGETSESSIONATTRS // PGTZ +// PGMINPROTOCOLVERSION +// PGMAXPROTOCOLVERSION // // See http://www.postgresql.org/docs/current/static/libpq-envars.html for details on the meaning of environment variables. // @@ -332,6 +357,9 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con "target_session_attrs": {}, "service": {}, "servicefile": {}, + "min_protocol_version": {}, + "max_protocol_version": {}, + "channel_binding": {}, } // Adding kerberos configuration @@ -424,6 +452,38 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)} } + minProto, err := parseProtocolVersion(settings["min_protocol_version"]) + if err != nil { + return nil, &ParseConfigError{ConnString: connString, msg: "invalid min_protocol_version", err: err} + } + maxProto, err := parseProtocolVersion(settings["max_protocol_version"]) + if err != nil { + return nil, &ParseConfigError{ConnString: connString, msg: "invalid max_protocol_version", err: err} + } + if minProto > maxProto { + return nil, &ParseConfigError{ConnString: connString, msg: "min_protocol_version cannot be greater than max_protocol_version"} + } + + config.MinProtocolVersion = settings["min_protocol_version"] + config.MaxProtocolVersion = settings["max_protocol_version"] + if config.MinProtocolVersion == "" { + config.MinProtocolVersion = "3.0" + } + if config.MaxProtocolVersion == "" { + config.MaxProtocolVersion = "3.0" + } + + switch channelBinding := settings["channel_binding"]; channelBinding { + case "", "prefer": + config.ChannelBinding = "prefer" + case "disable": + config.ChannelBinding = "disable" + case "require": + config.ChannelBinding = "require" + default: + return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown channel_binding value: %v", channelBinding)} + } + return config, nil } @@ -431,9 +491,7 @@ func mergeSettings(settingSets ...map[string]string) map[string]string { settings := make(map[string]string) for _, s2 := range settingSets { - for k, v := range s2 { - settings[k] = v - } + maps.Copy(settings, s2) } return settings @@ -463,6 +521,8 @@ func parseEnvSettings() map[string]string { "PGSERVICEFILE": "servicefile", "PGTZ": "timezone", "PGOPTIONS": "options", + "PGMINPROTOCOLVERSION": "min_protocol_version", + "PGMAXPROTOCOLVERSION": "max_protocol_version", } for envname, realname := range nameMap { @@ -487,7 +547,9 @@ func parseURLSettings(connString string) (map[string]string, error) { } if parsedURL.User != nil { - settings["user"] = parsedURL.User.Username() + if u := parsedURL.User.Username(); u != "" { + settings["user"] = u + } if password, present := parsedURL.User.Password(); present { settings["password"] = password } @@ -496,7 +558,7 @@ func parseURLSettings(connString string) (map[string]string, error) { // Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port. var hosts []string var ports []string - for _, host := range strings.Split(parsedURL.Host, ",") { + for host := range strings.SplitSeq(parsedURL.Host, ",") { if host == "" { continue } @@ -614,6 +676,9 @@ func parseKeywordValueSettings(s string) (map[string]string, error) { return nil, errors.New("invalid keyword/value") } + if key == "user" && val == "" { + continue + } settings[key] = val } @@ -784,7 +849,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P // Attempt decryption with pass phrase // NOTE: only supports RSA (PKCS#1) if sslpassword != "" { - decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) + decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) //nolint:ineffassign } // if sslpassword not provided or has decryption error when use it // try to find sslpassword with callback function @@ -799,7 +864,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) // Should we also provide warning for PKCS#1 needed? if decryptedError != nil { - return nil, fmt.Errorf("unable to decrypt key: %w", err) + return nil, fmt.Errorf("unable to decrypt key: %w", decryptedError) } pemBytes := pem.Block{ @@ -951,3 +1016,14 @@ func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn return nil } + +func parseProtocolVersion(s string) (uint32, error) { + switch s { + case "", "3.0": + return pgproto3.ProtocolVersion30, nil + case "3.2", "latest": + return pgproto3.ProtocolVersion32, nil + default: + return 0, fmt.Errorf("invalid protocol version: %q", s) + } +} diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/ctxwatch/context_watcher.go b/vendor/github.com/jackc/pgx/v5/pgconn/ctxwatch/context_watcher.go index db8884e..b8892e6 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/ctxwatch/context_watcher.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/ctxwatch/context_watcher.go @@ -8,12 +8,13 @@ import ( // ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a // time. type ContextWatcher struct { - handler Handler - unwatchChan chan struct{} + handler Handler - lock sync.Mutex - watchInProgress bool - onCancelWasCalled bool + // Lock protects the members below. + lock sync.Mutex + // Stop is the handle for an "after func". See [context.AfterFunc]. + stop func() bool + done chan struct{} } // NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled. @@ -21,8 +22,7 @@ type ContextWatcher struct { // onCancel called. func NewContextWatcher(handler Handler) *ContextWatcher { cw := &ContextWatcher{ - handler: handler, - unwatchChan: make(chan struct{}), + handler: handler, } return cw @@ -33,25 +33,16 @@ func (cw *ContextWatcher) Watch(ctx context.Context) { cw.lock.Lock() defer cw.lock.Unlock() - if cw.watchInProgress { - panic("Watch already in progress") + if cw.stop != nil { + panic("watch already in progress") } - cw.onCancelWasCalled = false - if ctx.Done() != nil { - cw.watchInProgress = true - go func() { - select { - case <-ctx.Done(): - cw.handler.HandleCancel(ctx) - cw.onCancelWasCalled = true - <-cw.unwatchChan - case <-cw.unwatchChan: - } - }() - } else { - cw.watchInProgress = false + cw.done = make(chan struct{}) + cw.stop = context.AfterFunc(ctx, func() { + cw.handler.HandleCancel(ctx) + close(cw.done) + }) } } @@ -61,12 +52,13 @@ func (cw *ContextWatcher) Unwatch() { cw.lock.Lock() defer cw.lock.Unlock() - if cw.watchInProgress { - cw.unwatchChan <- struct{}{} - if cw.onCancelWasCalled { + if cw.stop != nil { + if !cw.stop() { + <-cw.done cw.handler.HandleUnwatchAfterCancel() } - cw.watchInProgress = false + cw.stop = nil + cw.done = nil } } diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/errors.go b/vendor/github.com/jackc/pgx/v5/pgconn/errors.go index d968d3f..bc1e31e 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/errors.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/errors.go @@ -254,3 +254,20 @@ func (e *NotPreferredError) SafeToRetry() bool { func (e *NotPreferredError) Unwrap() error { return e.err } + +type PrepareError struct { + err error + + ParseComplete bool // Indicates whether the error occurred after a ParseComplete message was received. +} + +func (e *PrepareError) Error() string { + if e.ParseComplete { + return fmt.Sprintf("prepare failed after ParseComplete: %s", e.err.Error()) + } + return e.err.Error() +} + +func (e *PrepareError) Unwrap() error { + return e.err +} diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go index 97141c6..ca9a48c 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "io" + "maps" "math" "net" "strconv" @@ -22,6 +23,7 @@ import ( "github.com/jackc/pgx/v5/pgconn/ctxwatch" "github.com/jackc/pgx/v5/pgconn/internal/bgreader" "github.com/jackc/pgx/v5/pgproto3" + "github.com/jackc/pgx/v5/pgtype" ) const ( @@ -75,7 +77,7 @@ type NotificationHandler func(*PgConn, *Notification) type PgConn struct { conn net.Conn pid uint32 // backend pid - secretKey uint32 // key to use to send a cancel query message to the server + secretKey []byte // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server txStatus byte frontend *pgproto3.Frontend @@ -317,6 +319,15 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo return e } + maxProtocolVersion, err := parseProtocolVersion(config.MaxProtocolVersion) + if err != nil { + return nil, newPerDialConnectError("invalid max_protocol_version", err) + } + minProtocolVersion, err := parseProtocolVersion(config.MinProtocolVersion) + if err != nil { + return nil, newPerDialConnectError("invalid min_protocol_version", err) + } + pgConn.conn, err = config.DialFunc(ctx, connectConfig.network, connectConfig.address) if err != nil { return nil, newPerDialConnectError("dial error", err) @@ -343,6 +354,14 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo pgConn.conn = tlsConn } + if config.AfterNetConnect != nil { + pgConn.conn, err = config.AfterNetConnect(ctx, config, pgConn.conn) + if err != nil { + pgConn.conn.Close() + return nil, newPerDialConnectError("AfterNetConnect failed", err) + } + } + pgConn.contextWatcher = ctxwatch.NewContextWatcher(config.BuildContextWatcherHandler(pgConn)) pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() @@ -361,14 +380,12 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo pgConn.frontend = config.BuildFrontend(pgConn.bgReader, pgConn.conn) startupMsg := pgproto3.StartupMessage{ - ProtocolVersion: pgproto3.ProtocolVersionNumber, + ProtocolVersion: maxProtocolVersion, Parameters: make(map[string]string), } // Copy default run-time params - for k, v := range config.RuntimeParams { - startupMsg.Parameters[k] = v - } + maps.Copy(startupMsg.Parameters, config.RuntimeParams) startupMsg.Parameters["user"] = config.User if config.Database != "" { @@ -411,7 +428,20 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo return nil, newPerDialConnectError("failed to write password message", err) } case *pgproto3.AuthenticationSASL: - err = pgConn.scramAuth(msg.AuthMechanisms) + // Check if OAUTHBEARER is supported + serverSupportsOAuthBearer := false + for _, mech := range msg.AuthMechanisms { + if mech == "OAUTHBEARER" { + serverSupportsOAuthBearer = true + break + } + } + + if serverSupportsOAuthBearer && pgConn.config.OAuthTokenProvider != nil { + err = pgConn.oauthAuth(ctx) + } else { + err = pgConn.scramAuth(msg.AuthMechanisms) + } if err != nil { pgConn.conn.Close() return nil, newPerDialConnectError("failed SASL auth", err) @@ -444,6 +474,12 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo return pgConn, nil case *pgproto3.ParameterStatus, *pgproto3.NoticeResponse: // handled by ReceiveMessage + case *pgproto3.NegotiateProtocolVersion: + serverVersion := pgproto3.ProtocolVersion30&0xFFFF0000 | uint32(msg.NewestMinorProtocol) + if serverVersion < minProtocolVersion { + pgConn.conn.Close() + return nil, newPerDialConnectError("server protocol version too low", nil) + } case *pgproto3.ErrorResponse: pgConn.conn.Close() return nil, newPerDialConnectError("server error", ErrorResponseToPgError(msg)) @@ -576,6 +612,10 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { // receiveMessage receives a message without setting up context cancellation func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { + if pgConn.status == connStatusClosed { + return nil, &connLockError{status: "conn closed"} + } + msg, err := pgConn.peekMessage() if err != nil { return nil, err @@ -633,7 +673,7 @@ func (pgConn *PgConn) TxStatus() byte { } // SecretKey returns the backend secret key used to send a cancel query message to the server. -func (pgConn *PgConn) SecretKey() uint32 { +func (pgConn *PgConn) SecretKey() []byte { return pgConn.secretKey } @@ -770,25 +810,20 @@ func NewCommandTag(s string) CommandTag { // RowsAffected returns the number of rows affected. If the CommandTag was not // for a row affecting command (e.g. "CREATE TABLE") then it returns 0. func (ct CommandTag) RowsAffected() int64 { - // Find last non-digit - idx := -1 + // Parse the number from the end in a single pass. + var n int64 + var mult int64 = 1 + for i := len(ct.s) - 1; i >= 0; i-- { - if ct.s[i] >= '0' && ct.s[i] <= '9' { - idx = i + c := ct.s[i] + if c >= '0' && c <= '9' { + n += int64(c-'0') * mult + mult *= 10 } else { break } } - if idx == -1 { - return 0 - } - - var n int64 - for _, b := range ct.s[idx:] { - n = n*10 + int64(b-'0') - } - return n } @@ -826,13 +861,15 @@ type FieldDescription struct { Format int16 } -func (pgConn *PgConn) convertRowDescription(dst []FieldDescription, rd *pgproto3.RowDescription) []FieldDescription { - if cap(dst) >= len(rd.Fields) { - dst = dst[:len(rd.Fields):len(rd.Fields)] +func (pgConn *PgConn) getFieldDescriptionSlice(n int) []FieldDescription { + if cap(pgConn.fieldDescriptions) >= n { + return pgConn.fieldDescriptions[:n:n] } else { - dst = make([]FieldDescription, len(rd.Fields)) + return make([]FieldDescription, n) } +} +func convertRowDescription(dst []FieldDescription, rd *pgproto3.RowDescription) { for i := range rd.Fields { dst[i].Name = string(rd.Fields[i].Name) dst[i].TableOID = rd.Fields[i].TableOID @@ -842,8 +879,6 @@ func (pgConn *PgConn) convertRowDescription(dst []FieldDescription, rd *pgproto3 dst[i].TypeModifier = rd.Fields[i].TypeModifier dst[i].Format = rd.Fields[i].Format } - - return dst } type StatementDescription struct { @@ -858,6 +893,10 @@ type StatementDescription struct { // // Prepare does not send a PREPARE statement to the server. It uses the PostgreSQL Parse and Describe protocol messages // directly. +// +// In extremely rare cases, Prepare may fail after the Parse is successful, but before the Describe is complete. In this +// case, the returned error will be an error where errors.As with a *PrepareError succeeds and the *PrepareError has +// ParseComplete set to true. func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) { if err := pgConn.lock(); err != nil { return nil, err @@ -885,7 +924,8 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ psd := &StatementDescription{Name: name, SQL: sql} - var parseErr error + var ParseComplete bool + var pgErr *PgError readloop: for { @@ -896,20 +936,23 @@ readloop: } switch msg := msg.(type) { + case *pgproto3.ParseComplete: + ParseComplete = true case *pgproto3.ParameterDescription: psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) copy(psd.ParamOIDs, msg.ParameterOIDs) case *pgproto3.RowDescription: - psd.Fields = pgConn.convertRowDescription(nil, msg) + psd.Fields = make([]FieldDescription, len(msg.Fields)) + convertRowDescription(psd.Fields, msg) case *pgproto3.ErrorResponse: - parseErr = ErrorResponseToPgError(msg) + pgErr = ErrorResponseToPgError(msg) case *pgproto3.ReadyForQuery: break readloop } } - if parseErr != nil { - return nil, parseErr + if pgErr != nil { + return nil, &PrepareError{err: pgErr, ParseComplete: ParseComplete} } return psd, nil } @@ -1029,11 +1072,11 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { defer contextWatcher.Unwatch() } - buf := make([]byte, 16) - binary.BigEndian.PutUint32(buf[0:4], 16) + buf := make([]byte, 12+len(pgConn.secretKey)) + binary.BigEndian.PutUint32(buf[0:4], uint32(len(buf))) binary.BigEndian.PutUint32(buf[4:8], 80877102) binary.BigEndian.PutUint32(buf[8:12], pgConn.pid) - binary.BigEndian.PutUint32(buf[12:16], pgConn.secretKey) + copy(buf[12:], pgConn.secretKey) if _, err := cancelConn.Write(buf); err != nil { return fmt.Errorf("write to connection for cancellation: %w", err) @@ -1150,7 +1193,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] pgConn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}) pgConn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) - pgConn.execExtendedSuffix(result) + pgConn.execExtendedSuffix(result, nil, nil) return result } @@ -1175,7 +1218,36 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa pgConn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) - pgConn.execExtendedSuffix(result) + pgConn.execExtendedSuffix(result, nil, nil) + + return result +} + +// ExecStatement enqueues the execution of a prepared statement via the PostgreSQL extended query protocol. +// +// This differs from ExecPrepared in that it takes a *StatementDescription instead of the prepared statement name. +// Because it has the *StatementDescription it can avoid the Describe Portal message that ExecPrepared must send to get +// the result column descriptions. +// +// paramValues are the parameter values. It must be encoded in the format given by paramFormats. +// +// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or +// binary format. If paramFormats is nil all params are text format. ExecPrepared will panic if len(paramFormats) is not +// 0, 1, or len(paramValues). +// +// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or binary +// format. If resultFormats is nil all results will be in text format. +// +// ResultReader must be closed before PgConn can be used again. +func (pgConn *PgConn) ExecStatement(ctx context.Context, statementDescription *StatementDescription, paramValues [][]byte, paramFormats, resultFormats []int16) *ResultReader { + result := pgConn.execExtendedPrefix(ctx, paramValues) + if result.closed { + return result + } + + pgConn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: statementDescription.Name, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) + + pgConn.execExtendedSuffix(result, statementDescription, resultFormats) return result } @@ -1215,8 +1287,10 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result } -func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) { - pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) +func (pgConn *PgConn) execExtendedSuffix(result *ResultReader, statementDescription *StatementDescription, resultFormats []int16) { + if statementDescription == nil { + pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) + } pgConn.frontend.SendExecute(&pgproto3.Execute{}) pgConn.frontend.SendSync(&pgproto3.Sync{}) @@ -1230,7 +1304,7 @@ func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) { return } - result.readUntilRowDescription() + result.readUntilRowDescription(statementDescription, resultFormats) } // CopyTo executes the copy command sql and copies the results to w. @@ -1322,10 +1396,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co copyErrChan := make(chan error, 1) signalMessageChan := pgConn.signalMessage() var wg sync.WaitGroup - wg.Add(1) - - go func() { - defer wg.Done() + wg.Go(func() { buf := iobufpool.Get(65536) defer iobufpool.Put(buf) (*buf)[0] = 'd' @@ -1357,7 +1428,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co default: } } - }() + }) var pgErr error var copyErr error @@ -1433,6 +1504,10 @@ type MultiResultReader struct { rr *ResultReader + // Data from when the batch was queued. + statementDescriptions []*StatementDescription + resultFormats [][]int16 + closed bool err error } @@ -1474,6 +1549,39 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) // NextResult returns advances the MultiResultReader to the next result and returns true if a result is available. func (mrr *MultiResultReader) NextResult() bool { for !mrr.closed && mrr.err == nil { + msg, _ := mrr.pgConn.peekMessage() + if _, ok := msg.(*pgproto3.DataRow); ok { + if len(mrr.statementDescriptions) > 0 { + rr := ResultReader{ + pgConn: mrr.pgConn, + multiResultReader: mrr, + ctx: mrr.ctx, + } + + // This result corresponds to a prepared statement description that was provided when queuing the batch. + sd := mrr.statementDescriptions[0] + mrr.statementDescriptions = mrr.statementDescriptions[1:] + + resultFormats := mrr.resultFormats[0] + mrr.resultFormats = mrr.resultFormats[1:] + + sdFields := sd.Fields + rr.fieldDescriptions = rr.pgConn.getFieldDescriptionSlice(len(sdFields)) + + err := combineFieldDescriptionsAndResultFormats(rr.fieldDescriptions, sdFields, resultFormats) + if err != nil { + rr.concludeCommand(CommandTag{}, err) + } + + mrr.pgConn.resultReader = rr + mrr.rr = &mrr.pgConn.resultReader + return true + } + + mrr.err = fmt.Errorf("unexpected DataRow message without preceding RowDescription") + return false + } + msg, err := mrr.receiveMessage() if err != nil { return false @@ -1485,8 +1593,9 @@ func (mrr *MultiResultReader) NextResult() bool { pgConn: mrr.pgConn, multiResultReader: mrr, ctx: mrr.ctx, - fieldDescriptions: mrr.pgConn.convertRowDescription(mrr.pgConn.fieldDescriptions[:], msg), + fieldDescriptions: mrr.pgConn.getFieldDescriptionSlice(len(msg.Fields)), } + convertRowDescription(mrr.pgConn.resultReader.fieldDescriptions, msg) mrr.rr = &mrr.pgConn.resultReader return true @@ -1499,7 +1608,12 @@ func (mrr *MultiResultReader) NextResult() bool { mrr.rr = &mrr.pgConn.resultReader return true case *pgproto3.EmptyQueryResponse: - return false + mrr.pgConn.resultReader = ResultReader{ + commandConcluded: true, + closed: true, + } + mrr.rr = &mrr.pgConn.resultReader + return true } } @@ -1533,6 +1647,7 @@ type ResultReader struct { fieldDescriptions []FieldDescription rowValues [][]byte commandTag CommandTag + preloaded bool commandConcluded bool closed bool err error @@ -1574,6 +1689,11 @@ func (rr *ResultReader) Read() *Result { // NextRow advances the ResultReader to the next row and returns true if a row is available. func (rr *ResultReader) NextRow() bool { + if rr.preloaded { + rr.preloaded = false + return true + } + for !rr.commandConcluded { msg, err := rr.receiveMessage() if err != nil { @@ -1590,6 +1710,11 @@ func (rr *ResultReader) NextRow() bool { return false } +func (rr *ResultReader) preloadRowValues(values [][]byte) { + rr.rowValues = values + rr.preloaded = true +} + // FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until // the ResultReader is closed. It may return nil (for example, if the query did not return a result set or an error was // encountered.) @@ -1642,19 +1767,34 @@ func (rr *ResultReader) Close() (CommandTag, error) { // readUntilRowDescription ensures the ResultReader's fieldDescriptions are loaded. It does not return an error as any // error will be stored in the ResultReader. -func (rr *ResultReader) readUntilRowDescription() { +func (rr *ResultReader) readUntilRowDescription(statementDescription *StatementDescription, resultFormats []int16) { for !rr.commandConcluded { - // Peek before receive to avoid consuming a DataRow if the result set does not include a RowDescription method. - // This should never happen under normal pgconn usage, but it is possible if SendBytes and ReceiveResults are - // manually used to construct a query that does not issue a describe statement. - msg, _ := rr.pgConn.peekMessage() - if _, ok := msg.(*pgproto3.DataRow); ok { + msg, _ := rr.receiveMessage() + switch msg := msg.(type) { + case *pgproto3.RowDescription: return - } + case *pgproto3.DataRow: + rr.preloadRowValues(msg.Values) + if statementDescription != nil { + sdFields := statementDescription.Fields + rr.fieldDescriptions = rr.pgConn.getFieldDescriptionSlice(len(sdFields)) - // Consume the message - msg, _ = rr.receiveMessage() - if _, ok := msg.(*pgproto3.RowDescription); ok { + err := combineFieldDescriptionsAndResultFormats(rr.fieldDescriptions, sdFields, resultFormats) + if err != nil { + rr.concludeCommand(CommandTag{}, err) + } + } + return + case *pgproto3.CommandComplete: + if statementDescription != nil { + sdFields := statementDescription.Fields + rr.fieldDescriptions = rr.pgConn.getFieldDescriptionSlice(len(sdFields)) + + err := combineFieldDescriptionsAndResultFormats(rr.fieldDescriptions, sdFields, resultFormats) + if err != nil { + rr.concludeCommand(CommandTag{}, err) + } + } return } } @@ -1681,7 +1821,8 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error switch msg := msg.(type) { case *pgproto3.RowDescription: - rr.fieldDescriptions = rr.pgConn.convertRowDescription(rr.pgConn.fieldDescriptions[:], msg) + rr.fieldDescriptions = rr.pgConn.getFieldDescriptionSlice(len(msg.Fields)) + convertRowDescription(rr.fieldDescriptions, msg) case *pgproto3.CommandComplete: rr.concludeCommand(rr.pgConn.makeCommandTag(msg.CommandTag), nil) case *pgproto3.EmptyQueryResponse: @@ -1715,8 +1856,10 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { // Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip. type Batch struct { - buf []byte - err error + buf []byte + statementDescriptions []*StatementDescription + resultFormats [][]int16 + err error } // ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions. @@ -1754,6 +1897,30 @@ func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFor } } +// ExecStatement appends an ExecStatement command to the batch. See PgConn.ExecPrepared for parameter descriptions. +// +// This differs from ExecPrepared in that it takes a *StatementDescription instead of just the prepared statement name. +// Because it has the *StatementDescription it can avoid the Describe Portal message that ExecPrepared must send to get +// the result column descriptions. +func (batch *Batch) ExecStatement(statementDescription *StatementDescription, paramValues [][]byte, paramFormats, resultFormats []int16) { + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Bind{PreparedStatement: statementDescription.Name, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) + if batch.err != nil { + return + } + + batch.statementDescriptions = append(batch.statementDescriptions, statementDescription) + batch.resultFormats = append(batch.resultFormats, resultFormats) + + batch.buf, batch.err = (&pgproto3.Execute{}).Encode(batch.buf) + if batch.err != nil { + return + } +} + // ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a // transaction is already in progress or SQL contains transaction control statements. This is a simpler way of executing // multiple queries in a single round trip than using pipeline mode. @@ -1773,8 +1940,10 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR } pgConn.multiResultReader = MultiResultReader{ - pgConn: pgConn, - ctx: ctx, + pgConn: pgConn, + ctx: ctx, + statementDescriptions: batch.statementDescriptions, + resultFormats: batch.resultFormats, } multiResult := &pgConn.multiResultReader @@ -1799,9 +1968,11 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR return multiResult } - pgConn.enterPotentialWriteReadDeadlock() - defer pgConn.exitPotentialWriteReadDeadlock() - _, err := pgConn.conn.Write(batch.buf) + _, err := func(buf []byte) (int, error) { + pgConn.enterPotentialWriteReadDeadlock() + defer pgConn.exitPotentialWriteReadDeadlock() + return pgConn.conn.Write(buf) + }(batch.buf) if err != nil { pgConn.contextWatcher.Unwatch() multiResult.err = normalizeTimeoutError(multiResult.ctx, err) @@ -1907,7 +2078,7 @@ func (pgConn *PgConn) flushWithPotentialWriteReadDeadlock() error { // // This should not be confused with the PostgreSQL protocol Sync message. func (pgConn *PgConn) SyncConn(ctx context.Context) error { - for i := 0; i < 10; i++ { + for range 10 { if pgConn.bgReader.Status() == bgreader.StatusStopped && pgConn.frontend.ReadBufferLen() == 0 { return nil } @@ -1935,7 +2106,7 @@ func (pgConn *PgConn) CustomData() map[string]any { type HijackedConn struct { Conn net.Conn PID uint32 // backend pid - SecretKey uint32 // key to use to send a cancel query message to the server + SecretKey []byte // key to use to send a cancel query message to the server ParameterStatuses map[string]string // parameters that have been reported by the server TxStatus byte Frontend *pgproto3.Frontend @@ -2007,9 +2178,10 @@ func Construct(hc *HijackedConn) (*PgConn, error) { // Pipeline represents a connection in pipeline mode. // -// SendPrepare, SendQueryParams, and SendQueryPrepared queue requests to the server. These requests are not written until -// pipeline is flushed by Flush or Sync. Sync must be called after the last request is queued. Requests between -// synchronization points are implicitly transactional unless explicit transaction control statements have been issued. +// SendPrepare, SendQueryParams, SendQueryPrepared, and SendQueryStatement queue requests to the server. These requests +// are not written until pipeline is flushed by Flush or Sync. Sync must be called after the last request is queued. +// Requests between synchronization points are implicitly transactional unless explicit transaction control statements +// have been issued. // // The context the pipeline was started with is in effect for the entire life of the Pipeline. // @@ -2038,6 +2210,7 @@ const ( pipelinePrepare pipelineQueryParams pipelineQueryPrepared + pipelineQueryStatement pipelineDeallocate pipelineSyncRequest pipelineFlushRequest @@ -2051,6 +2224,8 @@ type pipelineRequestEvent struct { type pipelineState struct { requestEventQueue list.List + statementDescriptionsQueue list.List + resultFormatsQueue list.List lastRequestType pipelineRequestType pgErr *PgError expectedReadyForQueryCount int @@ -2058,6 +2233,8 @@ type pipelineState struct { func (s *pipelineState) Init() { s.requestEventQueue.Init() + s.statementDescriptionsQueue.Init() + s.resultFormatsQueue.Init() s.lastRequestType = pipelineNil } @@ -2122,6 +2299,29 @@ func (s *pipelineState) ExtractFrontRequestType() pipelineRequestType { } } +func (s *pipelineState) PushBackStatementData(sd *StatementDescription, resultFormats []int16) { + s.statementDescriptionsQueue.PushBack(sd) + s.resultFormatsQueue.PushBack(resultFormats) +} + +func (s *pipelineState) ExtractFrontStatementData() (*StatementDescription, []int16) { + sdElem := s.statementDescriptionsQueue.Front() + var sd *StatementDescription + if sdElem != nil { + s.statementDescriptionsQueue.Remove(sdElem) + sd = sdElem.Value.(*StatementDescription) + } + + rfElem := s.resultFormatsQueue.Front() + var resultFormats []int16 + if rfElem != nil { + s.resultFormatsQueue.Remove(rfElem) + resultFormats = rfElem.Value.([]int16) + } + + return sd, resultFormats +} + func (s *pipelineState) HandleError(err *PgError) { s.pgErr = err } @@ -2164,6 +2364,8 @@ func (pgConn *PgConn) StartPipeline(ctx context.Context) *Pipeline { return pipeline } + pgConn.resultReader = ResultReader{closed: true} + pgConn.pipeline = Pipeline{ conn: pgConn, ctx: ctx, @@ -2208,7 +2410,7 @@ func (p *Pipeline) SendDeallocate(name string) { p.state.PushBackRequestType(pipelineDeallocate) } -// SendQueryParams is the pipeline version of *PgConn.QueryParams. +// SendQueryParams is the pipeline version of *PgConn.ExecParams. func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats, resultFormats []int16) { if p.closed { return @@ -2221,7 +2423,7 @@ func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs [ p.state.PushBackRequestType(pipelineQueryParams) } -// SendQueryPrepared is the pipeline version of *PgConn.QueryPrepared. +// SendQueryPrepared is the pipeline version of *PgConn.ExecPrepared. func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, paramFormats, resultFormats []int16) { if p.closed { return @@ -2233,6 +2435,18 @@ func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, para p.state.PushBackRequestType(pipelineQueryPrepared) } +// SendQueryStatement is the pipeline version of *PgConn.ExecStatement. +func (p *Pipeline) SendQueryStatement(statementDescription *StatementDescription, paramValues [][]byte, paramFormats, resultFormats []int16) { + if p.closed { + return + } + + p.conn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: statementDescription.Name, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) + p.conn.frontend.SendExecute(&pgproto3.Execute{}) + p.state.PushBackRequestType(pipelineQueryStatement) + p.state.PushBackStatementData(statementDescription, resultFormats) +} + // SendFlushRequest sends a request for the server to flush its output buffer. // // The server flushes its output buffer automatically as a result of Sync being called, @@ -2307,99 +2521,315 @@ func (p *Pipeline) GetResults() (results any, err error) { return nil, errors.New("pipeline closed") } - if p.state.ExtractFrontRequestType() == pipelineNil { - return nil, nil - } - return p.getResults() } func (p *Pipeline) getResults() (results any, err error) { - for { - msg, err := p.conn.receiveMessage() + if !p.conn.resultReader.closed { + _, err := p.conn.resultReader.Close() if err != nil { - p.closed = true - p.err = err - p.conn.asyncClose() - return nil, normalizeTimeoutError(p.ctx, err) + return nil, err } + } - switch msg := msg.(type) { - case *pgproto3.RowDescription: - p.conn.resultReader = ResultReader{ - pgConn: p.conn, - pipeline: p, - ctx: p.ctx, - fieldDescriptions: p.conn.convertRowDescription(p.conn.fieldDescriptions[:], msg), - } - return &p.conn.resultReader, nil - case *pgproto3.CommandComplete: - p.conn.resultReader = ResultReader{ - commandTag: p.conn.makeCommandTag(msg.CommandTag), - commandConcluded: true, - closed: true, - } - return &p.conn.resultReader, nil - case *pgproto3.ParseComplete: - peekedMsg, err := p.conn.peekMessage() - if err != nil { - p.conn.asyncClose() - return nil, normalizeTimeoutError(p.ctx, err) - } - if _, ok := peekedMsg.(*pgproto3.ParameterDescription); ok { - return p.getResultsPrepare() - } - case *pgproto3.CloseComplete: - return &CloseComplete{}, nil - case *pgproto3.ReadyForQuery: - p.state.HandleReadyForQuery() - return &PipelineSync{}, nil - case *pgproto3.ErrorResponse: - pgErr := ErrorResponseToPgError(msg) - p.state.HandleError(pgErr) - return nil, pgErr - } + currentRequestType := p.state.ExtractFrontRequestType() + switch currentRequestType { + case pipelineNil: + return nil, nil + case pipelinePrepare: + return p.getResultsPrepare() + case pipelineQueryParams: + return p.getResultsQueryParams() + case pipelineQueryPrepared: + return p.getResultsQueryPrepared() + case pipelineQueryStatement: + return p.getResultsQueryStatement() + case pipelineDeallocate: + return p.getResultsDeallocate() + case pipelineSyncRequest: + return p.getResultsSync() + case pipelineFlushRequest: + return nil, errors.New("BUG: pipelineFlushRequest should not be in request queue") + default: + return nil, errors.New("BUG: unknown pipeline request type") } } func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) { + err := p.receiveParseComplete("Prepare") + if err != nil { + return nil, err + } + psd := &StatementDescription{} + msg, err := p.receiveMessage() + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.ParameterDescription: + psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) + copy(psd.ParamOIDs, msg.ParameterOIDs) + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + return nil, pgErr + default: + return nil, p.handleUnexpectedMessage("Prepare ParameterDescription", msg) + } + + msg, err = p.receiveMessage() + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + psd.Fields = make([]FieldDescription, len(msg.Fields)) + convertRowDescription(psd.Fields, msg) + return psd, nil + + // NoData is returned instead of RowDescription when there is no expected result. e.g. An INSERT without a RETURNING + // clause. + case *pgproto3.NoData: + return psd, nil + + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + return nil, pgErr + default: + return nil, p.handleUnexpectedMessage("Prepare RowDescription", msg) + } +} + +func (p *Pipeline) getResultsQueryParams() (*ResultReader, error) { + err := p.receiveParseComplete("QueryParams") + if err != nil { + return nil, err + } + + err = p.receiveBindComplete("QueryParams") + if err != nil { + return nil, err + } + + return p.receiveDescribedResultReader("QueryParams") +} + +func (p *Pipeline) getResultsQueryPrepared() (*ResultReader, error) { + err := p.receiveBindComplete("QueryPrepared") + if err != nil { + return nil, err + } + + return p.receiveDescribedResultReader("QueryPrepared") +} + +func (p *Pipeline) getResultsQueryStatement() (*ResultReader, error) { + err := p.receiveBindComplete("QueryStatement") + if err != nil { + return nil, err + } + + msg, err := p.receiveMessage() + if err != nil { + return nil, err + } + + sd, resultFormats := p.state.ExtractFrontStatementData() + if sd == nil { + return nil, errors.New("BUG: missing statement description or result formats for QueryStatement") + } + sdFields := sd.Fields + fieldDescriptions := p.conn.getFieldDescriptionSlice(len(sdFields)) + err = combineFieldDescriptionsAndResultFormats(fieldDescriptions, sdFields, resultFormats) + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.DataRow: + rr := ResultReader{ + pgConn: p.conn, + pipeline: p, + ctx: p.ctx, + fieldDescriptions: fieldDescriptions, + } + rr.preloadRowValues(msg.Values) + p.conn.resultReader = rr + return &p.conn.resultReader, nil + case *pgproto3.CommandComplete: + p.conn.resultReader = ResultReader{ + commandTag: p.conn.makeCommandTag(msg.CommandTag), + commandConcluded: true, + closed: true, + fieldDescriptions: fieldDescriptions, + } + return &p.conn.resultReader, nil + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + p.conn.resultReader.closed = true + return nil, pgErr + default: + return nil, p.handleUnexpectedMessage("QueryStatement", msg) + } +} + +func (p *Pipeline) getResultsDeallocate() (*CloseComplete, error) { + msg, err := p.receiveMessage() + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.CloseComplete: + return &CloseComplete{}, nil + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + p.conn.resultReader.closed = true + return nil, pgErr + default: + return nil, p.handleUnexpectedMessage("Deallocate", msg) + } +} + +func (p *Pipeline) getResultsSync() (*PipelineSync, error) { + msg, err := p.receiveMessage() + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + p.state.HandleReadyForQuery() + return &PipelineSync{}, nil + case *pgproto3.ErrorResponse: + // Error message that is received while expecting a Sync message still consumes the expected Sync. Put it back. + p.state.requestEventQueue.PushFront(pipelineRequestEvent{RequestType: pipelineSyncRequest, WasSentToServer: true, BeforeFlushOrSync: true}) + + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + p.conn.resultReader.closed = true + return nil, pgErr + default: + return nil, p.handleUnexpectedMessage("Sync", msg) + } +} + +func (p *Pipeline) receiveParseComplete(errStr string) error { + msg, err := p.receiveMessage() + if err != nil { + return err + } + + switch msg := msg.(type) { + case *pgproto3.ParseComplete: + return nil + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + return pgErr + default: + return p.handleUnexpectedMessage(fmt.Sprintf("%s Parse", errStr), msg) + } +} + +func (p *Pipeline) receiveBindComplete(errStr string) error { + msg, err := p.receiveMessage() + if err != nil { + return err + } + + switch msg := msg.(type) { + case *pgproto3.BindComplete: + return nil + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + return pgErr + default: + return p.handleUnexpectedMessage(fmt.Sprintf("%s Bind", errStr), msg) + } +} + +func (p *Pipeline) receiveDescribedResultReader(errStr string) (*ResultReader, error) { + msg, err := p.receiveMessage() + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + p.conn.resultReader = ResultReader{ + pgConn: p.conn, + pipeline: p, + ctx: p.ctx, + fieldDescriptions: p.conn.getFieldDescriptionSlice(len(msg.Fields)), + } + convertRowDescription(p.conn.resultReader.fieldDescriptions, msg) + return &p.conn.resultReader, nil + case *pgproto3.NoData: + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + p.conn.resultReader.closed = true + return nil, pgErr + default: + return nil, p.handleUnexpectedMessage(fmt.Sprintf("%s RowDescription or NoData", errStr), msg) + } + + msg, err = p.receiveMessage() + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.CommandComplete: + p.conn.resultReader = ResultReader{ + commandTag: p.conn.makeCommandTag(msg.CommandTag), + commandConcluded: true, + closed: true, + } + return &p.conn.resultReader, nil + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + p.conn.resultReader.closed = true + return nil, pgErr + default: + return nil, p.handleUnexpectedMessage(fmt.Sprintf("%s CommandComplete", errStr), msg) + } +} + +func (p *Pipeline) receiveMessage() (pgproto3.BackendMessage, error) { for { msg, err := p.conn.receiveMessage() if err != nil { + p.err = err p.conn.asyncClose() return nil, normalizeTimeoutError(p.ctx, err) } switch msg := msg.(type) { - case *pgproto3.ParameterDescription: - psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) - copy(psd.ParamOIDs, msg.ParameterOIDs) - case *pgproto3.RowDescription: - psd.Fields = p.conn.convertRowDescription(nil, msg) - return psd, nil - - // NoData is returned instead of RowDescription when there is no expected result. e.g. An INSERT without a RETURNING - // clause. - case *pgproto3.NoData: - return psd, nil - - // These should never happen here. But don't take chances that could lead to a deadlock. - case *pgproto3.ErrorResponse: - pgErr := ErrorResponseToPgError(msg) - p.state.HandleError(pgErr) - return nil, pgErr - case *pgproto3.CommandComplete: - p.conn.asyncClose() - return nil, errors.New("BUG: received CommandComplete while handling Describe") - case *pgproto3.ReadyForQuery: - p.conn.asyncClose() - return nil, errors.New("BUG: received ReadyForQuery while handling Describe") + case *pgproto3.ParameterStatus, *pgproto3.NoticeResponse, *pgproto3.NotificationResponse: + // Filter these message types out in pipeline mode. The normal processing is handled by PgConn.receiveMessage. + default: + return msg, nil } } } +func (p *Pipeline) handleUnexpectedMessage(errStr string, msg pgproto3.BackendMessage) error { + p.err = fmt.Errorf("pipeline: %s: received unexpected message type %T", errStr, msg) + p.conn.asyncClose() + return p.err +} + // Close closes the pipeline and returns the connection to normal mode. func (p *Pipeline) Close() error { if p.closed { @@ -2418,7 +2848,7 @@ func (p *Pipeline) Close() error { } for p.state.ExpectedReadyForQuery() > 0 { - _, err := p.getResults() + results, err := p.getResults() if err != nil { p.err = err var pgErr *PgError @@ -2426,6 +2856,15 @@ func (p *Pipeline) Close() error { p.conn.asyncClose() break } + } else if results == nil { + // getResults returns (nil, nil) when the request queue is exhausted but + // ExpectedReadyForQuery is still > 0. This can happen when FATAL errors consume + // queued request slots without the server ever sending ReadyForQuery. + p.conn.asyncClose() + if p.err == nil { + p.err = errors.New("pipeline: no more results but expected ReadyForQuery") + } + break } } @@ -2502,3 +2941,32 @@ func (h *CancelRequestContextWatcherHandler) HandleUnwatchAfterCancel() { h.Conn.conn.SetDeadline(time.Time{}) } + +func combineFieldDescriptionsAndResultFormats(outputFields, inputFields []FieldDescription, resultFormats []int16) error { + switch { + case len(resultFormats) == 0: + // No format codes provided means text format for all columns. + for i := range inputFields { + outputFields[i] = inputFields[i] + outputFields[i].Format = pgtype.TextFormatCode + } + case len(resultFormats) == 1: + // Single format code applies to all columns. + format := resultFormats[0] + for i := range inputFields { + outputFields[i] = inputFields[i] + outputFields[i].Format = format + } + case len(resultFormats) == len(inputFields): + // One format code per column. + for i := range inputFields { + outputFields[i] = inputFields[i] + outputFields[i].Format = resultFormats[i] + } + default: + // This should not occur if Bind validation is correct, but handle gracefully + return fmt.Errorf("result format codes length %d does not match field count %d", len(resultFormats), len(inputFields)) + } + + return nil +} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_sasl.go b/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_sasl.go index e66580f..69e2282 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_sasl.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_sasl.go @@ -33,6 +33,7 @@ func (dst *AuthenticationSASL) Decode(src []byte) error { return errors.New("bad auth type") } + dst.AuthMechanisms = dst.AuthMechanisms[:0] authMechanisms := src[4:] for len(authMechanisms) > 1 { idx := bytes.IndexByte(authMechanisms, 0) diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/backend.go b/vendor/github.com/jackc/pgx/v5/pgproto3/backend.go index 28cff04..65388ad 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/backend.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/backend.go @@ -46,8 +46,8 @@ type Backend struct { } const ( - minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code. - maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source. + minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code. + maxStartupPacketLen = 10_000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source. ) // NewBackend creates a new Backend. @@ -123,7 +123,7 @@ func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) { if err != nil { return nil, err } - msgSize := int(binary.BigEndian.Uint32(buf) - 4) + msgSize := int(int32(binary.BigEndian.Uint32(buf)) - 4) if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen { return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize) @@ -137,7 +137,7 @@ func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) { code := binary.BigEndian.Uint32(buf) switch code { - case ProtocolVersionNumber: + case ProtocolVersion30, ProtocolVersion32: err = b.startupMessage.Decode(buf) if err != nil { return nil, err @@ -176,7 +176,7 @@ func (b *Backend) Receive() (FrontendMessage, error) { b.msgType = header[0] - msgLength := int(binary.BigEndian.Uint32(header[1:])) + msgLength := int(int32(binary.BigEndian.Uint32(header[1:]))) if msgLength < 4 { return nil, fmt.Errorf("invalid message length: %d", msgLength) } diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/backend_key_data.go b/vendor/github.com/jackc/pgx/v5/pgproto3/backend_key_data.go index 23f5da6..c73b2da 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/backend_key_data.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/backend_key_data.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "encoding/hex" "encoding/json" "github.com/jackc/pgx/v5/internal/pgio" @@ -9,7 +10,7 @@ import ( type BackendKeyData struct { ProcessID uint32 - SecretKey uint32 + SecretKey []byte } // Backend identifies this message as sendable by the PostgreSQL backend. @@ -18,12 +19,13 @@ func (*BackendKeyData) Backend() {} // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *BackendKeyData) Decode(src []byte) error { - if len(src) != 8 { + if len(src) < 8 { return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)} } dst.ProcessID = binary.BigEndian.Uint32(src[:4]) - dst.SecretKey = binary.BigEndian.Uint32(src[4:]) + dst.SecretKey = make([]byte, len(src)-4) + copy(dst.SecretKey, src[4:]) return nil } @@ -32,7 +34,7 @@ func (dst *BackendKeyData) Decode(src []byte) error { func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) { dst, sp := beginMessage(dst, 'K') dst = pgio.AppendUint32(dst, src.ProcessID) - dst = pgio.AppendUint32(dst, src.SecretKey) + dst = append(dst, src.SecretKey...) return finishMessage(dst, sp) } @@ -41,10 +43,29 @@ func (src BackendKeyData) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string ProcessID uint32 - SecretKey uint32 + SecretKey string }{ Type: "BackendKeyData", ProcessID: src.ProcessID, - SecretKey: src.SecretKey, + SecretKey: hex.EncodeToString(src.SecretKey), }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *BackendKeyData) UnmarshalJSON(data []byte) error { + var msg struct { + ProcessID uint32 + SecretKey string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.ProcessID = msg.ProcessID + secretKey, err := hex.DecodeString(msg.SecretKey) + if err != nil { + return err + } + dst.SecretKey = secretKey + return nil +} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/bind.go b/vendor/github.com/jackc/pgx/v5/pgproto3/bind.go index ad6ac48..fb56e4d 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/bind.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/bind.go @@ -54,7 +54,7 @@ func (dst *Bind) Decode(src []byte) error { if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 { return &invalidMessageFormatErr{messageType: "Bind"} } - for i := 0; i < parameterFormatCodeCount; i++ { + for i := range parameterFormatCodeCount { dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) rp += 2 } @@ -69,7 +69,7 @@ func (dst *Bind) Decode(src []byte) error { if parameterCount > 0 { dst.Parameters = make([][]byte, parameterCount) - for i := 0; i < parameterCount; i++ { + for i := range parameterCount { if len(src[rp:]) < 4 { return &invalidMessageFormatErr{messageType: "Bind"} } @@ -82,7 +82,7 @@ func (dst *Bind) Decode(src []byte) error { continue } - if len(src[rp:]) < msgSize { + if msgSize < 0 || len(src[rp:]) < msgSize { return &invalidMessageFormatErr{messageType: "Bind"} } @@ -101,7 +101,7 @@ func (dst *Bind) Decode(src []byte) error { if len(src[rp:]) < len(dst.ResultFormatCodes)*2 { return &invalidMessageFormatErr{messageType: "Bind"} } - for i := 0; i < resultFormatCodeCount; i++ { + for i := range resultFormatCodeCount { dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) rp += 2 } diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/cancel_request.go b/vendor/github.com/jackc/pgx/v5/pgproto3/cancel_request.go index 6b52dd9..63ebe5c 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/cancel_request.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/cancel_request.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "encoding/hex" "encoding/json" "errors" @@ -12,35 +13,42 @@ const cancelRequestCode = 80877102 type CancelRequest struct { ProcessID uint32 - SecretKey uint32 + SecretKey []byte } // Frontend identifies this message as sendable by a PostgreSQL frontend. func (*CancelRequest) Frontend() {} func (dst *CancelRequest) Decode(src []byte) error { - if len(src) != 12 { - return errors.New("bad cancel request size") + if len(src) < 12 { + return errors.New("cancel request too short") + } + if len(src) > 264 { + return errors.New("cancel request too long") } requestCode := binary.BigEndian.Uint32(src) - if requestCode != cancelRequestCode { return errors.New("bad cancel request code") } dst.ProcessID = binary.BigEndian.Uint32(src[4:]) - dst.SecretKey = binary.BigEndian.Uint32(src[8:]) + dst.SecretKey = make([]byte, len(src)-8) + copy(dst.SecretKey, src[8:]) return nil } // Encode encodes src into dst. dst will include the 4 byte message length. func (src *CancelRequest) Encode(dst []byte) ([]byte, error) { - dst = pgio.AppendInt32(dst, 16) + if len(src.SecretKey) > 256 { + return nil, errors.New("secret key too long") + } + msgLen := int32(12 + len(src.SecretKey)) + dst = pgio.AppendInt32(dst, msgLen) dst = pgio.AppendInt32(dst, cancelRequestCode) dst = pgio.AppendUint32(dst, src.ProcessID) - dst = pgio.AppendUint32(dst, src.SecretKey) + dst = append(dst, src.SecretKey...) return dst, nil } @@ -49,10 +57,29 @@ func (src CancelRequest) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string ProcessID uint32 - SecretKey uint32 + SecretKey string }{ Type: "CancelRequest", ProcessID: src.ProcessID, - SecretKey: src.SecretKey, + SecretKey: hex.EncodeToString(src.SecretKey), }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CancelRequest) UnmarshalJSON(data []byte) error { + var msg struct { + ProcessID uint32 + SecretKey string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.ProcessID = msg.ProcessID + secretKey, err := hex.DecodeString(msg.SecretKey) + if err != nil { + return err + } + dst.SecretKey = secretKey + return nil +} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/copy_both_response.go b/vendor/github.com/jackc/pgx/v5/pgproto3/copy_both_response.go index 99e1afe..e2a402f 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/copy_both_response.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/copy_both_response.go @@ -35,7 +35,7 @@ func (dst *CopyBothResponse) Decode(src []byte) error { } columnFormatCodes := make([]uint16, columnCount) - for i := 0; i < columnCount; i++ { + for i := range columnCount { columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) } diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/copy_fail.go b/vendor/github.com/jackc/pgx/v5/pgproto3/copy_fail.go index 72a85fd..f8a00b8 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/copy_fail.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/copy_fail.go @@ -15,6 +15,10 @@ func (*CopyFail) Frontend() {} // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *CopyFail) Decode(src []byte) error { + if len(src) == 0 { + return &invalidMessageFormatErr{messageType: "CopyFail"} + } + idx := bytes.IndexByte(src, 0) if idx != len(src)-1 { return &invalidMessageFormatErr{messageType: "CopyFail"} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/copy_in_response.go b/vendor/github.com/jackc/pgx/v5/pgproto3/copy_in_response.go index 06cf99c..0633935 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/copy_in_response.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/copy_in_response.go @@ -35,7 +35,7 @@ func (dst *CopyInResponse) Decode(src []byte) error { } columnFormatCodes := make([]uint16, columnCount) - for i := 0; i < columnCount; i++ { + for i := range columnCount { columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) } diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/copy_out_response.go b/vendor/github.com/jackc/pgx/v5/pgproto3/copy_out_response.go index 549e916..006864a 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/copy_out_response.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/copy_out_response.go @@ -34,7 +34,7 @@ func (dst *CopyOutResponse) Decode(src []byte) error { } columnFormatCodes := make([]uint16, columnCount) - for i := 0; i < columnCount; i++ { + for i := range columnCount { columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) } diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/data_row.go b/vendor/github.com/jackc/pgx/v5/pgproto3/data_row.go index fdfb0f7..54418d5 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/data_row.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/data_row.go @@ -31,16 +31,13 @@ func (dst *DataRow) Decode(src []byte) error { // large reallocate. This is too avoid one row with many columns from // permanently allocating memory. if cap(dst.Values) < fieldCount || cap(dst.Values)-fieldCount > 32 { - newCap := 32 - if newCap < fieldCount { - newCap = fieldCount - } + newCap := max(32, fieldCount) dst.Values = make([][]byte, fieldCount, newCap) } else { dst.Values = dst.Values[:fieldCount] } - for i := 0; i < fieldCount; i++ { + for i := range fieldCount { if len(src[rp:]) < 4 { return &invalidMessageFormatErr{messageType: "DataRow"} } diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/frontend.go b/vendor/github.com/jackc/pgx/v5/pgproto3/frontend.go index 056e547..3d66518 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/frontend.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/frontend.go @@ -52,6 +52,7 @@ type Frontend struct { readyForQuery ReadyForQuery rowDescription RowDescription portalSuspended PortalSuspended + negotiateProtocolVersion NegotiateProtocolVersion bodyLen int maxBodyLen int // maxBodyLen is the maximum length of a message body in octets. If a message body exceeds this length, Receive will return an error. @@ -230,7 +231,7 @@ func (f *Frontend) SendExecute(msg *Execute) { f.wbuf = newBuf if f.tracer != nil { - f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg) + f.tracer.traceExecute('F', int32(len(f.wbuf)-prevLen), msg) } } @@ -312,7 +313,7 @@ func (f *Frontend) Receive() (BackendMessage, error) { f.msgType = header[0] - msgLength := int(binary.BigEndian.Uint32(header[1:])) + msgLength := int(int32(binary.BigEndian.Uint32(header[1:]))) if msgLength < 4 { return nil, fmt.Errorf("invalid message length: %d", msgLength) } @@ -383,6 +384,8 @@ func (f *Frontend) Receive() (BackendMessage, error) { msg = &f.copyBothResponse case 'Z': msg = &f.readyForQuery + case 'v': + msg = &f.negotiateProtocolVersion default: return nil, fmt.Errorf("unknown message type: %c", f.msgType) } diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/function_call.go b/vendor/github.com/jackc/pgx/v5/pgproto3/function_call.go index 7d83579..23bbd8b 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/function_call.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/function_call.go @@ -23,6 +23,11 @@ func (*FunctionCall) Frontend() {} func (dst *FunctionCall) Decode(src []byte) error { *dst = FunctionCall{} rp := 0 + + if len(src) < 8 { + return &invalidMessageFormatErr{messageType: "FunctionCall"} + } + // Specifies the object ID of the function to call. dst.Function = binary.BigEndian.Uint32(src[rp:]) rp += 4 @@ -32,8 +37,13 @@ func (dst *FunctionCall) Decode(src []byte) error { // or it can equal the actual number of arguments. nArgumentCodes := int(binary.BigEndian.Uint16(src[rp:])) rp += 2 + + if len(src[rp:]) < nArgumentCodes*2+2 { + return &invalidMessageFormatErr{messageType: "FunctionCall"} + } + argumentCodes := make([]uint16, nArgumentCodes) - for i := 0; i < nArgumentCodes; i++ { + for i := range nArgumentCodes { // The argument format codes. Each must presently be zero (text) or one (binary). ac := binary.BigEndian.Uint16(src[rp:]) if ac != 0 && ac != 1 { @@ -48,14 +58,22 @@ func (dst *FunctionCall) Decode(src []byte) error { nArguments := int(binary.BigEndian.Uint16(src[rp:])) rp += 2 arguments := make([][]byte, nArguments) - for i := 0; i < nArguments; i++ { + for i := range nArguments { + if len(src[rp:]) < 4 { + return &invalidMessageFormatErr{messageType: "FunctionCall"} + } // The length of the argument value, in bytes (this count does not include itself). Can be zero. // As a special case, -1 indicates a NULL argument value. No value bytes follow in the NULL case. - argumentLength := int(binary.BigEndian.Uint32(src[rp:])) + argumentLength := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 if argumentLength == -1 { arguments[i] = nil + } else if argumentLength < 0 { + return &invalidMessageFormatErr{messageType: "FunctionCall"} } else { + if len(src[rp:]) < argumentLength { + return &invalidMessageFormatErr{messageType: "FunctionCall"} + } // The value of the argument, in the format indicated by the associated format code. n is the above length. argumentValue := src[rp : rp+argumentLength] rp += argumentLength @@ -64,6 +82,9 @@ func (dst *FunctionCall) Decode(src []byte) error { } dst.Arguments = arguments // The format code for the function result. Must presently be zero (text) or one (binary). + if len(src[rp:]) < 2 { + return &invalidMessageFormatErr{messageType: "FunctionCall"} + } resultFormatCode := binary.BigEndian.Uint16(src[rp:]) if resultFormatCode != 0 && resultFormatCode != 1 { return &invalidMessageFormatErr{messageType: "FunctionCall"} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/function_call_response.go b/vendor/github.com/jackc/pgx/v5/pgproto3/function_call_response.go index 1f27349..6b6ed8b 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/function_call_response.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/function_call_response.go @@ -22,7 +22,7 @@ func (dst *FunctionCallResponse) Decode(src []byte) error { return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} } rp := 0 - resultSize := int(binary.BigEndian.Uint32(src[rp:])) + resultSize := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 if resultSize == -1 { @@ -30,7 +30,7 @@ func (dst *FunctionCallResponse) Decode(src []byte) error { return nil } - if len(src[rp:]) != resultSize { + if resultSize < 0 || len(src[rp:]) != resultSize { return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} } diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/negotiate_protocol_version.go b/vendor/github.com/jackc/pgx/v5/pgproto3/negotiate_protocol_version.go new file mode 100644 index 0000000..43bd7ec --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/negotiate_protocol_version.go @@ -0,0 +1,93 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type NegotiateProtocolVersion struct { + NewestMinorProtocol uint32 + UnrecognizedOptions []string +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*NegotiateProtocolVersion) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *NegotiateProtocolVersion) Decode(src []byte) error { + if len(src) < 8 { + return &invalidMessageLenErr{messageType: "NegotiateProtocolVersion", expectedLen: 8, actualLen: len(src)} + } + + dst.NewestMinorProtocol = binary.BigEndian.Uint32(src[:4]) + optionCount := int(binary.BigEndian.Uint32(src[4:8])) + + rp := 8 + + // Use the remaining message size as an upper bound for capacity to prevent + // malicious optionCount values from causing excessive memory allocation. + capHint := optionCount + if remaining := len(src) - rp; capHint > remaining { + capHint = remaining + } + dst.UnrecognizedOptions = make([]string, 0, capHint) + for i := 0; i < optionCount; i++ { + if rp >= len(src) { + return &invalidMessageFormatErr{messageType: "NegotiateProtocolVersion"} + } + end := rp + for end < len(src) && src[end] != 0 { + end++ + } + if end >= len(src) { + return &invalidMessageFormatErr{messageType: "NegotiateProtocolVersion"} + } + dst.UnrecognizedOptions = append(dst.UnrecognizedOptions, string(src[rp:end])) + rp = end + 1 + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *NegotiateProtocolVersion) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'v') + dst = pgio.AppendUint32(dst, src.NewestMinorProtocol) + dst = pgio.AppendUint32(dst, uint32(len(src.UnrecognizedOptions))) + for _, option := range src.UnrecognizedOptions { + dst = append(dst, option...) + dst = append(dst, 0) + } + return finishMessage(dst, sp) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src NegotiateProtocolVersion) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + NewestMinorProtocol uint32 + UnrecognizedOptions []string + }{ + Type: "NegotiateProtocolVersion", + NewestMinorProtocol: src.NewestMinorProtocol, + UnrecognizedOptions: src.UnrecognizedOptions, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *NegotiateProtocolVersion) UnmarshalJSON(data []byte) error { + var msg struct { + NewestMinorProtocol uint32 + UnrecognizedOptions []string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.NewestMinorProtocol = msg.NewestMinorProtocol + dst.UnrecognizedOptions = msg.UnrecognizedOptions + return nil +} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/parameter_description.go b/vendor/github.com/jackc/pgx/v5/pgproto3/parameter_description.go index 1ef27b7..58eb26e 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/parameter_description.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/parameter_description.go @@ -33,7 +33,7 @@ func (dst *ParameterDescription) Decode(src []byte) error { *dst = ParameterDescription{ParameterOIDs: make([]uint32, parameterCount)} - for i := 0; i < parameterCount; i++ { + for i := range parameterCount { dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4)) } diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/parse.go b/vendor/github.com/jackc/pgx/v5/pgproto3/parse.go index 6ba3486..8fb8de5 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/parse.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/parse.go @@ -43,7 +43,7 @@ func (dst *Parse) Decode(src []byte) error { } parameterOIDCount := int(binary.BigEndian.Uint16(buf.Next(2))) - for i := 0; i < parameterOIDCount; i++ { + for range parameterOIDCount { if buf.Len() < 4 { return &invalidMessageFormatErr{messageType: "Parse"} } diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/query.go b/vendor/github.com/jackc/pgx/v5/pgproto3/query.go index aebdfde..9e16465 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/query.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/query.go @@ -15,6 +15,10 @@ func (*Query) Frontend() {} // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *Query) Decode(src []byte) error { + if len(src) == 0 { + return &invalidMessageFormatErr{messageType: "Query"} + } + i := bytes.IndexByte(src, 0) if i != len(src)-1 { return &invalidMessageFormatErr{messageType: "Query"} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/row_description.go b/vendor/github.com/jackc/pgx/v5/pgproto3/row_description.go index c40a226..b46f510 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/row_description.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/row_description.go @@ -64,7 +64,7 @@ func (dst *RowDescription) Decode(src []byte) error { dst.Fields = dst.Fields[0:0] - for i := 0; i < fieldCount; i++ { + for range fieldCount { var fd FieldDescription idx := bytes.IndexByte(src[rp:], 0) diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/sasl_initial_response.go b/vendor/github.com/jackc/pgx/v5/pgproto3/sasl_initial_response.go index 9eb1b6a..123f3cd 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/sasl_initial_response.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/sasl_initial_response.go @@ -32,6 +32,9 @@ func (dst *SASLInitialResponse) Decode(src []byte) error { dst.AuthMechanism = string(src[rp:idx]) rp = idx + 1 + if len(src[rp:]) < 4 { + return errors.New("invalid SASLInitialResponse") + } rp += 4 // The rest of the message is data so we can just skip the size dst.Data = src[rp:] diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/startup_message.go b/vendor/github.com/jackc/pgx/v5/pgproto3/startup_message.go index 3af4587..6caab3e 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/startup_message.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/startup_message.go @@ -10,7 +10,11 @@ import ( "github.com/jackc/pgx/v5/internal/pgio" ) -const ProtocolVersionNumber = 196608 // 3.0 +const ( + ProtocolVersion30 = 196608 // 3.0 + ProtocolVersion32 = 196610 // 3.2 + ProtocolVersionNumber = ProtocolVersion30 // Default is still 3.0 +) type StartupMessage struct { ProtocolVersion uint32 @@ -30,8 +34,8 @@ func (dst *StartupMessage) Decode(src []byte) error { dst.ProtocolVersion = binary.BigEndian.Uint32(src) rp := 4 - if dst.ProtocolVersion != ProtocolVersionNumber { - return fmt.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion) + if dst.ProtocolVersion != ProtocolVersion30 && dst.ProtocolVersion != ProtocolVersion32 { + return fmt.Errorf("Bad startup message version number. Expected %d or %d, got %d", ProtocolVersion30, ProtocolVersion32, dst.ProtocolVersion) } dst.Parameters = make(map[string]string) diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/trace.go b/vendor/github.com/jackc/pgx/v5/pgproto3/trace.go index 6cc7d3e..2f9da62 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/trace.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/trace.go @@ -82,7 +82,7 @@ func (t *tracer) traceMessage(sender byte, encodedLen int32, msg Message) { case *ErrorResponse: t.traceErrorResponse(sender, encodedLen, msg) case *Execute: - t.TraceQueryute(sender, encodedLen, msg) + t.traceExecute(sender, encodedLen, msg) case *Flush: t.traceFlush(sender, encodedLen, msg) case *FunctionCall: @@ -260,7 +260,7 @@ func (t *tracer) traceErrorResponse(sender byte, encodedLen int32, msg *ErrorRes t.writeTrace(sender, encodedLen, "ErrorResponse", nil) } -func (t *tracer) TraceQueryute(sender byte, encodedLen int32, msg *Execute) { +func (t *tracer) traceExecute(sender byte, encodedLen int32, msg *Execute) { t.writeTrace(sender, encodedLen, "Execute", func() { fmt.Fprintf(t.buf, "\t %s %d", traceDoubleQuotedString([]byte(msg.Portal)), msg.MaxRows) }) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/array.go b/vendor/github.com/jackc/pgx/v5/pgtype/array.go index 872a088..10b96e7 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/array.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/array.go @@ -38,6 +38,10 @@ func cardinality(dimensions []ArrayDimension) int { elementCount *= int(d.Length) } + if elementCount < 0 { + return 0 + } + return elementCount } @@ -51,16 +55,20 @@ func (dst *arrayHeader) DecodeBinary(m *Map, src []byte) (int, error) { numDims := int(binary.BigEndian.Uint32(src[rp:])) rp += 4 + if numDims > 6 { + return 0, fmt.Errorf("array has too many dimensions: %d", numDims) + } + dst.ContainsNull = binary.BigEndian.Uint32(src[rp:]) == 1 rp += 4 dst.ElementOID = binary.BigEndian.Uint32(src[rp:]) rp += 4 - dst.Dimensions = make([]ArrayDimension, numDims) if len(src) < 12+numDims*8 { return 0, fmt.Errorf("array header too short for %d dimensions: %d", numDims, len(src)) } + dst.Dimensions = make([]ArrayDimension, numDims) for i := range dst.Dimensions { dst.Dimensions[i].Length = int32(binary.BigEndian.Uint32(src[rp:])) rp += 4 @@ -299,7 +307,7 @@ func arrayParseQuotedValue(buf *bytes.Buffer) (string, bool, error) { return "", false, err } case '"': - r, _, err = buf.ReadRune() + _, _, err = buf.ReadRune() if err != nil { return "", false, err } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/array_codec.go b/vendor/github.com/jackc/pgx/v5/pgtype/array_codec.go index bf5f698..f6b36f4 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/array_codec.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/array_codec.go @@ -118,7 +118,7 @@ func (p *encodePlanArrayCodecText) Encode(value any, buf []byte) (newBuf []byte, var encodePlan EncodePlan var lastElemType reflect.Type inElemBuf := make([]byte, 0, 32) - for i := 0; i < elementCount; i++ { + for i := range elementCount { if i > 0 { buf = append(buf, ',') } @@ -131,7 +131,7 @@ func (p *encodePlanArrayCodecText) Encode(value any, buf []byte) (newBuf []byte, elem := array.Index(i) var elemBuf []byte - if elem != nil { + if isNil, _ := isNilDriverValuer(elem); !isNil { elemType := reflect.TypeOf(elem) if lastElemType != elemType { lastElemType = elemType @@ -189,13 +189,13 @@ func (p *encodePlanArrayCodecBinary) Encode(value any, buf []byte) (newBuf []byt var encodePlan EncodePlan var lastElemType reflect.Type - for i := 0; i < elementCount; i++ { + for i := range elementCount { sp := len(buf) buf = pgio.AppendInt32(buf, -1) elem := array.Index(i) var elemBuf []byte - if elem != nil { + if isNil, _ := isNilDriverValuer(elem); !isNil { elemType := reflect.TypeOf(elem) if lastElemType != elemType { lastElemType = elemType @@ -270,7 +270,7 @@ func (c *ArrayCodec) decodeBinary(m *Map, arrayOID uint32, src []byte, array Arr elementScanPlan = m.PlanScan(c.ElementType.OID, BinaryFormatCode, array.ScanIndex(0)) } - for i := 0; i < elementCount; i++ { + for i := range elementCount { elem := array.ScanIndex(i) elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 @@ -388,7 +388,7 @@ func isRagged(slice reflect.Value) bool { sliceLen := slice.Len() innerLen := 0 - for i := 0; i < sliceLen; i++ { + for i := range sliceLen { if i == 0 { innerLen = slice.Index(i).Len() } else { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/builtin_wrappers.go b/vendor/github.com/jackc/pgx/v5/pgtype/builtin_wrappers.go index 8496442..126e0be 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/builtin_wrappers.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/builtin_wrappers.go @@ -892,7 +892,7 @@ func (a *anyMultiDimSliceArray) makeMultidimensionalSlice(sliceType reflect.Type sliceLen := int(dimensions[0].Length) slice := reflect.MakeSlice(sliceType, sliceLen, sliceLen) - for i := 0; i < sliceLen; i++ { + for i := range sliceLen { subSlice := a.makeMultidimensionalSlice(sliceType.Elem(), dimensions[1:], flatSlice, flatSliceIdx+(i*int(dimensions[1].Length))) slice.Index(i).Set(subSlice) } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/composite.go b/vendor/github.com/jackc/pgx/v5/pgtype/composite.go index 598cf7a..4667036 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/composite.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/composite.go @@ -289,7 +289,7 @@ type CompositeBinaryScanner struct { err error } -// NewCompositeBinaryScanner a scanner over a binary encoded composite balue. +// NewCompositeBinaryScanner a scanner over a binary encoded composite value. func NewCompositeBinaryScanner(m *Map, src []byte) *CompositeBinaryScanner { rp := 0 if len(src[rp:]) < 4 { @@ -476,7 +476,7 @@ func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field any) { return } - if field == nil { + if isNil, _ := isNilDriverValuer(field); isNil { b.buf = pgio.AppendUint32(b.buf, oid) b.buf = pgio.AppendInt32(b.buf, -1) b.fieldCount++ @@ -533,7 +533,7 @@ func (b *CompositeTextBuilder) AppendValue(oid uint32, field any) { return } - if field == nil { + if isNil, _ := isNilDriverValuer(field); isNil { b.buf = append(b.buf, ',') return } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/convert.go b/vendor/github.com/jackc/pgx/v5/pgtype/convert.go index 8a9cee9..5cfc0ea 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/convert.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/convert.go @@ -90,19 +90,19 @@ func GetAssignToDstType(dst any) (any, bool) { func init() { kindTypes = map[reflect.Kind]reflect.Type{ - reflect.Bool: reflect.TypeOf(false), - reflect.Float32: reflect.TypeOf(float32(0)), - reflect.Float64: reflect.TypeOf(float64(0)), - reflect.Int: reflect.TypeOf(int(0)), - reflect.Int8: reflect.TypeOf(int8(0)), - reflect.Int16: reflect.TypeOf(int16(0)), - reflect.Int32: reflect.TypeOf(int32(0)), - reflect.Int64: reflect.TypeOf(int64(0)), - reflect.Uint: reflect.TypeOf(uint(0)), - reflect.Uint8: reflect.TypeOf(uint8(0)), - reflect.Uint16: reflect.TypeOf(uint16(0)), - reflect.Uint32: reflect.TypeOf(uint32(0)), - reflect.Uint64: reflect.TypeOf(uint64(0)), - reflect.String: reflect.TypeOf(""), + reflect.Bool: reflect.TypeFor[bool](), + reflect.Float32: reflect.TypeFor[float32](), + reflect.Float64: reflect.TypeFor[float64](), + reflect.Int: reflect.TypeFor[int](), + reflect.Int8: reflect.TypeFor[int8](), + reflect.Int16: reflect.TypeFor[int16](), + reflect.Int32: reflect.TypeFor[int32](), + reflect.Int64: reflect.TypeFor[int64](), + reflect.Uint: reflect.TypeFor[uint](), + reflect.Uint8: reflect.TypeFor[uint8](), + reflect.Uint16: reflect.TypeFor[uint16](), + reflect.Uint32: reflect.TypeFor[uint32](), + reflect.Uint64: reflect.TypeFor[uint64](), + reflect.String: reflect.TypeFor[string](), } } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/date.go b/vendor/github.com/jackc/pgx/v5/pgtype/date.go index 4470568..68c9585 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/date.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/date.go @@ -5,7 +5,6 @@ import ( "encoding/binary" "encoding/json" "fmt" - "regexp" "strconv" "time" @@ -271,8 +270,6 @@ func (scanPlanBinaryDateToDateScanner) Scan(src []byte, dst any) error { type scanPlanTextAnyToDateScanner struct{} -var dateRegexp = regexp.MustCompile(`^(\d{4,})-(\d\d)-(\d\d)( BC)?$`) - func (scanPlanTextAnyToDateScanner) Scan(src []byte, dst any) error { scanner := (dst).(DateScanner) @@ -280,41 +277,104 @@ func (scanPlanTextAnyToDateScanner) Scan(src []byte, dst any) error { return scanner.ScanDate(Date{}) } - sbuf := string(src) - match := dateRegexp.FindStringSubmatch(sbuf) - if match != nil { - year, err := strconv.ParseInt(match[1], 10, 32) - if err != nil { - return fmt.Errorf("BUG: cannot parse date that regexp matched (year): %w", err) - } + // Check infinity cases first + if len(src) == 8 && string(src) == "infinity" { + return scanner.ScanDate(Date{InfinityModifier: Infinity, Valid: true}) + } + if len(src) == 9 && string(src) == "-infinity" { + return scanner.ScanDate(Date{InfinityModifier: -Infinity, Valid: true}) + } - month, err := strconv.ParseInt(match[2], 10, 32) - if err != nil { - return fmt.Errorf("BUG: cannot parse date that regexp matched (month): %w", err) - } + // Format: YYYY-MM-DD or YYYY...-MM-DD BC + // Minimum: 10 chars (2000-01-01), with BC: 13 chars + if len(src) < 10 { + return fmt.Errorf("invalid date format") + } - day, err := strconv.ParseInt(match[3], 10, 32) - if err != nil { - return fmt.Errorf("BUG: cannot parse date that regexp matched (month): %w", err) - } + // Check for BC suffix + bc := false + datePart := src + if len(src) >= 13 && string(src[len(src)-3:]) == " BC" { + bc = true + datePart = src[:len(src)-3] + } - // BC matched - if len(match[4]) > 0 { - year = -year + 1 + // Find year-month separator (first dash after at least 4 digits) + yearEnd := -1 + for i := 4; i < len(datePart); i++ { + if datePart[i] == '-' { + yearEnd = i + break + } + if datePart[i] < '0' || datePart[i] > '9' { + return fmt.Errorf("invalid date format") } + } + if yearEnd == -1 || yearEnd+6 > len(datePart) { + return fmt.Errorf("invalid date format") + } - t := time.Date(int(year), time.Month(month), int(day), 0, 0, 0, 0, time.UTC) - return scanner.ScanDate(Date{Time: t, Valid: true}) + // Validate: -MM-DD structure after year + if datePart[yearEnd+3] != '-' { + return fmt.Errorf("invalid date format") } - switch sbuf { - case "infinity": - return scanner.ScanDate(Date{InfinityModifier: Infinity, Valid: true}) - case "-infinity": - return scanner.ScanDate(Date{InfinityModifier: -Infinity, Valid: true}) - default: + // Parse year + year, err := parseDigits(datePart[:yearEnd]) + if err != nil { + return fmt.Errorf("invalid date format") + } + + // Parse month (2 digits) + month, err := parse2Digits(datePart[yearEnd+1 : yearEnd+3]) + if err != nil { + return fmt.Errorf("invalid date format") + } + + // Parse day (2 digits) + day, err := parse2Digits(datePart[yearEnd+4 : yearEnd+6]) + if err != nil { + return fmt.Errorf("invalid date format") + } + + // Ensure nothing extra after day + if yearEnd+6 != len(datePart) { return fmt.Errorf("invalid date format") } + + if bc { + year = -year + 1 + } + + t := time.Date(int(year), time.Month(month), int(day), 0, 0, 0, 0, time.UTC) + return scanner.ScanDate(Date{Time: t, Valid: true}) +} + +// parse2Digits parses exactly 2 ASCII digits. +func parse2Digits(b []byte) (int64, error) { + if len(b) != 2 { + return 0, fmt.Errorf("expected 2 digits") + } + d1, d2 := b[0], b[1] + if d1 < '0' || d1 > '9' || d2 < '0' || d2 > '9' { + return 0, fmt.Errorf("expected digits") + } + return int64(d1-'0')*10 + int64(d2-'0'), nil +} + +// parseDigits parses a sequence of ASCII digits. +func parseDigits(b []byte) (int64, error) { + if len(b) == 0 { + return 0, fmt.Errorf("empty") + } + var n int64 + for _, c := range b { + if c < '0' || c > '9' { + return 0, fmt.Errorf("non-digit") + } + n = n*10 + int64(c-'0') + } + return n, nil } func (c DateCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go b/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go index ef86492..c5fa22c 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go @@ -198,17 +198,24 @@ func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error { pairCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += uint32Len + if pairCount < 0 { + return fmt.Errorf("hstore invalid pair count: %d", pairCount) + } + hstore := make(Hstore, pairCount) // one allocation for all *string, rather than one per string, just like text parsing valueStrings := make([]string, pairCount) - for i := 0; i < pairCount; i++ { + for i := range pairCount { if len(src[rp:]) < uint32Len { return fmt.Errorf("hstore incomplete %v", src) } keyLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += uint32Len + if keyLen < 0 { + return fmt.Errorf("hstore invalid key length: %d", keyLen) + } if len(src[rp:]) < keyLen { return fmt.Errorf("hstore incomplete %v", src) } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/int.go b/vendor/github.com/jackc/pgx/v5/pgtype/int.go index d1b8eb6..95032e5 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/int.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/int.go @@ -78,7 +78,7 @@ func (dst *Int2) Scan(src any) error { } if n < math.MinInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", n) + return fmt.Errorf("%d is less than minimum value for Int2", n) } if n > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", n) @@ -641,7 +641,7 @@ func (dst *Int4) Scan(src any) error { } if n < math.MinInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", n) + return fmt.Errorf("%d is less than minimum value for Int4", n) } if n > math.MaxInt32 { return fmt.Errorf("%d is greater than maximum value for Int4", n) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/interval.go b/vendor/github.com/jackc/pgx/v5/pgtype/interval.go index ba5e818..b1bc785 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/interval.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/interval.go @@ -11,7 +11,7 @@ import ( ) const ( - microsecondsPerSecond = 1000000 + microsecondsPerSecond = 1_000_000 microsecondsPerMinute = 60 * microsecondsPerSecond microsecondsPerHour = 60 * microsecondsPerMinute microsecondsPerDay = 24 * microsecondsPerHour @@ -223,6 +223,8 @@ func (scanPlanTextAnyToIntervalScanner) Scan(src []byte, dst any) error { months += int32(scalar) case "day", "days": days = int32(scalar) + default: + return fmt.Errorf("bad interval format: %q", parts[i+1]) } } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/json.go b/vendor/github.com/jackc/pgx/v5/pgtype/json.go index 60aa2b7..bf70735 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/json.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/json.go @@ -157,7 +157,7 @@ func (c *JSONCodec) planScan(m *Map, oid uint32, formatCode int16, target any, d case BytesScanner: return &scanPlanBinaryBytesToBytesScanner{} case sql.Scanner: - return &scanPlanSQLScanner{formatCode: formatCode} + return &scanPlanCodecSQLScanner{c: c, m: m, oid: oid, formatCode: formatCode} } rv := reflect.ValueOf(target) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/multirange.go b/vendor/github.com/jackc/pgx/v5/pgtype/multirange.go index 4fe6dd4..0c02575 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/multirange.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/multirange.go @@ -98,7 +98,7 @@ func (p *encodePlanMultirangeCodecText) Encode(value any, buf []byte) (newBuf [] var encodePlan EncodePlan var lastElemType reflect.Type inElemBuf := make([]byte, 0, 32) - for i := 0; i < elementCount; i++ { + for i := range elementCount { if i > 0 { buf = append(buf, ',') } @@ -151,7 +151,7 @@ func (p *encodePlanMultirangeCodecBinary) Encode(value any, buf []byte) (newBuf var encodePlan EncodePlan var lastElemType reflect.Type - for i := 0; i < elementCount; i++ { + for i := range elementCount { sp := len(buf) buf = pgio.AppendInt32(buf, -1) @@ -210,6 +210,11 @@ func (c *MultirangeCodec) decodeBinary(m *Map, multirangeOID uint32, src []byte, elementCount := int(binary.BigEndian.Uint32(src[rp:])) rp += 4 + // Each element requires at least 4 bytes for its length prefix. + if elementCount > len(src)/4 { + return fmt.Errorf("multirange element count %d exceeds available data", elementCount) + } + err := multirange.SetLen(elementCount) if err != nil { return err @@ -224,7 +229,7 @@ func (c *MultirangeCodec) decodeBinary(m *Map, multirangeOID uint32, src []byte, elementScanPlan = m.PlanScan(c.ElementType.OID, BinaryFormatCode, multirange.ScanIndex(0)) } - for i := 0; i < elementCount; i++ { + for i := range elementCount { elem := multirange.ScanIndex(i) elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/numeric.go b/vendor/github.com/jackc/pgx/v5/pgtype/numeric.go index 7d23690..c9022ab 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/numeric.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/numeric.go @@ -14,7 +14,7 @@ import ( ) // PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000 -const nbase = 10000 +const nbase = 10_000 const ( pgNumericNaN = 0x00000000c0000000 @@ -28,7 +28,6 @@ const ( ) var ( - big0 *big.Int = big.NewInt(0) big1 *big.Int = big.NewInt(1) big10 *big.Int = big.NewInt(10) big100 *big.Int = big.NewInt(100) @@ -129,7 +128,7 @@ func (n Numeric) Int64Value() (Int8, error) { } func (n *Numeric) ScanScientific(src string) error { - if !strings.ContainsAny("eE", src) { + if !strings.ContainsAny(src, "eE") { return scanPlanTextAnyToNumericScanner{}.Scan([]byte(src), n) } @@ -166,7 +165,7 @@ func (n *Numeric) toBigInt() (*big.Int, error) { div.Exp(big10, big.NewInt(int64(-n.Exp)), nil) remainder := &big.Int{} num.DivMod(num, div, remainder) - if remainder.Cmp(big0) != 0 { + if remainder.Sign() != 0 { return nil, fmt.Errorf("cannot convert %v to integer", n) } return num, nil @@ -194,14 +193,11 @@ func parseNumericString(str string) (n *big.Int, exp int32, err error) { } func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) { - digits := len(src) / 2 - if digits > 4 { - digits = 4 - } + digits := min(len(src)/2, 4) rp := 0 - for i := 0; i < digits; i++ { + for i := range digits { if i > 0 { accum *= nbase } @@ -268,6 +264,10 @@ func (n *Numeric) UnmarshalJSON(src []byte) error { // numberString returns a string of the number. undefined if NaN, infinite, or NULL func (n Numeric) numberTextBytes() []byte { + if n.Int == nil { + return []byte("0") + } + intStr := n.Int.String() buf := &bytes.Buffer{} @@ -280,14 +280,14 @@ func (n Numeric) numberTextBytes() []byte { exp := int(n.Exp) if exp > 0 { buf.WriteString(intStr) - for i := 0; i < exp; i++ { + for range exp { buf.WriteByte('0') } } else if exp < 0 { if len(intStr) <= -exp { buf.WriteString("0.") leadingZeros := -exp - len(intStr) - for i := 0; i < leadingZeros; i++ { + for range leadingZeros { buf.WriteByte('0') } buf.WriteString(intStr) @@ -409,7 +409,7 @@ func encodeNumericBinary(n Numeric, buf []byte) (newBuf []byte, err error) { } var sign int16 - if n.Int.Cmp(big0) < 0 { + if n.Int != nil && n.Int.Sign() < 0 { sign = 16384 } @@ -417,7 +417,9 @@ func encodeNumericBinary(n Numeric, buf []byte) (newBuf []byte, err error) { wholePart := &big.Int{} fracPart := &big.Int{} remainder := &big.Int{} - absInt.Abs(n.Int) + if n.Int != nil { + absInt.Abs(n.Int) + } // Normalize absInt and exp to where exp is always a multiple of 4. This makes // converting to 16-bit base 10,000 digits easier. @@ -447,12 +449,12 @@ func encodeNumericBinary(n Numeric, buf []byte) (newBuf []byte, err error) { var wholeDigits, fracDigits []int16 - for wholePart.Cmp(big0) != 0 { + for wholePart.Sign() != 0 { wholePart.DivMod(wholePart, bigNBase, remainder) wholeDigits = append(wholeDigits, int16(remainder.Int64())) } - if fracPart.Cmp(big0) != 0 { + if fracPart.Sign() != 0 { for fracPart.Cmp(big1) != 0 { fracPart.DivMod(fracPart, bigNBase, remainder) fracDigits = append(fracDigits, int16(remainder.Int64())) @@ -658,18 +660,19 @@ func (scanPlanBinaryNumericToNumericScanner) Scan(src []byte, dst any) error { exp := (int32(weight) - int32(ndigits) + 1) * 4 if dscale > 0 { - fracNBaseDigits := int16(int32(ndigits) - int32(weight) - 1) + fracNBaseDigits := int(ndigits) - int(weight) - 1 fracDecimalDigits := fracNBaseDigits * 4 + dscaleInt := int(dscale) - if dscale > fracDecimalDigits { - multCount := int(dscale - fracDecimalDigits) - for i := 0; i < multCount; i++ { + if dscaleInt > fracDecimalDigits { + multCount := dscaleInt - fracDecimalDigits + for range multCount { accum.Mul(accum, big10) exp-- } - } else if dscale < fracDecimalDigits { - divCount := int(fracDecimalDigits - dscale) - for i := 0; i < divCount; i++ { + } else if dscaleInt < fracDecimalDigits { + divCount := fracDecimalDigits - dscaleInt + for range divCount { accum.Div(accum, big10) exp++ } @@ -681,7 +684,7 @@ func (scanPlanBinaryNumericToNumericScanner) Scan(src []byte, dst any) error { if exp >= 0 { for { reduced.DivMod(accum, big10, remainder) - if remainder.Cmp(big0) != 0 { + if remainder.Sign() != 0 { break } accum.Set(reduced) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/path.go b/vendor/github.com/jackc/pgx/v5/pgtype/path.go index 81dc1e5..685996a 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/path.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/path.go @@ -195,7 +195,7 @@ func (scanPlanBinaryPathToPathScanner) Scan(src []byte, dst any) error { } points := make([]Vec2, pointCount) - for i := 0; i < len(points); i++ { + for i := range points { x := binary.BigEndian.Uint64(src[rp:]) rp += 8 y := binary.BigEndian.Uint64(src[rp:]) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go index b3ef320..29721a4 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go @@ -96,6 +96,8 @@ const ( RecordArrayOID = 2287 UUIDOID = 2950 UUIDArrayOID = 2951 + TSVectorOID = 3614 + TSVectorArrayOID = 3643 JSONBOID = 3802 JSONBArrayOID = 3807 DaterangeOID = 3912 @@ -523,20 +525,20 @@ type SkipUnderlyingTypePlanner interface { } var elemKindToPointerTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{ - reflect.Int: reflect.TypeOf(new(int)), - reflect.Int8: reflect.TypeOf(new(int8)), - reflect.Int16: reflect.TypeOf(new(int16)), - reflect.Int32: reflect.TypeOf(new(int32)), - reflect.Int64: reflect.TypeOf(new(int64)), - reflect.Uint: reflect.TypeOf(new(uint)), - reflect.Uint8: reflect.TypeOf(new(uint8)), - reflect.Uint16: reflect.TypeOf(new(uint16)), - reflect.Uint32: reflect.TypeOf(new(uint32)), - reflect.Uint64: reflect.TypeOf(new(uint64)), - reflect.Float32: reflect.TypeOf(new(float32)), - reflect.Float64: reflect.TypeOf(new(float64)), - reflect.String: reflect.TypeOf(new(string)), - reflect.Bool: reflect.TypeOf(new(bool)), + reflect.Int: reflect.TypeFor[*int](), + reflect.Int8: reflect.TypeFor[*int8](), + reflect.Int16: reflect.TypeFor[*int16](), + reflect.Int32: reflect.TypeFor[*int32](), + reflect.Int64: reflect.TypeFor[*int64](), + reflect.Uint: reflect.TypeFor[*uint](), + reflect.Uint8: reflect.TypeFor[*uint8](), + reflect.Uint16: reflect.TypeFor[*uint16](), + reflect.Uint32: reflect.TypeFor[*uint32](), + reflect.Uint64: reflect.TypeFor[*uint64](), + reflect.Float32: reflect.TypeFor[*float32](), + reflect.Float64: reflect.TypeFor[*float64](), + reflect.String: reflect.TypeFor[*string](), + reflect.Bool: reflect.TypeFor[*bool](), } type underlyingTypeScanPlan struct { @@ -901,7 +903,7 @@ func (plan *pointerEmptyInterfaceScanPlan) Scan(src []byte, dst any) error { return nil } -// TryWrapStructPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter. +// TryWrapStructScanPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter. func TryWrapStructScanPlan(target any) (plan WrappedScanPlanNextSetter, nextValue any, ok bool) { targetValue := reflect.ValueOf(target) if targetValue.Kind() != reflect.Ptr { @@ -1135,10 +1137,18 @@ func (m *Map) planScan(oid uint32, formatCode int16, target any, depth int) Scan } } - if dt != nil { - if _, ok := target.(*any); ok { - return &pointerEmptyInterfaceScanPlan{codec: dt.Codec, m: m, oid: oid, formatCode: formatCode} + if _, ok := target.(*any); ok { + var codec Codec + if dt != nil { + codec = dt.Codec + } else { + if formatCode == TextFormatCode { + codec = TextCodec{} + } else { + codec = ByteaCodec{} + } } + return &pointerEmptyInterfaceScanPlan{codec: codec, m: m, oid: oid, formatCode: formatCode} } return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} @@ -1364,23 +1374,23 @@ func TryWrapDerefPointerEncodePlan(value any) (plan WrappedEncodePlanNextSetter, } var kindToTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{ - reflect.Int: reflect.TypeOf(int(0)), - reflect.Int8: reflect.TypeOf(int8(0)), - reflect.Int16: reflect.TypeOf(int16(0)), - reflect.Int32: reflect.TypeOf(int32(0)), - reflect.Int64: reflect.TypeOf(int64(0)), - reflect.Uint: reflect.TypeOf(uint(0)), - reflect.Uint8: reflect.TypeOf(uint8(0)), - reflect.Uint16: reflect.TypeOf(uint16(0)), - reflect.Uint32: reflect.TypeOf(uint32(0)), - reflect.Uint64: reflect.TypeOf(uint64(0)), - reflect.Float32: reflect.TypeOf(float32(0)), - reflect.Float64: reflect.TypeOf(float64(0)), - reflect.String: reflect.TypeOf(""), - reflect.Bool: reflect.TypeOf(false), -} - -var byteSliceType = reflect.TypeOf([]byte{}) + reflect.Int: reflect.TypeFor[int](), + reflect.Int8: reflect.TypeFor[int8](), + reflect.Int16: reflect.TypeFor[int16](), + reflect.Int32: reflect.TypeFor[int32](), + reflect.Int64: reflect.TypeFor[int64](), + reflect.Uint: reflect.TypeFor[uint](), + reflect.Uint8: reflect.TypeFor[uint8](), + reflect.Uint16: reflect.TypeFor[uint16](), + reflect.Uint32: reflect.TypeFor[uint32](), + reflect.Uint64: reflect.TypeFor[uint64](), + reflect.Float32: reflect.TypeFor[float32](), + reflect.Float64: reflect.TypeFor[float64](), + reflect.String: reflect.TypeFor[string](), + reflect.Bool: reflect.TypeFor[bool](), +} + +var byteSliceType = reflect.TypeFor[[]byte]() type underlyingTypeEncodePlan struct { nextValueType reflect.Type @@ -1743,7 +1753,7 @@ func (plan *wrapFmtStringerEncodePlan) Encode(value any, buf []byte) (newBuf []b return plan.next.Encode(fmtStringerWrapper{value.(fmt.Stringer)}, buf) } -// TryWrapStructPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter. +// TryWrapStructEncodePlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter. func TryWrapStructEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { if _, ok := value.(driver.Valuer); ok { return nil, nil, false @@ -1999,7 +2009,7 @@ func (w *sqlScannerWrapper) Scan(src any) error { case []byte: bufSrc = src default: - bufSrc = []byte(fmt.Sprint(bufSrc)) + bufSrc = fmt.Append(nil, bufSrc) } } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go index 5648d89..42b39d8 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go @@ -81,6 +81,7 @@ func initDefaultMap() { defaultMap.RegisterType(&Type{Name: "record", OID: RecordOID, Codec: RecordCodec{}}) defaultMap.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}}) defaultMap.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) + defaultMap.RegisterType(&Type{Name: "tsvector", OID: TSVectorOID, Codec: TSVectorCodec{}}) defaultMap.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) defaultMap.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: &TimestampCodec{}}) defaultMap.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: &TimestamptzCodec{}}) @@ -164,6 +165,7 @@ func initDefaultMap() { defaultMap.RegisterType(&Type{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[RecordOID]}}) defaultMap.RegisterType(&Type{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TextOID]}}) defaultMap.RegisterType(&Type{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TIDOID]}}) + defaultMap.RegisterType(&Type{Name: "_tsvector", OID: TSVectorArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TSVectorOID]}}) defaultMap.RegisterType(&Type{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimeOID]}}) defaultMap.RegisterType(&Type{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimestampOID]}}) defaultMap.RegisterType(&Type{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimestamptzOID]}}) @@ -242,6 +244,7 @@ func initDefaultMap() { registerDefaultPgTypeVariants[Multirange[Range[Timestamp]]](defaultMap, "tsmultirange") registerDefaultPgTypeVariants[Range[Timestamptz]](defaultMap, "tstzrange") registerDefaultPgTypeVariants[Multirange[Range[Timestamptz]]](defaultMap, "tstzmultirange") + registerDefaultPgTypeVariants[TSVector](defaultMap, "tsvector") registerDefaultPgTypeVariants[UUID](defaultMap, "uuid") defaultMap.buildReflectTypeToType() diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/polygon.go b/vendor/github.com/jackc/pgx/v5/pgtype/polygon.go index a84b25f..e18c9da 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/polygon.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/polygon.go @@ -178,7 +178,7 @@ func (scanPlanBinaryPolygonToPolygonScanner) Scan(src []byte, dst any) error { } points := make([]Vec2, pointCount) - for i := 0; i < len(points); i++ { + for i := range points { x := binary.BigEndian.Uint64(src[rp:]) rp += 8 y := binary.BigEndian.Uint64(src[rp:]) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go b/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go index 861fa88..de500a1 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go @@ -111,9 +111,9 @@ func (ts *Timestamp) UnmarshalJSON(b []byte) error { case "-infinity": *ts = Timestamp{Valid: true, InfinityModifier: -Infinity} default: - // Parse time with or without timezonr + // Parse time with or without timezone tss := *s - // PostgreSQL uses ISO 8601 without timezone for to_json function and casting from a string to timestampt + // PostgreSQL uses ISO 8601 without timezone for to_json function and casting from a string to timestamp tim, err := time.Parse(time.RFC3339Nano, tss) if err == nil { *ts = Timestamp{Time: tim, Valid: true} @@ -176,7 +176,7 @@ func (encodePlanTimestampCodecBinary) Encode(value any, buf []byte) (newBuf []by switch ts.InfinityModifier { case Finite: t := discardTimeZone(ts.Time) - microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000 + microsecSinceUnixEpoch := t.Unix()*1_000_000 + int64(t.Nanosecond())/1000 microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K case Infinity: microsecSinceY2K = infinityMicrosecondOffset @@ -279,8 +279,8 @@ func (plan *scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) ts = Timestamp{Valid: true, InfinityModifier: -Infinity} default: tim := time.Unix( - microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, - (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), + microsecFromUnixEpochToY2K/1_000_000+microsecSinceY2K/1_000_000, + (microsecFromUnixEpochToY2K%1_000_000*1_000)+(microsecSinceY2K%1_000_000*1000), ).UTC() if plan.location != nil { tim = time.Date(tim.Year(), tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), plan.location) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/timestamptz.go b/vendor/github.com/jackc/pgx/v5/pgtype/timestamptz.go index 5d67e47..4d055bf 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/timestamptz.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/timestamptz.go @@ -15,7 +15,7 @@ const ( pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07" pgTimestamptzMinuteFormat = "2006-01-02 15:04:05.999999999Z07:00" pgTimestamptzSecondFormat = "2006-01-02 15:04:05.999999999Z07:00:00" - microsecFromUnixEpochToY2K = 946684800 * 1000000 + microsecFromUnixEpochToY2K = 946_684_800 * 1_000_000 ) const ( @@ -270,8 +270,8 @@ func (plan *scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst tstz = Timestamptz{Valid: true, InfinityModifier: -Infinity} default: tim := time.Unix( - microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, - (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), + microsecFromUnixEpochToY2K/1_000_000+microsecSinceY2K/1_000_000, + (microsecFromUnixEpochToY2K%1_000_000*1_000)+(microsecSinceY2K%1_000_000*1_000), ) if plan.location != nil { tim = tim.In(plan.location) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/tsvector.go b/vendor/github.com/jackc/pgx/v5/pgtype/tsvector.go new file mode 100644 index 0000000..b357948 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/pgtype/tsvector.go @@ -0,0 +1,507 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "fmt" + "strconv" + "strings" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type TSVectorScanner interface { + ScanTSVector(TSVector) error +} + +type TSVectorValuer interface { + TSVectorValue() (TSVector, error) +} + +// TSVector represents a PostgreSQL tsvector value. +type TSVector struct { + Lexemes []TSVectorLexeme + Valid bool +} + +// TSVectorLexeme represents a lexeme within a tsvector, consisting of a word and its positions. +type TSVectorLexeme struct { + Word string + Positions []TSVectorPosition +} + +// ScanTSVector implements the [TSVectorScanner] interface. +func (t *TSVector) ScanTSVector(v TSVector) error { + *t = v + return nil +} + +// TSVectorValue implements the [TSVectorValuer] interface. +func (t TSVector) TSVectorValue() (TSVector, error) { + return t, nil +} + +func (t TSVector) String() string { + buf, _ := encodePlanTSVectorCodecText{}.Encode(t, nil) + return string(buf) +} + +// Scan implements the [database/sql.Scanner] interface. +func (t *TSVector) Scan(src any) error { + if src == nil { + *t = TSVector{} + return nil + } + + switch src := src.(type) { + case string: + return scanPlanTextAnyToTSVectorScanner{}.scanString(src, t) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (t TSVector) Value() (driver.Value, error) { + if !t.Valid { + return nil, nil + } + + buf, err := TSVectorCodec{}.PlanEncode(nil, 0, TextFormatCode, t).Encode(t, nil) + if err != nil { + return nil, err + } + + return string(buf), nil +} + +// TSVectorWeight represents the weight label of a lexeme position in a tsvector. +type TSVectorWeight byte + +const ( + TSVectorWeightA = TSVectorWeight('A') + TSVectorWeightB = TSVectorWeight('B') + TSVectorWeightC = TSVectorWeight('C') + TSVectorWeightD = TSVectorWeight('D') +) + +// tsvectorWeightToBinary converts a TSVectorWeight to the 2-bit binary encoding used by PostgreSQL. +func tsvectorWeightToBinary(w TSVectorWeight) uint16 { + switch w { + case TSVectorWeightA: + return 3 + case TSVectorWeightB: + return 2 + case TSVectorWeightC: + return 1 + default: + return 0 // D or unset + } +} + +// tsvectorWeightFromBinary converts a 2-bit binary weight value to a TSVectorWeight. +func tsvectorWeightFromBinary(b uint16) TSVectorWeight { + switch b { + case 3: + return TSVectorWeightA + case 2: + return TSVectorWeightB + case 1: + return TSVectorWeightC + default: + return TSVectorWeightD + } +} + +// TSVectorPosition represents a lexeme position and its optional weight within a tsvector. +type TSVectorPosition struct { + Position uint16 + Weight TSVectorWeight +} + +func (p TSVectorPosition) String() string { + s := strconv.FormatUint(uint64(p.Position), 10) + if p.Weight != 0 && p.Weight != TSVectorWeightD { + s += string(p.Weight) + } + return s +} + +type TSVectorCodec struct{} + +func (TSVectorCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (TSVectorCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (TSVectorCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(TSVectorValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanTSVectorCodecBinary{} + case TextFormatCode: + return encodePlanTSVectorCodecText{} + } + + return nil +} + +type encodePlanTSVectorCodecBinary struct{} + +func (encodePlanTSVectorCodecBinary) Encode(value any, buf []byte) ([]byte, error) { + tsv, err := value.(TSVectorValuer).TSVectorValue() + if err != nil { + return nil, err + } + + if !tsv.Valid { + return nil, nil + } + + buf = pgio.AppendInt32(buf, int32(len(tsv.Lexemes))) + + for _, entry := range tsv.Lexemes { + buf = append(buf, entry.Word...) + buf = append(buf, 0x00) + buf = pgio.AppendUint16(buf, uint16(len(entry.Positions))) + + // Each position is a uint16: weight (2 bits) | position (14 bits) + for _, pos := range entry.Positions { + packed := tsvectorWeightToBinary(pos.Weight)<<14 | uint16(pos.Position)&0x3FFF + buf = pgio.AppendUint16(buf, packed) + } + } + + return buf, nil +} + +type scanPlanBinaryTSVectorToTSVectorScanner struct{} + +func (scanPlanBinaryTSVectorToTSVectorScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TSVectorScanner) + + if src == nil { + return scanner.ScanTSVector(TSVector{}) + } + + rp := 0 + + const ( + uint16Len = 2 + uint32Len = 4 + ) + + if len(src[rp:]) < uint32Len { + return fmt.Errorf("tsvector incomplete %v", src) + } + entryCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += uint32Len + + var tsv TSVector + if entryCount > 0 { + tsv.Lexemes = make([]TSVectorLexeme, entryCount) + } + + for i := range entryCount { + nullIndex := bytes.IndexByte(src[rp:], 0x00) + if nullIndex == -1 { + return fmt.Errorf("invalid tsvector binary format: missing null terminator") + } + + lexeme := TSVectorLexeme{Word: string(src[rp : rp+nullIndex])} + rp += nullIndex + 1 // skip past null terminator + + // Read position count. + if len(src[rp:]) < uint16Len { + return fmt.Errorf("invalid tsvector binary format: incomplete position count") + } + + numPositions := int(binary.BigEndian.Uint16(src[rp:])) + rp += uint16Len + + // Read each packed position: weight (2 bits) | position (14 bits) + if len(src[rp:]) < numPositions*uint16Len { + return fmt.Errorf("invalid tsvector binary format: incomplete positions") + } + + if numPositions > 0 { + lexeme.Positions = make([]TSVectorPosition, numPositions) + for pos := range numPositions { + packed := binary.BigEndian.Uint16(src[rp:]) + rp += uint16Len + lexeme.Positions[pos] = TSVectorPosition{ + Position: packed & 0x3FFF, + Weight: tsvectorWeightFromBinary(packed >> 14), + } + } + } + + tsv.Lexemes[i] = lexeme + } + tsv.Valid = true + + return scanner.ScanTSVector(tsv) +} + +var tsvectorLexemeReplacer = strings.NewReplacer( + `\`, `\\`, + `'`, `\'`, +) + +type encodePlanTSVectorCodecText struct{} + +func (encodePlanTSVectorCodecText) Encode(value any, buf []byte) ([]byte, error) { + tsv, err := value.(TSVectorValuer).TSVectorValue() + if err != nil { + return nil, err + } + + if !tsv.Valid { + return nil, nil + } + + if buf == nil { + buf = []byte{} + } + + for i, lex := range tsv.Lexemes { + if i > 0 { + buf = append(buf, ' ') + } + + buf = append(buf, '\'') + buf = append(buf, tsvectorLexemeReplacer.Replace(lex.Word)...) + buf = append(buf, '\'') + + sep := byte(':') + for _, p := range lex.Positions { + buf = append(buf, sep) + buf = append(buf, p.String()...) + sep = ',' + } + } + + return buf, nil +} + +func (TSVectorCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case TSVectorScanner: + return scanPlanBinaryTSVectorToTSVectorScanner{} + } + case TextFormatCode: + switch target.(type) { + case TSVectorScanner: + return scanPlanTextAnyToTSVectorScanner{} + } + } + + return nil +} + +type scanPlanTextAnyToTSVectorScanner struct{} + +func (s scanPlanTextAnyToTSVectorScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TSVectorScanner) + + if src == nil { + return scanner.ScanTSVector(TSVector{}) + } + + return s.scanString(string(src), scanner) +} + +func (scanPlanTextAnyToTSVectorScanner) scanString(src string, scanner TSVectorScanner) error { + tsv, err := parseTSVector(src) + if err != nil { + return err + } + return scanner.ScanTSVector(tsv) +} + +func (c TSVectorCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c TSVectorCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var tsv TSVector + err := codecScan(c, m, oid, format, src, &tsv) + if err != nil { + return nil, err + } + return tsv, nil +} + +type tsvectorParser struct { + str string + pos int +} + +func (p *tsvectorParser) atEnd() bool { + return p.pos >= len(p.str) +} + +func (p *tsvectorParser) peek() byte { + return p.str[p.pos] +} + +func (p *tsvectorParser) consume() (byte, bool) { + if p.pos >= len(p.str) { + return 0, true + } + b := p.str[p.pos] + p.pos++ + return b, false +} + +func (p *tsvectorParser) consumeSpaces() { + for !p.atEnd() && p.peek() == ' ' { + p.consume() + } +} + +// consumeLexeme consumes a single-quoted lexeme, handling single quotes and backslash escapes. +func (p *tsvectorParser) consumeLexeme() (string, error) { + ch, end := p.consume() + if end || ch != '\'' { + return "", fmt.Errorf("invalid tsvector format: lexeme must start with a single quote") + } + + var buf strings.Builder + for { + ch, end := p.consume() + if end { + return "", fmt.Errorf("invalid tsvector format: unterminated quoted lexeme") + } + + switch ch { + case '\'': + // Escaped quote ('') — write a literal single quote + if !p.atEnd() && p.peek() == '\'' { + p.consume() + buf.WriteByte('\'') + } else { + // Closing quote — lexeme is complete + return buf.String(), nil + } + case '\\': + next, end := p.consume() + if end { + return "", fmt.Errorf("invalid tsvector format: unexpected end after backslash") + } + buf.WriteByte(next) + default: + buf.WriteByte(ch) + } + } +} + +// consumePositions consumes a comma-separated list of position[weight] values. +func (p *tsvectorParser) consumePositions() ([]TSVectorPosition, error) { + var positions []TSVectorPosition + + for { + pos, err := p.consumePosition() + if err != nil { + return nil, err + } + positions = append(positions, pos) + + if p.atEnd() || p.peek() != ',' { + break + } + + p.consume() // skip ',' + } + + return positions, nil +} + +// consumePosition consumes a single position number with optional weight letter. +func (p *tsvectorParser) consumePosition() (TSVectorPosition, error) { + start := p.pos + + for !p.atEnd() && p.peek() >= '0' && p.peek() <= '9' { + p.consume() + } + + if p.pos == start { + return TSVectorPosition{}, fmt.Errorf("invalid tsvector format: expected position number") + } + + num, err := strconv.ParseUint(p.str[start:p.pos], 10, 16) + if err != nil { + return TSVectorPosition{}, fmt.Errorf("invalid tsvector format: invalid position number %q", p.str[start:p.pos]) + } + + pos := TSVectorPosition{Position: uint16(num), Weight: TSVectorWeightD} + + // Check for optional weight letter + if !p.atEnd() { + switch p.peek() { + case 'A', 'a': + pos.Weight = TSVectorWeightA + case 'B', 'b': + pos.Weight = TSVectorWeightB + case 'C', 'c': + pos.Weight = TSVectorWeightC + case 'D', 'd': + pos.Weight = TSVectorWeightD + default: + return pos, nil + } + p.consume() + } + + return pos, nil +} + +// parseTSVector parses a PostgreSQL tsvector text representation. +func parseTSVector(s string) (TSVector, error) { + result := TSVector{} + p := &tsvectorParser{str: strings.TrimSpace(s), pos: 0} + + for !p.atEnd() { + p.consumeSpaces() + if p.atEnd() { + break + } + + word, err := p.consumeLexeme() + if err != nil { + return TSVector{}, err + } + + entry := TSVectorLexeme{Word: word} + + // Check for optional positions after ':' + if !p.atEnd() && p.peek() == ':' { + p.consume() // skip ':' + + positions, err := p.consumePositions() + if err != nil { + return TSVector{}, err + } + entry.Positions = positions + } + + result.Lexemes = append(result.Lexemes, entry) + } + + result.Valid = true + + return result, nil +} diff --git a/vendor/github.com/jackc/pgx/v5/pgxpool/pool.go b/vendor/github.com/jackc/pgx/v5/pgxpool/pool.go index addfb41..8291ed8 100644 --- a/vendor/github.com/jackc/pgx/v5/pgxpool/pool.go +++ b/vendor/github.com/jackc/pgx/v5/pgxpool/pool.go @@ -3,7 +3,7 @@ package pgxpool import ( "context" "errors" - "math/rand" + "math/rand/v2" "runtime" "strconv" "sync" @@ -97,6 +97,10 @@ type Pool struct { maxConnLifetimeJitter time.Duration maxConnIdleTime time.Duration healthCheckPeriod time.Duration + pingTimeout time.Duration + + healthCheckMu sync.Mutex + healthCheckTimer *time.Timer healthCheckChan chan struct{} @@ -166,6 +170,10 @@ type Config struct { // MaxConnIdleTime is the duration after which an idle connection will be automatically closed by the health check. MaxConnIdleTime time.Duration + // PingTimeout is the maximum amount of time to wait for a connection to pong before considering it as unhealthy and + // destroying it. If zero, the default is no timeout. + PingTimeout time.Duration + // MaxConns is the maximum size of the pool. The default is the greater of 4 or runtime.NumCPU(). MaxConns int32 @@ -238,6 +246,7 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { maxConnLifetime: config.MaxConnLifetime, maxConnLifetimeJitter: config.MaxConnLifetimeJitter, maxConnIdleTime: config.MaxConnIdleTime, + pingTimeout: config.PingTimeout, healthCheckPeriod: config.HealthCheckPeriod, healthCheckChan: make(chan struct{}, 1), closeChan: make(chan struct{}), @@ -458,15 +467,25 @@ func (p *Pool) isExpired(res *puddle.Resource[*connResource]) bool { } func (p *Pool) triggerHealthCheck() { - go func() { + const healthCheckDelay = 500 * time.Millisecond + + p.healthCheckMu.Lock() + defer p.healthCheckMu.Unlock() + + if p.healthCheckTimer == nil { // Destroy is asynchronous so we give it time to actually remove itself from // the pool otherwise we might try to check the pool size too soon - time.Sleep(500 * time.Millisecond) - select { - case p.healthCheckChan <- struct{}{}: - default: - } - }() + p.healthCheckTimer = time.AfterFunc(healthCheckDelay, func() { + select { + case <-p.closeChan: + case p.healthCheckChan <- struct{}{}: + default: + } + }) + return + } + + p.healthCheckTimer.Reset(healthCheckDelay) } func (p *Pool) backgroundHealthCheck() { @@ -539,7 +558,8 @@ func (p *Pool) checkMinConns() error { // off this check // Create the number of connections needed to get to both minConns and minIdleConns - toCreate := max(p.minConns-p.Stat().TotalConns(), p.minIdleConns-p.Stat().IdleConns()) + stat := p.Stat() + toCreate := max(p.minConns-stat.TotalConns(), p.minIdleConns-stat.IdleConns()) if toCreate > 0 { return p.createIdleResources(context.Background(), int(toCreate)) } @@ -552,7 +572,7 @@ func (p *Pool) createIdleResources(parentCtx context.Context, targetResources in errs := make(chan error, targetResources) - for i := 0; i < targetResources; i++ { + for range targetResources { go func() { err := p.p.CreateResource(ctx) // Ignore ErrNotAvailable since it means that the pool has become full since we started creating resource. @@ -564,7 +584,7 @@ func (p *Pool) createIdleResources(parentCtx context.Context, targetResources in } var firstError error - for i := 0; i < targetResources; i++ { + for range targetResources { err := <-errs if err != nil && firstError == nil { cancel() @@ -591,7 +611,7 @@ func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) { // Try to acquire from the connection pool up to maxConns + 1 times, so that // any that fatal errors would empty the pool and still at least try 1 fresh // connection. - for range p.maxConns + 1 { + for range int(p.maxConns) + 1 { res, err := p.p.Acquire(ctx) if err != nil { return nil, err @@ -601,7 +621,15 @@ func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) { shouldPingParams := ShouldPingParams{Conn: cr.conn, IdleDuration: res.IdleDuration()} if p.shouldPing(ctx, shouldPingParams) { - err := cr.conn.Ping(ctx) + err := func() error { + pingCtx := ctx + if p.pingTimeout > 0 { + var cancel context.CancelFunc + pingCtx, cancel = context.WithTimeout(ctx, p.pingTimeout) + defer cancel() + } + return cr.conn.Ping(pingCtx) + }() if err != nil { res.Destroy() continue @@ -626,7 +654,7 @@ func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) { return cr.getConn(p, res), nil } - return nil, errors.New("pgxpool: detected infinite loop acquiring connection; likely bug in PrepareConn or BeforeAcquire hook") + return nil, errors.New("pgxpool: too many failed attempts acquiring connection; likely bug in PrepareConn, BeforeAcquire, or ShouldPing hook") } // AcquireFunc acquires a *Conn and calls f with that *Conn. ctx will only affect the Acquire. It has no effect on the diff --git a/vendor/github.com/jackc/pgx/v5/rows.go b/vendor/github.com/jackc/pgx/v5/rows.go index a5725fd..d74518d 100644 --- a/vendor/github.com/jackc/pgx/v5/rows.go +++ b/vendor/github.com/jackc/pgx/v5/rows.go @@ -29,9 +29,9 @@ type Rows interface { // to call Close after rows is already closed. Close() - // Err returns any error that occurred while reading. Err must only be called after the Rows is closed (either by - // calling Close or by Next returning false). If it is called early it may return nil even if there was an error - // executing the query. + // Err returns any error that occurred while executing a query or reading its results. Err must be called after the + // Rows is closed (either by calling Close or by Next returning false) to check if the query was successful. If it is + // called before the Rows is closed it may return nil even if the query failed on the server. Err() error // CommandTag returns the command tag from this query. It is only available after Rows is closed. @@ -529,7 +529,7 @@ func RowTo[T any](row CollectableRow) (T, error) { return value, err } -// RowTo returns a the address of a T scanned from row. +// RowToAddrOf returns the address of a T scanned from row. func RowToAddrOf[T any](row CollectableRow) (*T, error) { var value T err := row.Scan(&value) @@ -848,7 +848,7 @@ func fieldPosByName(fldDescs []pgconn.FieldDescription, field string, normalize } } } - return + return i } // structRowField describes a field of a struct. diff --git a/vendor/github.com/jackc/pgx/v5/test.sh b/vendor/github.com/jackc/pgx/v5/test.sh new file mode 100644 index 0000000..8bab2d2 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/test.sh @@ -0,0 +1,170 @@ +#!/usr/bin/env bash +set -euo pipefail + +# test.sh - Run pgx tests against specific database targets +# +# Usage: +# ./test.sh [target] [go test flags...] +# +# Targets: +# pg14 - PostgreSQL 14 (port 5414) +# pg15 - PostgreSQL 15 (port 5415) +# pg16 - PostgreSQL 16 (port 5416) +# pg17 - PostgreSQL 17 (port 5417) +# pg18 - PostgreSQL 18 (port 5432) [default] +# crdb - CockroachDB (port 26257) +# all - Run against all targets sequentially +# +# Examples: +# ./test.sh # Test against PG18 +# ./test.sh pg14 # Test against PG14 +# ./test.sh crdb # Test against CockroachDB +# ./test.sh all # Test against all targets +# ./test.sh pg16 -run TestConnect # Test specific test against PG16 +# ./test.sh pg18 -count=1 -v # Verbose, no cache, PG18 + +# Color output (disabled if not a terminal) +if [ -t 1 ]; then + GREEN='\033[0;32m' + RED='\033[0;31m' + BLUE='\033[0;34m' + NC='\033[0m' +else + GREEN='' + RED='' + BLUE='' + NC='' +fi + +log_info() { echo -e "${BLUE}==> $*${NC}"; } +log_ok() { echo -e "${GREEN}==> $*${NC}"; } +log_err() { echo -e "${RED}==> $*${NC}" >&2; } + +# Wait for a database to accept connections +wait_for_ready() { + local connstr="$1" + local label="$2" + local max_attempts=30 + local attempt=0 + + log_info "Waiting for $label to be ready..." + while ! psql "$connstr" -c "SELECT 1" > /dev/null 2>&1; do + attempt=$((attempt + 1)) + if [ "$attempt" -ge "$max_attempts" ]; then + log_err "$label did not become ready after $max_attempts attempts" + return 1 + fi + sleep 1 + done + log_ok "$label is ready" +} + +# Directory containing this script (used to locate testsetup/) +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +CERTS_DIR="$SCRIPT_DIR/testsetup/certs" + +# Copy client certificates to /tmp for TLS tests +setup_client_certs() { + if [ -d "$CERTS_DIR" ]; then + base64 -d "$CERTS_DIR/ca.pem.b64" > /tmp/ca.pem + base64 -d "$CERTS_DIR/pgx_sslcert.crt.b64" > /tmp/pgx_sslcert.crt + base64 -d "$CERTS_DIR/pgx_sslcert.key.b64" > /tmp/pgx_sslcert.key + fi +} + +# Initialize CockroachDB (create database if not exists) +init_crdb() { + local connstr="postgresql://root@localhost:26257/?sslmode=disable" + wait_for_ready "$connstr" "CockroachDB" + log_info "Ensuring pgx_test database exists on CockroachDB..." + psql "$connstr" -c "CREATE DATABASE IF NOT EXISTS pgx_test" 2>/dev/null || true +} + +# Run tests against a single target +run_tests() { + local target="$1" + shift + local extra_args=("$@") + + local label="" + local port="" + + case "$target" in + pg14) label="PostgreSQL 14"; port=5414 ;; + pg15) label="PostgreSQL 15"; port=5415 ;; + pg16) label="PostgreSQL 16"; port=5416 ;; + pg17) label="PostgreSQL 17"; port=5417 ;; + pg18) label="PostgreSQL 18"; port=5432 ;; + crdb) + label="CockroachDB (port 26257)" + init_crdb + log_info "Testing against $label" + if ! PGX_TEST_DATABASE="postgresql://root@localhost:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" \ + go test -count=1 "${extra_args[@]}" ./...; then + log_err "Tests FAILED against $label" + return 1 + fi + log_ok "Tests passed against $label" + return 0 + ;; + *) + log_err "Unknown target: $target" + log_err "Valid targets: pg14, pg15, pg16, pg17, pg18, crdb, all" + return 1 + ;; + esac + + setup_client_certs + + log_info "Testing against $label (port $port)" + if ! PGX_TEST_DATABASE="host=localhost port=$port user=postgres password=postgres dbname=pgx_test" \ + PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql port=$port user=postgres dbname=pgx_test" \ + PGX_TEST_TCP_CONN_STRING="host=127.0.0.1 port=$port user=pgx_md5 password=secret dbname=pgx_test" \ + PGX_TEST_MD5_PASSWORD_CONN_STRING="host=127.0.0.1 port=$port user=pgx_md5 password=secret dbname=pgx_test" \ + PGX_TEST_SCRAM_PASSWORD_CONN_STRING="host=127.0.0.1 port=$port user=pgx_scram password=secret dbname=pgx_test channel_binding=disable" \ + PGX_TEST_SCRAM_PLUS_CONN_STRING="host=localhost port=$port user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test channel_binding=require" \ + PGX_TEST_PLAIN_PASSWORD_CONN_STRING="host=127.0.0.1 port=$port user=pgx_pw password=secret dbname=pgx_test" \ + PGX_TEST_TLS_CONN_STRING="host=localhost port=$port user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test channel_binding=disable" \ + PGX_TEST_TLS_CLIENT_CONN_STRING="host=localhost port=$port user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test" \ + PGX_SSL_PASSWORD=certpw \ + go test -count=1 "${extra_args[@]}" ./...; then + log_err "Tests FAILED against $label" + return 1 + fi + log_ok "Tests passed against $label" +} + +# Main +main() { + local target="${1:-pg18}" + + if [ "$target" = "all" ]; then + shift || true + local targets=(pg14 pg15 pg16 pg17 pg18 crdb) + local failed=() + + for t in "${targets[@]}"; do + echo "" + log_info "==========================================" + log_info "Target: $t" + log_info "==========================================" + if ! run_tests "$t" "$@"; then + failed+=("$t") + log_err "FAILED: $t" + fi + done + + echo "" + if [ ${#failed[@]} -gt 0 ]; then + log_err "Failed targets: ${failed[*]}" + return 1 + else + log_ok "All targets passed" + fi + else + shift || true + run_tests "$target" "$@" + fi +} + +main "$@" diff --git a/vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go b/vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go deleted file mode 100644 index 28cd99c..0000000 --- a/vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2012 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -/* -Package pbkdf2 implements the key derivation function PBKDF2 as defined in RFC -2898 / PKCS #5 v2.0. - -A key derivation function is useful when encrypting data based on a password -or any other not-fully-random data. It uses a pseudorandom function to derive -a secure encryption key based on the password. - -While v2.0 of the standard defines only one pseudorandom function to use, -HMAC-SHA1, the drafted v2.1 specification allows use of all five FIPS Approved -Hash Functions SHA-1, SHA-224, SHA-256, SHA-384 and SHA-512 for HMAC. To -choose, you can pass the `New` functions from the different SHA packages to -pbkdf2.Key. -*/ -package pbkdf2 - -import ( - "crypto/hmac" - "hash" -) - -// Key derives a key from the password, salt and iteration count, returning a -// []byte of length keylen that can be used as cryptographic key. The key is -// derived based on the method described as PBKDF2 with the HMAC variant using -// the supplied hash function. -// -// For example, to use a HMAC-SHA-1 based PBKDF2 key derivation function, you -// can get a derived key for e.g. AES-256 (which needs a 32-byte key) by -// doing: -// -// dk := pbkdf2.Key([]byte("some password"), salt, 4096, 32, sha1.New) -// -// Remember to get a good random salt. At least 8 bytes is recommended by the -// RFC. -// -// Using a higher iteration count will increase the cost of an exhaustive -// search but will also make derivation proportionally slower. -func Key(password, salt []byte, iter, keyLen int, h func() hash.Hash) []byte { - prf := hmac.New(h, password) - hashLen := prf.Size() - numBlocks := (keyLen + hashLen - 1) / hashLen - - var buf [4]byte - dk := make([]byte, 0, numBlocks*hashLen) - U := make([]byte, hashLen) - for block := 1; block <= numBlocks; block++ { - // N.B.: || means concatenation, ^ means XOR - // for each block T_i = U_1 ^ U_2 ^ ... ^ U_iter - // U_1 = PRF(password, salt || uint(i)) - prf.Reset() - prf.Write(salt) - buf[0] = byte(block >> 24) - buf[1] = byte(block >> 16) - buf[2] = byte(block >> 8) - buf[3] = byte(block) - prf.Write(buf[:4]) - dk = prf.Sum(dk) - T := dk[len(dk)-hashLen:] - copy(U, T) - - // U_n = PRF(password, U_(n-1)) - for n := 2; n <= iter; n++ { - prf.Reset() - prf.Write(U) - U = U[:0] - U = prf.Sum(U) - for x := range U { - T[x] ^= U[x] - } - } - } - return dk[:keyLen] -} diff --git a/vendor/modules.txt b/vendor/modules.txt index 0d35bd3..9761101 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -149,8 +149,8 @@ github.com/jackc/pgpassfile # github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 ## explicit; go 1.14 github.com/jackc/pgservicefile -# github.com/jackc/pgx/v5 v5.7.6 -## explicit; go 1.23.0 +# github.com/jackc/pgx/v5 v5.9.0 +## explicit; go 1.25.0 github.com/jackc/pgx/v5 github.com/jackc/pgx/v5/internal/iobufpool github.com/jackc/pgx/v5/internal/pgio @@ -317,7 +317,6 @@ golang.org/x/crypto/chacha20 golang.org/x/crypto/curve25519 golang.org/x/crypto/internal/alias golang.org/x/crypto/internal/poly1305 -golang.org/x/crypto/pbkdf2 golang.org/x/crypto/sha3 golang.org/x/crypto/ssh golang.org/x/crypto/ssh/internal/bcrypt_pbkdf