diff --git a/rows.go b/rows.go index 73eaf59d..0ca6b042 100644 --- a/rows.go +++ b/rows.go @@ -20,11 +20,14 @@ import ( "errors" "fmt" "io" + "math/big" + "reflect" "sync" "time" "cloud.google.com/go/civil" "cloud.google.com/go/spanner" + "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" sppb "cloud.google.com/go/spanner/apiv1/spannerpb" "github.com/google/uuid" "github.com/googleapis/go-sql-spanner/connectionstate" @@ -557,6 +560,207 @@ func (r *rows) nextStats(dest []driver.Value) error { return nil } +func (r *rows) ColumnTypeScanType(index int) reflect.Type { + if index < 0 || index >= len(r.colTypes) { + return nil + } + t := r.colTypes[index] + if r.decodeOption == DecodeOptionProto { + return reflect.TypeOf(spanner.GenericColumnValue{}) + } + return scanType(t, r.decodeToNativeArrays, r.state) +} + +func scanType(t *sppb.Type, decodeToNativeArrays bool, state *connectionstate.ConnectionState) reflect.Type { + switch t.Code { + case sppb.TypeCode_INT64, sppb.TypeCode_ENUM: + return reflect.TypeOf(int64(0)) + case sppb.TypeCode_FLOAT32: + return reflect.TypeOf(float32(0)) + case sppb.TypeCode_FLOAT64: + return reflect.TypeOf(float64(0)) + case sppb.TypeCode_NUMERIC: + if propertyDecodeNumericToString.GetValueOrDefault(state) { + return reflect.TypeOf("") + } + return reflect.TypeOf(big.Rat{}) + case sppb.TypeCode_STRING, sppb.TypeCode_DATE, sppb.TypeCode_UUID: + return reflect.TypeOf("") + case sppb.TypeCode_BYTES, sppb.TypeCode_PROTO: + return reflect.TypeOf([]byte{}) + case sppb.TypeCode_BOOL: + return reflect.TypeOf(true) + case sppb.TypeCode_TIMESTAMP: + return reflect.TypeOf(time.Time{}) + case sppb.TypeCode_JSON: + if t.TypeAnnotation == sppb.TypeAnnotationCode_PG_JSONB { + return reflect.TypeOf(spanner.PGJsonB{}) + } + return reflect.TypeOf(spanner.NullJSON{}) + case sppb.TypeCode_ARRAY: + if t.ArrayElementType == nil { + return reflect.TypeOf([]any{}).Elem() + } + et := scanType(t.ArrayElementType, decodeToNativeArrays, state) + if decodeToNativeArrays { + return reflect.SliceOf(et) + } + switch t.ArrayElementType.Code { + case sppb.TypeCode_INT64, sppb.TypeCode_ENUM: + return reflect.TypeOf([]spanner.NullInt64{}) + case sppb.TypeCode_FLOAT32: + return reflect.TypeOf([]spanner.NullFloat32{}) + case sppb.TypeCode_FLOAT64: + return reflect.TypeOf([]spanner.NullFloat64{}) + case sppb.TypeCode_BOOL: + return reflect.TypeOf([]spanner.NullBool{}) + case sppb.TypeCode_STRING: + return reflect.TypeOf([]spanner.NullString{}) + case sppb.TypeCode_DATE: + return reflect.TypeOf([]spanner.NullDate{}) + case sppb.TypeCode_TIMESTAMP: + return reflect.TypeOf([]spanner.NullTime{}) + case sppb.TypeCode_NUMERIC: + if t.ArrayElementType.TypeAnnotation == sppb.TypeAnnotationCode_PG_NUMERIC { + return reflect.TypeOf([]spanner.PGNumeric{}) + } + return reflect.TypeOf([]spanner.NullNumeric{}) + case sppb.TypeCode_JSON: + if t.ArrayElementType.TypeAnnotation == sppb.TypeAnnotationCode_PG_JSONB { + return reflect.TypeOf([]spanner.PGJsonB{}) + } + return reflect.TypeOf([]spanner.NullJSON{}) + case sppb.TypeCode_UUID: + return reflect.TypeOf([]spanner.NullUUID{}) + case sppb.TypeCode_BYTES, sppb.TypeCode_PROTO: + return reflect.TypeOf([][]byte{}) + default: + return reflect.SliceOf(et) + } + default: + return reflect.TypeOf([]any{}).Elem() + } +} + +func (r *rows) ColumnTypeDatabaseTypeName(index int) string { + if index < 0 || index >= len(r.colTypes) { + return "" + } + t := r.colTypes[index] + dialect := propertyDatabaseDialect.GetValueOrDefault(r.state) + return databaseTypeName(t, dialect) +} + +func databaseTypeName(t *sppb.Type, dialect databasepb.DatabaseDialect) string { + isPG := dialect == databasepb.DatabaseDialect_POSTGRESQL + switch t.Code { + case sppb.TypeCode_INT64, sppb.TypeCode_ENUM: + if isPG { + return "bigint" + } + return "INT64" + case sppb.TypeCode_FLOAT32: + if isPG { + return "real" + } + return "FLOAT32" + case sppb.TypeCode_FLOAT64: + if isPG { + return "double precision" + } + return "FLOAT64" + case sppb.TypeCode_NUMERIC: + if isPG { + return "numeric" + } + return "NUMERIC" + case sppb.TypeCode_STRING: + if isPG { + return "text" + } + return "STRING" + case sppb.TypeCode_DATE: + if isPG { + return "date" + } + return "DATE" + case sppb.TypeCode_UUID: + if isPG { + return "uuid" + } + return "UUID" + case sppb.TypeCode_BYTES, sppb.TypeCode_PROTO: + if isPG { + return "bytea" + } + return "BYTES" + case sppb.TypeCode_BOOL: + if isPG { + return "boolean" + } + return "BOOL" + case sppb.TypeCode_TIMESTAMP: + if isPG { + return "timestamp with time zone" + } + return "TIMESTAMP" + case sppb.TypeCode_JSON: + if isPG { + return "jsonb" + } + return "JSON" + case sppb.TypeCode_ARRAY: + if t.ArrayElementType == nil { + return "ARRAY" + } + if isPG { + switch t.ArrayElementType.Code { + case sppb.TypeCode_STRING: + return "_text" + case sppb.TypeCode_INT64: + return "_int8" + case sppb.TypeCode_FLOAT64: + return "_float8" + case sppb.TypeCode_BOOL: + return "_bool" + case sppb.TypeCode_BYTES: + return "_bytea" + case sppb.TypeCode_DATE: + return "_date" + case sppb.TypeCode_TIMESTAMP: + return "_timestamptz" + case sppb.TypeCode_NUMERIC: + return "_numeric" + case sppb.TypeCode_JSON: + return "_jsonb" + default: + return "_" + databaseTypeName(t.ArrayElementType, dialect) + } + } + return "ARRAY<" + databaseTypeName(t.ArrayElementType, dialect) + ">" + default: + return "" + } +} + +func (r *rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { + if index < 0 || index >= len(r.colTypes) { + return 0, 0, false + } + if r.colTypes[index].Code == sppb.TypeCode_NUMERIC { + return 38, 9, true + } + return 0, 0, false +} + +func (r *rows) ColumnTypeLength(index int) (length int64, ok bool) { + return 0, false +} + +func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) { + return false, false +} + var _ driver.Rows = (*emptyRows)(nil) var _ driver.RowsNextResultSet = (*emptyRows)(nil) var emptyRowsMetadata = &sppb.ResultSetMetadata{RowType: &sppb.StructType{Fields: []*sppb.StructType_Field{}}} diff --git a/rows_test.go b/rows_test.go index fafe33aa..4276f591 100644 --- a/rows_test.go +++ b/rows_test.go @@ -20,12 +20,15 @@ import ( "errors" "fmt" "io" + "math/big" "reflect" "testing" "cloud.google.com/go/spanner" + "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" sppb "cloud.google.com/go/spanner/apiv1/spannerpb" "github.com/google/go-cmp/cmp" + "github.com/googleapis/go-sql-spanner/connectionstate" "github.com/googleapis/go-sql-spanner/testutil" "google.golang.org/protobuf/types/known/structpb" ) @@ -286,3 +289,95 @@ func TestScanNumericAsFloat32(t *testing.T) { }) } } + +func TestRowsColumnType(t *testing.T) { + fmt.Println("Starting TestRowsColumnType") + fields := []*sppb.StructType_Field{ + {Name: "COL1", Type: &sppb.Type{Code: sppb.TypeCode_INT64}}, + {Name: "COL2", Type: &sppb.Type{Code: sppb.TypeCode_STRING}}, + {Name: "COL3", Type: &sppb.Type{Code: sppb.TypeCode_NUMERIC}}, + {Name: "COL4", Type: &sppb.Type{Code: sppb.TypeCode_ARRAY, ArrayElementType: &sppb.Type{Code: sppb.TypeCode_STRING}}}, + } + it := &testIterator{ + metadata: &sppb.ResultSetMetadata{ + RowType: &sppb.StructType{Fields: fields}, + }, + } + r := rows{ + it: it, + colTypes: []*sppb.Type{fields[0].Type, fields[1].Type, fields[2].Type, fields[3].Type}, + state: createInitialConnectionState(connectionstate.TypeNonTransactional, nil), + } + + // Test ColumnTypeScanType + if g, w := r.ColumnTypeScanType(0), reflect.TypeOf(int64(0)); g != w { + t.Errorf("ColumnTypeScanType(0) mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := r.ColumnTypeScanType(1), reflect.TypeOf(""); g != w { + t.Errorf("ColumnTypeScanType(1) mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := r.ColumnTypeScanType(2), reflect.TypeOf(big.Rat{}); g != w { + t.Errorf("ColumnTypeScanType(2) mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := r.ColumnTypeScanType(3), reflect.TypeOf([]spanner.NullString{}); g != w { + t.Errorf("ColumnTypeScanType(3) mismatch\n Got: %v\nWant: %v", g, w) + } + + // Test ColumnTypeDatabaseTypeName (GoogleSQL) + if g, w := r.ColumnTypeDatabaseTypeName(0), "INT64"; g != w { + t.Errorf("ColumnTypeDatabaseTypeName(0) mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := r.ColumnTypeDatabaseTypeName(1), "STRING"; g != w { + t.Errorf("ColumnTypeDatabaseTypeName(1) mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := r.ColumnTypeDatabaseTypeName(2), "NUMERIC"; g != w { + t.Errorf("ColumnTypeDatabaseTypeName(2) mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := r.ColumnTypeDatabaseTypeName(3), "ARRAY"; g != w { + t.Errorf("ColumnTypeDatabaseTypeName(3) mismatch\n Got: %v\nWant: %v", g, w) + } + + // Test ColumnTypeDatabaseTypeName (PostgreSQL) + pgState := createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{ + propertyDatabaseDialect.Key(): propertyDatabaseDialect.CreateTypedInitialValue(databasepb.DatabaseDialect_POSTGRESQL), + }) + rpg := rows{ + it: it, + colTypes: r.colTypes, + state: pgState, + } + if g, w := rpg.ColumnTypeDatabaseTypeName(0), "bigint"; g != w { + t.Errorf("ColumnTypeDatabaseTypeName(0) PG mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := rpg.ColumnTypeDatabaseTypeName(1), "text"; g != w { + t.Errorf("ColumnTypeDatabaseTypeName(1) PG mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := rpg.ColumnTypeDatabaseTypeName(2), "numeric"; g != w { + t.Errorf("ColumnTypeDatabaseTypeName(2) PG mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := rpg.ColumnTypeDatabaseTypeName(3), "_text"; g != w { + t.Errorf("ColumnTypeDatabaseTypeName(3) PG mismatch\n Got: %v\nWant: %v", g, w) + } + + // Test ColumnTypePrecisionScale + p, s, ok := r.ColumnTypePrecisionScale(2) + if !ok || p != 38 || s != 9 { + t.Errorf("ColumnTypePrecisionScale(2) mismatch\n Got: %v, %v, %v\nWant: 38, 9, true", p, s, ok) + } + p, s, ok = r.ColumnTypePrecisionScale(0) + if ok || p != 0 || s != 0 { + t.Errorf("ColumnTypePrecisionScale(0) mismatch\n Got: %v, %v, %v\nWant: 0, 0, false", p, s, ok) + } + + // Test ColumnTypeLength + l, ok := r.ColumnTypeLength(1) + if ok || l != 0 { + t.Errorf("ColumnTypeLength(1) mismatch\n Got: %v, %v\nWant: 0, false", l, ok) + } + + // Test ColumnTypeNullable + n, ok := r.ColumnTypeNullable(1) + if ok || n != false { + t.Errorf("ColumnTypeNullable(1) mismatch\n Got: %v, %v\nWant: false, false", n, ok) + } +}