diff --git a/copy_from.go b/copy_from.go index abcd22396..7d6164391 100644 --- a/copy_from.go +++ b/copy_from.go @@ -113,6 +113,7 @@ type copyFrom struct { rowSrc CopyFromSource readerErrChan chan error mode QueryExecMode + firstRow []any } func (ct *copyFrom) run(ctx context.Context) (int64, error) { @@ -158,6 +159,23 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) { return 0, fmt.Errorf("unknown QueryExecMode: %v", ct.mode) } + // Peek at the first row to determine whether all columns can be encoded in binary format. + hasFirstRow := ct.rowSrc.Next() + if hasFirstRow { + var err error + ct.firstRow, err = ct.rowSrc.Values() + if err != nil { + return 0, err + } + if len(ct.firstRow) != len(ct.columnNames) { + return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(ct.firstRow)) + } + } else if ct.rowSrc.Err() != nil { + return 0, ct.rowSrc.Err() + } + + useBinary := ct.canUseBinaryFormat(sd) + r, w := io.Pipe() doneChan := make(chan struct{}) @@ -167,39 +185,57 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) { // Purposely NOT using defer w.Close(). See https://github.com/golang/go/issues/24283. buf := ct.conn.wbuf - buf = append(buf, "PGCOPY\n\377\r\n\000"...) - buf = pgio.AppendInt32(buf, 0) - buf = pgio.AppendInt32(buf, 0) + if useBinary { + buf = append(buf, "PGCOPY\n\377\r\n\000"...) + buf = pgio.AppendInt32(buf, 0) + buf = pgio.AppendInt32(buf, 0) + } - moreRows := true - for moreRows { - var err error - moreRows, buf, err = ct.buildCopyBuf(buf, sd) - if err != nil { - w.CloseWithError(err) - return - } + moreRows := hasFirstRow + for { + if moreRows { + var err error + if useBinary { + moreRows, buf, err = ct.buildCopyBuf(buf, sd) + } else { + moreRows, buf, err = ct.buildCopyBufText(buf, sd) + } + if err != nil { + w.CloseWithError(err) + return + } - if ct.rowSrc.Err() != nil { - w.CloseWithError(ct.rowSrc.Err()) - return + if ct.rowSrc.Err() != nil { + w.CloseWithError(ct.rowSrc.Err()) + return + } } if len(buf) > 0 { - _, err = w.Write(buf) + _, err := w.Write(buf) if err != nil { w.Close() return } } + if !moreRows { + break + } + buf = buf[:0] } w.Close() }() - commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) + var copySQL string + if useBinary { + copySQL = fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames) + } else { + copySQL = fmt.Sprintf("copy %s ( %s ) from stdin;", quotedTableName, quotedColumnNames) + } + commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, copySQL) r.Close() <-doneChan @@ -214,11 +250,64 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) { return commandTag.RowsAffected(), err } +// canUseBinaryFormat checks whether all columns can be encoded in binary format. It checks both codec-level support and +// whether the first row's values can be planned for binary encoding. +func (ct *copyFrom) canUseBinaryFormat(sd *pgconn.StatementDescription) bool { + m := ct.conn.typeMap + for _, field := range sd.Fields { + typ, ok := m.TypeForOID(field.DataTypeOID) + if !ok { + return false + } + if !typ.Codec.FormatSupported(BinaryFormatCode) { + return false + } + } + + if ct.firstRow != nil { + for i, val := range ct.firstRow { + if val == nil { + continue + } + plan := m.PlanEncode(sd.Fields[i].DataTypeOID, BinaryFormatCode, val) + if plan == nil { + return false + } + } + } + + return true +} + func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) { const sendBufSize = 65536 - 5 // The packet has a 5-byte header lastBufLen := 0 largestRowLen := 0 + encodeBinaryRow := func(values []any) ([]byte, error) { + buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) + for i, val := range values { + var err error + buf, err = encodeCopyValue(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val) + if err != nil { + return nil, err + } + } + return buf, nil + } + + // Process buffered first row. + if ct.firstRow != nil { + lastBufLen = len(buf) + var err error + buf, err = encodeBinaryRow(ct.firstRow) + if err != nil { + return false, nil, err + } + ct.firstRow = nil + largestRowLen = len(buf) - lastBufLen + } + for ct.rowSrc.Next() { lastBufLen = len(buf) @@ -230,12 +319,9 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (b return false, nil, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) } - buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) - for i, val := range values { - buf, err = encodeCopyValue(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val) - if err != nil { - return false, nil, err - } + buf, err = encodeBinaryRow(values) + if err != nil { + return false, nil, err } rowLen := len(buf) - lastBufLen @@ -254,11 +340,73 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (b return false, buf, nil } +func (ct *copyFrom) buildCopyBufText(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) { + const sendBufSize = 65536 - 5 // The packet has a 5-byte header + lastBufLen := 0 + largestRowLen := 0 + + encodeTextRow := func(values []any) ([]byte, error) { + for i, val := range values { + if i > 0 { + buf = append(buf, '\t') + } + var err error + buf, err = encodeCopyValueText(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val) + if err != nil { + return nil, err + } + } + buf = append(buf, '\n') + return buf, nil + } + + // Process buffered first row. + if ct.firstRow != nil { + lastBufLen = len(buf) + var err error + buf, err = encodeTextRow(ct.firstRow) + if err != nil { + return false, nil, err + } + ct.firstRow = nil + largestRowLen = len(buf) - lastBufLen + } + + for ct.rowSrc.Next() { + lastBufLen = len(buf) + + values, err := ct.rowSrc.Values() + if err != nil { + return false, nil, err + } + if len(values) != len(ct.columnNames) { + return false, nil, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) + } + + buf, err = encodeTextRow(values) + if err != nil { + return false, nil, err + } + + rowLen := len(buf) - lastBufLen + if rowLen > largestRowLen { + largestRowLen = rowLen + } + + if len(buf) > sendBufSize-largestRowLen { + return true, buf, nil + } + } + + return false, buf, nil +} + // CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion. It returns the number of rows copied and // an error. // -// CopyFrom requires all values use the binary format. A pgtype.Type that supports the binary format must be registered -// for the type of each column. Almost all types implemented by pgx support the binary format. +// CopyFrom automatically uses the binary format when all columns support it. If any column's type only supports text +// format, or if the first row's values cannot be encoded in binary, CopyFrom transparently falls back to the text +// format. // // Even though enum types appear to be strings they still must be registered to use with CopyFrom. This can be done with // Conn.LoadType and pgtype.Map.RegisterType. diff --git a/copy_from_test.go b/copy_from_test.go index d3264c122..aa5b9b6b0 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -937,6 +937,276 @@ func TestConnCopyFromAutomaticStringConversionArray(t *testing.T) { ensureConnValid(t, conn) } +// TestConnCopyFromTextFormatFallback tests that CopyFrom falls back to text format when a column type does not support +// binary encoding (e.g., jsonpath which uses TextFormatOnlyCodec). +func TestConnCopyFromTextFormatFallback(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + pgxtest.SkipCockroachDB(t, conn, "CockroachDB does not support jsonpath") + + mustExec(t, conn, `create temporary table foo( + a int4, + b jsonpath, + c text + )`) + + inputRows := [][]any{ + {int32(1), "$.store.book[*].author", "hello"}, + {int32(2), "$.store.price", "world"}, + {int32(3), "$[0]", "test"}, + } + + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), copyCount) + + rows, _ := conn.Query(ctx, "select a, b::text, c from foo order by a") + outputRows, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) ([]any, error) { + var a int32 + var b, c string + err := row.Scan(&a, &b, &c) + return []any{a, b, c}, err + }) + require.NoError(t, err) + + require.Equal(t, [][]any{ + {int32(1), "$.\"store\".\"book\"[*].\"author\"", "hello"}, + {int32(2), "$.\"store\".\"price\"", "world"}, + {int32(3), "$[0]", "test"}, + }, outputRows) + + ensureConnValid(t, conn) +} + +// TestConnCopyFromTextFormatFallbackWithNulls tests text format fallback with NULL values across columns. +func TestConnCopyFromTextFormatFallbackWithNulls(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + pgxtest.SkipCockroachDB(t, conn, "CockroachDB does not support jsonpath") + + mustExec(t, conn, `create temporary table foo( + a int4, + b jsonpath, + c text + )`) + + inputRows := [][]any{ + {int32(1), "$.x", "hello"}, + {int32(2), nil, nil}, + {nil, "$.y", nil}, + {nil, nil, nil}, + } + + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), copyCount) + + rows, _ := conn.Query(ctx, "select a, b::text, c from foo order by a nulls last") + outputRows, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) ([]any, error) { + var a *int32 + var b, c *string + err := row.Scan(&a, &b, &c) + return []any{a, b, c}, err + }) + require.NoError(t, err) + require.Len(t, outputRows, 4) + + ensureConnValid(t, conn) +} + +// TestConnCopyFromTextFormatSpecialCharEscaping tests that text format properly escapes special characters. +func TestConnCopyFromTextFormatSpecialCharEscaping(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + pgxtest.SkipCockroachDB(t, conn, "CockroachDB does not support jsonpath") + + mustExec(t, conn, `create temporary table foo( + a jsonpath, + b text + )`) + + inputRows := [][]any{ + {"$.a", "tab\there"}, + {"$.b", "new\nline"}, + {"$.c", "carriage\rreturn"}, + {"$.d", "back\\slash"}, + {"$.e", "mixed\t\n\r\\all"}, + {"$.f", ""}, + } + + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), copyCount) + + rows, _ := conn.Query(ctx, "select b from foo order by a::text") + texts, err := pgx.CollectRows(rows, pgx.RowTo[string]) + require.NoError(t, err) + + require.Equal(t, []string{ + "tab\there", + "new\nline", + "carriage\rreturn", + "back\\slash", + "mixed\t\n\r\\all", + "", + }, texts) + + ensureConnValid(t, conn) +} + +// TestConnCopyFromTextFormatLarge tests text format fallback with a large number of rows to exercise buffer management. +func TestConnCopyFromTextFormatLarge(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + pgxtest.SkipCockroachDB(t, conn, "CockroachDB does not support jsonpath") + + mustExec(t, conn, `create temporary table foo( + a int4, + b jsonpath, + c text + )`) + + const rowCount = 10_000 + inputRows := make([][]any, rowCount) + for i := range rowCount { + inputRows[i] = []any{int32(i), "$.x", fmt.Sprintf("row-%d", i)} + } + + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, rowCount, copyCount) + + var count int64 + err = conn.QueryRow(ctx, "select count(*) from foo").Scan(&count) + require.NoError(t, err) + require.EqualValues(t, rowCount, count) + + ensureConnValid(t, conn) +} + +// TestConnCopyFromTextFormatAllQueryExecModes tests that text format fallback works with all query exec modes. +func TestConnCopyFromTextFormatAllQueryExecModes(t *testing.T) { + for _, mode := range pgxtest.AllQueryExecModes { + t.Run(mode.String(), func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + cfg := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + cfg.DefaultQueryExecMode = mode + conn := mustConnect(t, cfg) + defer closeConn(t, conn) + + pgxtest.SkipCockroachDB(t, conn, "CockroachDB does not support jsonpath") + + mustExec(t, conn, `create temporary table foo( + a int4, + b jsonpath + )`) + + inputRows := [][]any{ + {int32(1), "$.x"}, + {int32(2), "$.y"}, + } + + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), copyCount) + + var count int64 + err = conn.QueryRow(ctx, "select count(*) from foo").Scan(&count) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), count) + + ensureConnValid(t, conn) + }) + } +} + +// TestConnCopyFromTextFormatStringToInt tests that string values for integer columns trigger text format fallback and +// PostgreSQL parses them correctly. +func TestConnCopyFromTextFormatStringToInt(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int4, + b text + )`) + + inputRows := [][]any{ + {"42", "hello"}, + {"7", "world"}, + } + + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), copyCount) + + rows, _ := conn.Query(ctx, "select a, b from foo order by a") + outputRows, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) ([]any, error) { + var a int64 + var b string + err := row.Scan(&a, &b) + return []any{a, b}, err + }) + require.NoError(t, err) + + require.Equal(t, [][]any{ + {int64(7), "world"}, + {int64(42), "hello"}, + }, outputRows) + + ensureConnValid(t, conn) +} + +// TestConnCopyFromEmptyRows tests that CopyFrom handles zero rows correctly in both binary and text paths. +func TestConnCopyFromEmptyRows(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo(a int4)`) + + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows([][]any{})) + require.NoError(t, err) + require.EqualValues(t, 0, copyCount) + + ensureConnValid(t, conn) +} + func TestCopyFromFunc(t *testing.T) { t.Parallel() diff --git a/values.go b/values.go index 6e2ff3003..4c2d509e8 100644 --- a/values.go +++ b/values.go @@ -1,8 +1,6 @@ package pgx import ( - "errors" - "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/pgtype" ) @@ -29,11 +27,7 @@ func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, er buf = pgio.AppendInt32(buf, -1) argBuf, err := m.Encode(oid, BinaryFormatCode, arg, buf) if err != nil { - if argBuf2, err2 := tryScanStringCopyValueThenEncode(m, buf, oid, arg); err2 == nil { - argBuf = argBuf2 - } else { - return nil, err - } + return nil, err } if argBuf != nil { @@ -43,21 +37,36 @@ func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, er return buf, nil } -func tryScanStringCopyValueThenEncode(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, error) { - s, ok := arg.(string) - if !ok { - textBuf, err := m.Encode(oid, TextFormatCode, arg, nil) - if err != nil { - return nil, errors.New("not a string and cannot be encoded as text") - } - s = string(textBuf) - } - - var v any - err := m.Scan(oid, TextFormatCode, []byte(s), &v) +func encodeCopyValueText(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, error) { + // Encode into a separate buffer to distinguish NULL (nil return) from empty string (empty non-nil return). Using a + // non-nil empty slice ensures that encoding an empty value (e.g. empty string) returns non-nil. This also avoids + // aliasing issues when the subsequent escaping step expands special characters in-place. + textBuf, err := m.Encode(oid, TextFormatCode, arg, []byte{}) if err != nil { return nil, err } + if textBuf == nil { + // NULL is represented as \N in text COPY format. + return append(buf, '\\', 'N'), nil + } + return appendTextCopyEscaped(buf, textBuf), nil +} - return m.Encode(oid, BinaryFormatCode, v, buf) +// appendTextCopyEscaped appends src to buf, escaping characters that are special in PostgreSQL text COPY format. +func appendTextCopyEscaped(buf []byte, src []byte) []byte { + for _, b := range src { + switch b { + case '\\': + buf = append(buf, '\\', '\\') + case '\n': + buf = append(buf, '\\', 'n') + case '\r': + buf = append(buf, '\\', 'r') + case '\t': + buf = append(buf, '\\', 't') + default: + buf = append(buf, b) + } + } + return buf }