From b7f225df998856d193696e982019a6886dfc32b0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 20 Mar 2026 01:16:11 +0000 Subject: [PATCH] CopyFrom: auto-detect binary/text format with text fallback CopyFrom previously hardcoded binary format and used tryScanStringCopyValueThenEncode as a workaround when binary encoding failed. This was fragile and couldn't handle types that only support text format (e.g. jsonpath, aclitem). Now CopyFrom peeks at the first row and checks both codec-level format support and value-level encode plan availability. If any column cannot use binary, it transparently falls back to text format COPY, letting PostgreSQL handle the parsing natively. - Remove tryScanStringCopyValueThenEncode - Add encodeCopyValueText with proper COPY text escaping - Add canUseBinaryFormat two-level detection - Add buildCopyBufText for text format row encoding - Buffer first row for format decision without data loss - Add tests for text fallback, special char escaping, NULLs, large datasets, all query exec modes, string-to-int conversion, and empty row sets Co-Authored-By: Claude Opus 4.6 --- copy_from.go | 196 ++++++++++++++++++++++++++++----- copy_from_test.go | 270 ++++++++++++++++++++++++++++++++++++++++++++++ values.go | 49 +++++---- 3 files changed, 471 insertions(+), 44 deletions(-) diff --git a/copy_from.go b/copy_from.go index abcd22396..7d6164391 100644 --- a/copy_from.go +++ b/copy_from.go @@ -113,6 +113,7 @@ type copyFrom struct { rowSrc CopyFromSource readerErrChan chan error mode QueryExecMode + firstRow []any } func (ct *copyFrom) run(ctx context.Context) (int64, error) { @@ -158,6 +159,23 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) { return 0, fmt.Errorf("unknown QueryExecMode: %v", ct.mode) } + // Peek at the first row to determine whether all columns can be encoded in binary format. + hasFirstRow := ct.rowSrc.Next() + if hasFirstRow { + var err error + ct.firstRow, err = ct.rowSrc.Values() + if err != nil { + return 0, err + } + if len(ct.firstRow) != len(ct.columnNames) { + return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(ct.firstRow)) + } + } else if ct.rowSrc.Err() != nil { + return 0, ct.rowSrc.Err() + } + + useBinary := ct.canUseBinaryFormat(sd) + r, w := io.Pipe() doneChan := make(chan struct{}) @@ -167,39 +185,57 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) { // Purposely NOT using defer w.Close(). See https://github.com/golang/go/issues/24283. buf := ct.conn.wbuf - buf = append(buf, "PGCOPY\n\377\r\n\000"...) - buf = pgio.AppendInt32(buf, 0) - buf = pgio.AppendInt32(buf, 0) + if useBinary { + buf = append(buf, "PGCOPY\n\377\r\n\000"...) + buf = pgio.AppendInt32(buf, 0) + buf = pgio.AppendInt32(buf, 0) + } - moreRows := true - for moreRows { - var err error - moreRows, buf, err = ct.buildCopyBuf(buf, sd) - if err != nil { - w.CloseWithError(err) - return - } + moreRows := hasFirstRow + for { + if moreRows { + var err error + if useBinary { + moreRows, buf, err = ct.buildCopyBuf(buf, sd) + } else { + moreRows, buf, err = ct.buildCopyBufText(buf, sd) + } + if err != nil { + w.CloseWithError(err) + return + } - if ct.rowSrc.Err() != nil { - w.CloseWithError(ct.rowSrc.Err()) - return + if ct.rowSrc.Err() != nil { + w.CloseWithError(ct.rowSrc.Err()) + return + } } if len(buf) > 0 { - _, err = w.Write(buf) + _, err := w.Write(buf) if err != nil { w.Close() return } } + if !moreRows { + break + } + buf = buf[:0] } w.Close() }() - commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) + var copySQL string + if useBinary { + copySQL = fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames) + } else { + copySQL = fmt.Sprintf("copy %s ( %s ) from stdin;", quotedTableName, quotedColumnNames) + } + commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, copySQL) r.Close() <-doneChan @@ -214,11 +250,64 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) { return commandTag.RowsAffected(), err } +// canUseBinaryFormat checks whether all columns can be encoded in binary format. It checks both codec-level support and +// whether the first row's values can be planned for binary encoding. +func (ct *copyFrom) canUseBinaryFormat(sd *pgconn.StatementDescription) bool { + m := ct.conn.typeMap + for _, field := range sd.Fields { + typ, ok := m.TypeForOID(field.DataTypeOID) + if !ok { + return false + } + if !typ.Codec.FormatSupported(BinaryFormatCode) { + return false + } + } + + if ct.firstRow != nil { + for i, val := range ct.firstRow { + if val == nil { + continue + } + plan := m.PlanEncode(sd.Fields[i].DataTypeOID, BinaryFormatCode, val) + if plan == nil { + return false + } + } + } + + return true +} + func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) { const sendBufSize = 65536 - 5 // The packet has a 5-byte header lastBufLen := 0 largestRowLen := 0 + encodeBinaryRow := func(values []any) ([]byte, error) { + buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) + for i, val := range values { + var err error + buf, err = encodeCopyValue(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val) + if err != nil { + return nil, err + } + } + return buf, nil + } + + // Process buffered first row. + if ct.firstRow != nil { + lastBufLen = len(buf) + var err error + buf, err = encodeBinaryRow(ct.firstRow) + if err != nil { + return false, nil, err + } + ct.firstRow = nil + largestRowLen = len(buf) - lastBufLen + } + for ct.rowSrc.Next() { lastBufLen = len(buf) @@ -230,12 +319,9 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (b return false, nil, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) } - buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) - for i, val := range values { - buf, err = encodeCopyValue(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val) - if err != nil { - return false, nil, err - } + buf, err = encodeBinaryRow(values) + if err != nil { + return false, nil, err } rowLen := len(buf) - lastBufLen @@ -254,11 +340,73 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (b return false, buf, nil } +func (ct *copyFrom) buildCopyBufText(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) { + const sendBufSize = 65536 - 5 // The packet has a 5-byte header + lastBufLen := 0 + largestRowLen := 0 + + encodeTextRow := func(values []any) ([]byte, error) { + for i, val := range values { + if i > 0 { + buf = append(buf, '\t') + } + var err error + buf, err = encodeCopyValueText(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val) + if err != nil { + return nil, err + } + } + buf = append(buf, '\n') + return buf, nil + } + + // Process buffered first row. + if ct.firstRow != nil { + lastBufLen = len(buf) + var err error + buf, err = encodeTextRow(ct.firstRow) + if err != nil { + return false, nil, err + } + ct.firstRow = nil + largestRowLen = len(buf) - lastBufLen + } + + for ct.rowSrc.Next() { + lastBufLen = len(buf) + + values, err := ct.rowSrc.Values() + if err != nil { + return false, nil, err + } + if len(values) != len(ct.columnNames) { + return false, nil, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) + } + + buf, err = encodeTextRow(values) + if err != nil { + return false, nil, err + } + + rowLen := len(buf) - lastBufLen + if rowLen > largestRowLen { + largestRowLen = rowLen + } + + if len(buf) > sendBufSize-largestRowLen { + return true, buf, nil + } + } + + return false, buf, nil +} + // CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion. It returns the number of rows copied and // an error. // -// CopyFrom requires all values use the binary format. A pgtype.Type that supports the binary format must be registered -// for the type of each column. Almost all types implemented by pgx support the binary format. +// CopyFrom automatically uses the binary format when all columns support it. If any column's type only supports text +// format, or if the first row's values cannot be encoded in binary, CopyFrom transparently falls back to the text +// format. // // Even though enum types appear to be strings they still must be registered to use with CopyFrom. This can be done with // Conn.LoadType and pgtype.Map.RegisterType. diff --git a/copy_from_test.go b/copy_from_test.go index d3264c122..aa5b9b6b0 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -937,6 +937,276 @@ func TestConnCopyFromAutomaticStringConversionArray(t *testing.T) { ensureConnValid(t, conn) } +// TestConnCopyFromTextFormatFallback tests that CopyFrom falls back to text format when a column type does not support +// binary encoding (e.g., jsonpath which uses TextFormatOnlyCodec). +func TestConnCopyFromTextFormatFallback(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + pgxtest.SkipCockroachDB(t, conn, "CockroachDB does not support jsonpath") + + mustExec(t, conn, `create temporary table foo( + a int4, + b jsonpath, + c text + )`) + + inputRows := [][]any{ + {int32(1), "$.store.book[*].author", "hello"}, + {int32(2), "$.store.price", "world"}, + {int32(3), "$[0]", "test"}, + } + + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), copyCount) + + rows, _ := conn.Query(ctx, "select a, b::text, c from foo order by a") + outputRows, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) ([]any, error) { + var a int32 + var b, c string + err := row.Scan(&a, &b, &c) + return []any{a, b, c}, err + }) + require.NoError(t, err) + + require.Equal(t, [][]any{ + {int32(1), "$.\"store\".\"book\"[*].\"author\"", "hello"}, + {int32(2), "$.\"store\".\"price\"", "world"}, + {int32(3), "$[0]", "test"}, + }, outputRows) + + ensureConnValid(t, conn) +} + +// TestConnCopyFromTextFormatFallbackWithNulls tests text format fallback with NULL values across columns. +func TestConnCopyFromTextFormatFallbackWithNulls(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + pgxtest.SkipCockroachDB(t, conn, "CockroachDB does not support jsonpath") + + mustExec(t, conn, `create temporary table foo( + a int4, + b jsonpath, + c text + )`) + + inputRows := [][]any{ + {int32(1), "$.x", "hello"}, + {int32(2), nil, nil}, + {nil, "$.y", nil}, + {nil, nil, nil}, + } + + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), copyCount) + + rows, _ := conn.Query(ctx, "select a, b::text, c from foo order by a nulls last") + outputRows, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) ([]any, error) { + var a *int32 + var b, c *string + err := row.Scan(&a, &b, &c) + return []any{a, b, c}, err + }) + require.NoError(t, err) + require.Len(t, outputRows, 4) + + ensureConnValid(t, conn) +} + +// TestConnCopyFromTextFormatSpecialCharEscaping tests that text format properly escapes special characters. +func TestConnCopyFromTextFormatSpecialCharEscaping(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + pgxtest.SkipCockroachDB(t, conn, "CockroachDB does not support jsonpath") + + mustExec(t, conn, `create temporary table foo( + a jsonpath, + b text + )`) + + inputRows := [][]any{ + {"$.a", "tab\there"}, + {"$.b", "new\nline"}, + {"$.c", "carriage\rreturn"}, + {"$.d", "back\\slash"}, + {"$.e", "mixed\t\n\r\\all"}, + {"$.f", ""}, + } + + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), copyCount) + + rows, _ := conn.Query(ctx, "select b from foo order by a::text") + texts, err := pgx.CollectRows(rows, pgx.RowTo[string]) + require.NoError(t, err) + + require.Equal(t, []string{ + "tab\there", + "new\nline", + "carriage\rreturn", + "back\\slash", + "mixed\t\n\r\\all", + "", + }, texts) + + ensureConnValid(t, conn) +} + +// TestConnCopyFromTextFormatLarge tests text format fallback with a large number of rows to exercise buffer management. +func TestConnCopyFromTextFormatLarge(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + pgxtest.SkipCockroachDB(t, conn, "CockroachDB does not support jsonpath") + + mustExec(t, conn, `create temporary table foo( + a int4, + b jsonpath, + c text + )`) + + const rowCount = 10_000 + inputRows := make([][]any, rowCount) + for i := range rowCount { + inputRows[i] = []any{int32(i), "$.x", fmt.Sprintf("row-%d", i)} + } + + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, rowCount, copyCount) + + var count int64 + err = conn.QueryRow(ctx, "select count(*) from foo").Scan(&count) + require.NoError(t, err) + require.EqualValues(t, rowCount, count) + + ensureConnValid(t, conn) +} + +// TestConnCopyFromTextFormatAllQueryExecModes tests that text format fallback works with all query exec modes. +func TestConnCopyFromTextFormatAllQueryExecModes(t *testing.T) { + for _, mode := range pgxtest.AllQueryExecModes { + t.Run(mode.String(), func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + cfg := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + cfg.DefaultQueryExecMode = mode + conn := mustConnect(t, cfg) + defer closeConn(t, conn) + + pgxtest.SkipCockroachDB(t, conn, "CockroachDB does not support jsonpath") + + mustExec(t, conn, `create temporary table foo( + a int4, + b jsonpath + )`) + + inputRows := [][]any{ + {int32(1), "$.x"}, + {int32(2), "$.y"}, + } + + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), copyCount) + + var count int64 + err = conn.QueryRow(ctx, "select count(*) from foo").Scan(&count) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), count) + + ensureConnValid(t, conn) + }) + } +} + +// TestConnCopyFromTextFormatStringToInt tests that string values for integer columns trigger text format fallback and +// PostgreSQL parses them correctly. +func TestConnCopyFromTextFormatStringToInt(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int4, + b text + )`) + + inputRows := [][]any{ + {"42", "hello"}, + {"7", "world"}, + } + + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), copyCount) + + rows, _ := conn.Query(ctx, "select a, b from foo order by a") + outputRows, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) ([]any, error) { + var a int64 + var b string + err := row.Scan(&a, &b) + return []any{a, b}, err + }) + require.NoError(t, err) + + require.Equal(t, [][]any{ + {int64(7), "world"}, + {int64(42), "hello"}, + }, outputRows) + + ensureConnValid(t, conn) +} + +// TestConnCopyFromEmptyRows tests that CopyFrom handles zero rows correctly in both binary and text paths. +func TestConnCopyFromEmptyRows(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo(a int4)`) + + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows([][]any{})) + require.NoError(t, err) + require.EqualValues(t, 0, copyCount) + + ensureConnValid(t, conn) +} + func TestCopyFromFunc(t *testing.T) { t.Parallel() diff --git a/values.go b/values.go index 6e2ff3003..4c2d509e8 100644 --- a/values.go +++ b/values.go @@ -1,8 +1,6 @@ package pgx import ( - "errors" - "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/pgtype" ) @@ -29,11 +27,7 @@ func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, er buf = pgio.AppendInt32(buf, -1) argBuf, err := m.Encode(oid, BinaryFormatCode, arg, buf) if err != nil { - if argBuf2, err2 := tryScanStringCopyValueThenEncode(m, buf, oid, arg); err2 == nil { - argBuf = argBuf2 - } else { - return nil, err - } + return nil, err } if argBuf != nil { @@ -43,21 +37,36 @@ func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, er return buf, nil } -func tryScanStringCopyValueThenEncode(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, error) { - s, ok := arg.(string) - if !ok { - textBuf, err := m.Encode(oid, TextFormatCode, arg, nil) - if err != nil { - return nil, errors.New("not a string and cannot be encoded as text") - } - s = string(textBuf) - } - - var v any - err := m.Scan(oid, TextFormatCode, []byte(s), &v) +func encodeCopyValueText(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, error) { + // Encode into a separate buffer to distinguish NULL (nil return) from empty string (empty non-nil return). Using a + // non-nil empty slice ensures that encoding an empty value (e.g. empty string) returns non-nil. This also avoids + // aliasing issues when the subsequent escaping step expands special characters in-place. + textBuf, err := m.Encode(oid, TextFormatCode, arg, []byte{}) if err != nil { return nil, err } + if textBuf == nil { + // NULL is represented as \N in text COPY format. + return append(buf, '\\', 'N'), nil + } + return appendTextCopyEscaped(buf, textBuf), nil +} - return m.Encode(oid, BinaryFormatCode, v, buf) +// appendTextCopyEscaped appends src to buf, escaping characters that are special in PostgreSQL text COPY format. +func appendTextCopyEscaped(buf []byte, src []byte) []byte { + for _, b := range src { + switch b { + case '\\': + buf = append(buf, '\\', '\\') + case '\n': + buf = append(buf, '\\', 'n') + case '\r': + buf = append(buf, '\\', 'r') + case '\t': + buf = append(buf, '\\', 't') + default: + buf = append(buf, b) + } + } + return buf }