Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 204 additions & 0 deletions rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@ import (
"errors"
"fmt"
"io"
"math/big"
"reflect"
"sync"
"time"

"cloud.google.com/go/civil"
"cloud.google.com/go/spanner"
"cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
sppb "cloud.google.com/go/spanner/apiv1/spannerpb"
"github.com/google/uuid"
"github.com/googleapis/go-sql-spanner/connectionstate"
Expand Down Expand Up @@ -557,6 +560,207 @@ func (r *rows) nextStats(dest []driver.Value) error {
return nil
}

func (r *rows) ColumnTypeScanType(index int) reflect.Type {
if index < 0 || index >= len(r.colTypes) {
return nil
}
t := r.colTypes[index]
if r.decodeOption == DecodeOptionProto {
return reflect.TypeOf(spanner.GenericColumnValue{})
}
return scanType(t, r.decodeToNativeArrays, r.state)
}

func scanType(t *sppb.Type, decodeToNativeArrays bool, state *connectionstate.ConnectionState) reflect.Type {
switch t.Code {
case sppb.TypeCode_INT64, sppb.TypeCode_ENUM:
return reflect.TypeOf(int64(0))
case sppb.TypeCode_FLOAT32:
return reflect.TypeOf(float32(0))
case sppb.TypeCode_FLOAT64:
return reflect.TypeOf(float64(0))
case sppb.TypeCode_NUMERIC:
if propertyDecodeNumericToString.GetValueOrDefault(state) {
return reflect.TypeOf("")
}
return reflect.TypeOf(big.Rat{})
case sppb.TypeCode_STRING, sppb.TypeCode_DATE, sppb.TypeCode_UUID:
return reflect.TypeOf("")
case sppb.TypeCode_BYTES, sppb.TypeCode_PROTO:
return reflect.TypeOf([]byte{})
case sppb.TypeCode_BOOL:
return reflect.TypeOf(true)
case sppb.TypeCode_TIMESTAMP:
return reflect.TypeOf(time.Time{})
case sppb.TypeCode_JSON:
if t.TypeAnnotation == sppb.TypeAnnotationCode_PG_JSONB {
return reflect.TypeOf(spanner.PGJsonB{})
}
return reflect.TypeOf(spanner.NullJSON{})
case sppb.TypeCode_ARRAY:
if t.ArrayElementType == nil {
return reflect.TypeOf([]any{}).Elem()
}
Comment on lines +601 to +603
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When the array element type is unknown, ColumnTypeScanType should return a slice type (e.g., []any) rather than the element type any. reflect.TypeOf([]any{}).Elem() evaluates to interface{}, which represents a scalar value.

Suggested change
if t.ArrayElementType == nil {
return reflect.TypeOf([]any{}).Elem()
}
if t.ArrayElementType == nil {
return reflect.TypeOf([]any{})
}

et := scanType(t.ArrayElementType, decodeToNativeArrays, state)
if decodeToNativeArrays {
return reflect.SliceOf(et)
}
Comment on lines +604 to +607
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is an inconsistency in how DATE and UUID types are handled between scalar and array decoding when decodeToNativeArrays is true. Scalar DATE and UUID values are decoded as string (see lines 323 and 352), but their array counterparts are decoded as []civil.Date and []uuid.UUID (see lines 464 and 498). The current implementation of scanType would incorrectly return []string for these arrays because it recursively calls scanType on the element type.

		et := scanType(t.ArrayElementType, decodeToNativeArrays, state)
		if decodeToNativeArrays {
			switch t.ArrayElementType.Code {
			case sppb.TypeCode_DATE:
				return reflect.TypeOf([]civil.Date{})
			case sppb.TypeCode_UUID:
				return reflect.TypeOf([]uuid.UUID{})
			}
			return reflect.SliceOf(et)
		}

switch t.ArrayElementType.Code {
case sppb.TypeCode_INT64, sppb.TypeCode_ENUM:
return reflect.TypeOf([]spanner.NullInt64{})
case sppb.TypeCode_FLOAT32:
return reflect.TypeOf([]spanner.NullFloat32{})
case sppb.TypeCode_FLOAT64:
return reflect.TypeOf([]spanner.NullFloat64{})
case sppb.TypeCode_BOOL:
return reflect.TypeOf([]spanner.NullBool{})
case sppb.TypeCode_STRING:
return reflect.TypeOf([]spanner.NullString{})
case sppb.TypeCode_DATE:
return reflect.TypeOf([]spanner.NullDate{})
case sppb.TypeCode_TIMESTAMP:
return reflect.TypeOf([]spanner.NullTime{})
case sppb.TypeCode_NUMERIC:
if t.ArrayElementType.TypeAnnotation == sppb.TypeAnnotationCode_PG_NUMERIC {
return reflect.TypeOf([]spanner.PGNumeric{})
}
return reflect.TypeOf([]spanner.NullNumeric{})
case sppb.TypeCode_JSON:
if t.ArrayElementType.TypeAnnotation == sppb.TypeAnnotationCode_PG_JSONB {
return reflect.TypeOf([]spanner.PGJsonB{})
}
return reflect.TypeOf([]spanner.NullJSON{})
case sppb.TypeCode_UUID:
return reflect.TypeOf([]spanner.NullUUID{})
case sppb.TypeCode_BYTES, sppb.TypeCode_PROTO:
return reflect.TypeOf([][]byte{})
default:
return reflect.SliceOf(et)
}
default:
return reflect.TypeOf([]any{}).Elem()
}
}

func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
if index < 0 || index >= len(r.colTypes) {
return ""
}
t := r.colTypes[index]
dialect := propertyDatabaseDialect.GetValueOrDefault(r.state)
return databaseTypeName(t, dialect)
}

func databaseTypeName(t *sppb.Type, dialect databasepb.DatabaseDialect) string {
isPG := dialect == databasepb.DatabaseDialect_POSTGRESQL
switch t.Code {
case sppb.TypeCode_INT64, sppb.TypeCode_ENUM:
if isPG {
return "bigint"
}
return "INT64"
case sppb.TypeCode_FLOAT32:
if isPG {
return "real"
}
return "FLOAT32"
case sppb.TypeCode_FLOAT64:
if isPG {
return "double precision"
}
return "FLOAT64"
case sppb.TypeCode_NUMERIC:
if isPG {
return "numeric"
}
return "NUMERIC"
case sppb.TypeCode_STRING:
if isPG {
return "text"
}
return "STRING"
case sppb.TypeCode_DATE:
if isPG {
return "date"
}
return "DATE"
case sppb.TypeCode_UUID:
if isPG {
return "uuid"
}
return "UUID"
case sppb.TypeCode_BYTES, sppb.TypeCode_PROTO:
if isPG {
return "bytea"
}
return "BYTES"
case sppb.TypeCode_BOOL:
if isPG {
return "boolean"
}
return "BOOL"
case sppb.TypeCode_TIMESTAMP:
if isPG {
return "timestamp with time zone"
}
return "TIMESTAMP"
case sppb.TypeCode_JSON:
if isPG {
return "jsonb"
}
return "JSON"
case sppb.TypeCode_ARRAY:
if t.ArrayElementType == nil {
return "ARRAY"
}
if isPG {
switch t.ArrayElementType.Code {
case sppb.TypeCode_STRING:
return "_text"
case sppb.TypeCode_INT64:
return "_int8"
case sppb.TypeCode_FLOAT64:
return "_float8"
case sppb.TypeCode_BOOL:
return "_bool"
case sppb.TypeCode_BYTES:
return "_bytea"
case sppb.TypeCode_DATE:
return "_date"
case sppb.TypeCode_TIMESTAMP:
return "_timestamptz"
case sppb.TypeCode_NUMERIC:
return "_numeric"
case sppb.TypeCode_JSON:
return "_jsonb"
default:
return "_" + databaseTypeName(t.ArrayElementType, dialect)
}
}
Comment on lines +716 to +739
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The database type names for PostgreSQL arrays are inconsistent with the scalar type names and use internal PostgreSQL names (e.g., _int8 for bigint[]). It is better to use the standard typename[] syntax, which is consistent with the scalar names returned by this function and more familiar to users. This also simplifies the implementation by removing the need for a separate switch statement for PG arrays.

		if isPG {
			return databaseTypeName(t.ArrayElementType, dialect) + "[]"
		}

return "ARRAY<" + databaseTypeName(t.ArrayElementType, dialect) + ">"
default:
return ""
}
}

func (r *rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
if index < 0 || index >= len(r.colTypes) {
return 0, 0, false
}
if r.colTypes[index].Code == sppb.TypeCode_NUMERIC {
return 38, 9, true
}
return 0, 0, false
}

func (r *rows) ColumnTypeLength(index int) (length int64, ok bool) {
return 0, false
}

func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) {
return false, false
}

var _ driver.Rows = (*emptyRows)(nil)
var _ driver.RowsNextResultSet = (*emptyRows)(nil)
var emptyRowsMetadata = &sppb.ResultSetMetadata{RowType: &sppb.StructType{Fields: []*sppb.StructType_Field{}}}
Expand Down
95 changes: 95 additions & 0 deletions rows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@ import (
"errors"
"fmt"
"io"
"math/big"
"reflect"
"testing"

"cloud.google.com/go/spanner"
"cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
sppb "cloud.google.com/go/spanner/apiv1/spannerpb"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/go-sql-spanner/connectionstate"
"github.com/googleapis/go-sql-spanner/testutil"
"google.golang.org/protobuf/types/known/structpb"
)
Expand Down Expand Up @@ -286,3 +289,95 @@ func TestScanNumericAsFloat32(t *testing.T) {
})
}
}

func TestRowsColumnType(t *testing.T) {
fmt.Println("Starting TestRowsColumnType")
fields := []*sppb.StructType_Field{
{Name: "COL1", Type: &sppb.Type{Code: sppb.TypeCode_INT64}},
{Name: "COL2", Type: &sppb.Type{Code: sppb.TypeCode_STRING}},
{Name: "COL3", Type: &sppb.Type{Code: sppb.TypeCode_NUMERIC}},
{Name: "COL4", Type: &sppb.Type{Code: sppb.TypeCode_ARRAY, ArrayElementType: &sppb.Type{Code: sppb.TypeCode_STRING}}},
}
it := &testIterator{
metadata: &sppb.ResultSetMetadata{
RowType: &sppb.StructType{Fields: fields},
},
}
r := rows{
it: it,
colTypes: []*sppb.Type{fields[0].Type, fields[1].Type, fields[2].Type, fields[3].Type},
state: createInitialConnectionState(connectionstate.TypeNonTransactional, nil),
}

// Test ColumnTypeScanType
if g, w := r.ColumnTypeScanType(0), reflect.TypeOf(int64(0)); g != w {
t.Errorf("ColumnTypeScanType(0) mismatch\n Got: %v\nWant: %v", g, w)
}
if g, w := r.ColumnTypeScanType(1), reflect.TypeOf(""); g != w {
t.Errorf("ColumnTypeScanType(1) mismatch\n Got: %v\nWant: %v", g, w)
}
if g, w := r.ColumnTypeScanType(2), reflect.TypeOf(big.Rat{}); g != w {
t.Errorf("ColumnTypeScanType(2) mismatch\n Got: %v\nWant: %v", g, w)
}
if g, w := r.ColumnTypeScanType(3), reflect.TypeOf([]spanner.NullString{}); g != w {
t.Errorf("ColumnTypeScanType(3) mismatch\n Got: %v\nWant: %v", g, w)
}

// Test ColumnTypeDatabaseTypeName (GoogleSQL)
if g, w := r.ColumnTypeDatabaseTypeName(0), "INT64"; g != w {
t.Errorf("ColumnTypeDatabaseTypeName(0) mismatch\n Got: %v\nWant: %v", g, w)
}
if g, w := r.ColumnTypeDatabaseTypeName(1), "STRING"; g != w {
t.Errorf("ColumnTypeDatabaseTypeName(1) mismatch\n Got: %v\nWant: %v", g, w)
}
if g, w := r.ColumnTypeDatabaseTypeName(2), "NUMERIC"; g != w {
t.Errorf("ColumnTypeDatabaseTypeName(2) mismatch\n Got: %v\nWant: %v", g, w)
}
if g, w := r.ColumnTypeDatabaseTypeName(3), "ARRAY<STRING>"; g != w {
t.Errorf("ColumnTypeDatabaseTypeName(3) mismatch\n Got: %v\nWant: %v", g, w)
}

// Test ColumnTypeDatabaseTypeName (PostgreSQL)
pgState := createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{
propertyDatabaseDialect.Key(): propertyDatabaseDialect.CreateTypedInitialValue(databasepb.DatabaseDialect_POSTGRESQL),
})
rpg := rows{
it: it,
colTypes: r.colTypes,
state: pgState,
}
if g, w := rpg.ColumnTypeDatabaseTypeName(0), "bigint"; g != w {
t.Errorf("ColumnTypeDatabaseTypeName(0) PG mismatch\n Got: %v\nWant: %v", g, w)
}
if g, w := rpg.ColumnTypeDatabaseTypeName(1), "text"; g != w {
t.Errorf("ColumnTypeDatabaseTypeName(1) PG mismatch\n Got: %v\nWant: %v", g, w)
}
if g, w := rpg.ColumnTypeDatabaseTypeName(2), "numeric"; g != w {
t.Errorf("ColumnTypeDatabaseTypeName(2) PG mismatch\n Got: %v\nWant: %v", g, w)
}
if g, w := rpg.ColumnTypeDatabaseTypeName(3), "_text"; g != w {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If the PostgreSQL array naming is updated to use the typename[] syntax as suggested in rows.go, this test expectation should be updated from _text to text[].

Suggested change
if g, w := rpg.ColumnTypeDatabaseTypeName(3), "_text"; g != w {
if g, w := rpg.ColumnTypeDatabaseTypeName(3), "text[]"; g != w {

t.Errorf("ColumnTypeDatabaseTypeName(3) PG mismatch\n Got: %v\nWant: %v", g, w)
}

// Test ColumnTypePrecisionScale
p, s, ok := r.ColumnTypePrecisionScale(2)
if !ok || p != 38 || s != 9 {
t.Errorf("ColumnTypePrecisionScale(2) mismatch\n Got: %v, %v, %v\nWant: 38, 9, true", p, s, ok)
}
p, s, ok = r.ColumnTypePrecisionScale(0)
if ok || p != 0 || s != 0 {
t.Errorf("ColumnTypePrecisionScale(0) mismatch\n Got: %v, %v, %v\nWant: 0, 0, false", p, s, ok)
}

// Test ColumnTypeLength
l, ok := r.ColumnTypeLength(1)
if ok || l != 0 {
t.Errorf("ColumnTypeLength(1) mismatch\n Got: %v, %v\nWant: 0, false", l, ok)
}

// Test ColumnTypeNullable
n, ok := r.ColumnTypeNullable(1)
if ok || n != false {
t.Errorf("ColumnTypeNullable(1) mismatch\n Got: %v, %v\nWant: false, false", n, ok)
}
}
Loading