Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 172 additions & 24 deletions copy_from.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ type copyFrom struct {
rowSrc CopyFromSource
readerErrChan chan error
mode QueryExecMode
firstRow []any
}

func (ct *copyFrom) run(ctx context.Context) (int64, error) {
Expand Down Expand Up @@ -158,6 +159,23 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
return 0, fmt.Errorf("unknown QueryExecMode: %v", ct.mode)
}

// Peek at the first row to determine whether all columns can be encoded in binary format.
hasFirstRow := ct.rowSrc.Next()
if hasFirstRow {
var err error
ct.firstRow, err = ct.rowSrc.Values()
if err != nil {
return 0, err
}
if len(ct.firstRow) != len(ct.columnNames) {
return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(ct.firstRow))
}
} else if ct.rowSrc.Err() != nil {
return 0, ct.rowSrc.Err()
}

useBinary := ct.canUseBinaryFormat(sd)

r, w := io.Pipe()
doneChan := make(chan struct{})

Expand All @@ -167,39 +185,57 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
// Purposely NOT using defer w.Close(). See https://github.com/golang/go/issues/24283.
buf := ct.conn.wbuf

buf = append(buf, "PGCOPY\n\377\r\n\000"...)
buf = pgio.AppendInt32(buf, 0)
buf = pgio.AppendInt32(buf, 0)
if useBinary {
buf = append(buf, "PGCOPY\n\377\r\n\000"...)
buf = pgio.AppendInt32(buf, 0)
buf = pgio.AppendInt32(buf, 0)
}

moreRows := true
for moreRows {
var err error
moreRows, buf, err = ct.buildCopyBuf(buf, sd)
if err != nil {
w.CloseWithError(err)
return
}
moreRows := hasFirstRow
for {
if moreRows {
var err error
if useBinary {
moreRows, buf, err = ct.buildCopyBuf(buf, sd)
} else {
moreRows, buf, err = ct.buildCopyBufText(buf, sd)
}
if err != nil {
w.CloseWithError(err)
return
}

if ct.rowSrc.Err() != nil {
w.CloseWithError(ct.rowSrc.Err())
return
if ct.rowSrc.Err() != nil {
w.CloseWithError(ct.rowSrc.Err())
return
}
}

if len(buf) > 0 {
_, err = w.Write(buf)
_, err := w.Write(buf)
if err != nil {
w.Close()
return
}
}

if !moreRows {
break
}

buf = buf[:0]
}

w.Close()
}()

commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
var copySQL string
if useBinary {
copySQL = fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)
} else {
copySQL = fmt.Sprintf("copy %s ( %s ) from stdin;", quotedTableName, quotedColumnNames)
}
commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, copySQL)

r.Close()
<-doneChan
Expand All @@ -214,11 +250,64 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
return commandTag.RowsAffected(), err
}

// canUseBinaryFormat checks whether all columns can be encoded in binary format. It checks both codec-level support and
// whether the first row's values can be planned for binary encoding.
func (ct *copyFrom) canUseBinaryFormat(sd *pgconn.StatementDescription) bool {
m := ct.conn.typeMap
for _, field := range sd.Fields {
typ, ok := m.TypeForOID(field.DataTypeOID)
if !ok {
return false
}
if !typ.Codec.FormatSupported(BinaryFormatCode) {
return false
}
}

if ct.firstRow != nil {
for i, val := range ct.firstRow {
if val == nil {
continue
}
plan := m.PlanEncode(sd.Fields[i].DataTypeOID, BinaryFormatCode, val)
if plan == nil {
return false
}
}
}

return true
}

func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) {
const sendBufSize = 65536 - 5 // The packet has a 5-byte header
lastBufLen := 0
largestRowLen := 0

encodeBinaryRow := func(values []any) ([]byte, error) {
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
for i, val := range values {
var err error
buf, err = encodeCopyValue(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val)
if err != nil {
return nil, err
}
}
return buf, nil
}

// Process buffered first row.
if ct.firstRow != nil {
lastBufLen = len(buf)
var err error
buf, err = encodeBinaryRow(ct.firstRow)
if err != nil {
return false, nil, err
}
ct.firstRow = nil
largestRowLen = len(buf) - lastBufLen
}

for ct.rowSrc.Next() {
lastBufLen = len(buf)

Expand All @@ -230,12 +319,9 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (b
return false, nil, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
}

buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
for i, val := range values {
buf, err = encodeCopyValue(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val)
if err != nil {
return false, nil, err
}
buf, err = encodeBinaryRow(values)
if err != nil {
return false, nil, err
}

rowLen := len(buf) - lastBufLen
Expand All @@ -254,11 +340,73 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (b
return false, buf, nil
}

func (ct *copyFrom) buildCopyBufText(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) {
const sendBufSize = 65536 - 5 // The packet has a 5-byte header
lastBufLen := 0
largestRowLen := 0

encodeTextRow := func(values []any) ([]byte, error) {
for i, val := range values {
if i > 0 {
buf = append(buf, '\t')
}
var err error
buf, err = encodeCopyValueText(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val)
if err != nil {
return nil, err
}
}
buf = append(buf, '\n')
return buf, nil
}

// Process buffered first row.
if ct.firstRow != nil {
lastBufLen = len(buf)
var err error
buf, err = encodeTextRow(ct.firstRow)
if err != nil {
return false, nil, err
}
ct.firstRow = nil
largestRowLen = len(buf) - lastBufLen
}

for ct.rowSrc.Next() {
lastBufLen = len(buf)

values, err := ct.rowSrc.Values()
if err != nil {
return false, nil, err
}
if len(values) != len(ct.columnNames) {
return false, nil, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
}

buf, err = encodeTextRow(values)
if err != nil {
return false, nil, err
}

rowLen := len(buf) - lastBufLen
if rowLen > largestRowLen {
largestRowLen = rowLen
}

if len(buf) > sendBufSize-largestRowLen {
return true, buf, nil
}
}

return false, buf, nil
}

// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion. It returns the number of rows copied and
// an error.
//
// CopyFrom requires all values use the binary format. A pgtype.Type that supports the binary format must be registered
// for the type of each column. Almost all types implemented by pgx support the binary format.
// CopyFrom automatically uses the binary format when all columns support it. If any column's type only supports text
// format, or if the first row's values cannot be encoded in binary, CopyFrom transparently falls back to the text
// format.
//
// Even though enum types appear to be strings they still must be registered to use with CopyFrom. This can be done with
// Conn.LoadType and pgtype.Map.RegisterType.
Expand Down
Loading
Loading